Compare commits

..

1 Commits

Author SHA1 Message Date
Paul Masurel
4b131c3b40 Optimize nested term aggregations with a shared term ordinal cache
When a string term aggregation is nested inside another bucket
aggregation, each parent bucket previously resolved term ordinals
to strings independently via `sorted_ords_to_term_cb`. This was
redundant since sibling buckets share the same term dictionary.

Introduce `TermOrdToStrCache` which resolves all observed term
ordinals once across all parent buckets, then stores the results
in a `StringArena` (a single contiguous String buffer) with
compact `StringRef` handles. The cache uses either a dense
`Vec<Option<StringRef>>` or a sparse `FxHashMap<u64, StringRef>`
depending on the ordinal range.

Also add `collect_term_ords` to the `TermAggregationMap` trait so
each map variant can export its live ordinals into a shared set.

Closes #2892
2026-04-25 18:37:35 +02:00
34 changed files with 838 additions and 2294 deletions

View File

@@ -4,9 +4,6 @@ on:
push:
branches: [main]
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -15,20 +12,16 @@ concurrency:
jobs:
coverage:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v6
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
uses: codecov/codecov-action@v6
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View File

@@ -8,9 +8,6 @@ env:
CARGO_TERM_COLOR: always
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -21,13 +18,10 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v6
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal

View File

@@ -1,49 +0,0 @@
name: OpenSSF Scorecard
on:
schedule:
- cron: '0 0 * * 0'
push:
branches:
- main
permissions:
contents: read
jobs:
analysis:
name: Scorecards analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results
id-token: write
steps:
- name: 'Checkout code'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: 'Run analysis'
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
repo_token: ${{ secrets.GITHUB_TOKEN }}
publish_results: true
# Upload the results as artifacts.
- name: 'Upload artifact'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: 'Upload to code-scanning'
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2
with:
sarif_file: results.sarif

View File

@@ -9,9 +9,6 @@ on:
env:
CARGO_TERM_COLOR: always
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -22,27 +19,23 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v6
- name: Install nightly
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
components: rustfmt
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
components: clippy
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: Swatinem/rust-cache@v2
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
@@ -54,7 +47,7 @@ jobs:
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
- uses: actions-rs/clippy-check@v1
with:
toolchain: stable
token: ${{ secrets.GITHUB_TOKEN }}
@@ -64,9 +57,6 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
matrix:
features:
@@ -77,17 +67,17 @@ jobs:
name: test-${{ matrix.features.label}}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v6
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |

View File

@@ -205,7 +205,3 @@ harness = false
[[bench]]
name = "query_parser_nested"
harness = false
[[bench]]
name = "intersection_bench"
harness = false

View File

@@ -1,7 +1,6 @@
[![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/)
[![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy)
[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/quickwit-oss/tantivy/badge)](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
[![Join the chat at https://discord.gg/MT27AG5EVE](https://shields.io/discord/908281611840282624?label=chat%20on%20discord)](https://discord.gg/MT27AG5EVE)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy)

View File

@@ -79,12 +79,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_100_buckets_with_cardinality_agg);
register!(group, terms_many_with_single_term_order_by_card);
register!(group, terms_many_with_single_term_2_order_by_card);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
@@ -172,32 +168,6 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a near-1M-cardinality string field.
// Hits the dense (PagedBitset) path: every doc has a unique term,
// so the bucket promotes from FxHashSet shortly into the scan.
fn cardinality_agg_high_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_all_unique_terms"
},
}
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
// values). Stays on the FxHashSet path — the promotion threshold is
// never crossed. Validates no regression on the sparse path.
fn cardinality_agg_low_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -230,64 +200,13 @@ fn terms_100_buckets_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many_with_single_term_order_by_card(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"nested_terms": {
"terms": {
"field": "single_term",
"order": { "cardinality": "desc" }
},
"aggs": {
"cardinality": {
"cardinality": { "field": "text_few_terms" }
}
}
}
}
},
});
execute_agg(index, agg_req);
}
// Two-level terms ordered by cardinality at each level: a high-card outer terms
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
let agg_req = json!({
"by_ip": {
"terms": {
"field": "text_many_terms",
"order": { "card_few_terms": "desc" }
},
"aggs": {
"card_few_terms": {
"cardinality": { "field": "text_few_terms" }
},
"nested_terms": {
"terms": {
"field": " single_term",
"order": { "distinct_path2": "desc" }
},
"aggs": {
"avg_botscore": { "avg": { "field": "score" } },
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
}
}
}
}
});
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
@@ -691,8 +610,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let single_term = schema_builder.add_text_field("single_term", FAST);
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
@@ -756,8 +674,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
single_term => "single_term",
single_term => "single_term",
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
@@ -792,7 +708,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),

View File

@@ -1,149 +0,0 @@
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
//
// What's measured:
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
// - Realistic term frequencies (geometric distribution, mostly low)
// - 1M-doc single segment
//
// Run with: cargo bench --bench intersection_bench
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
const NUM_DOCS: usize = 1_000_000;
struct BenchIndex {
searcher: Searcher,
query_parser: QueryParser,
}
/// Generate term frequency from a geometric-like distribution.
/// Most values are 1, a few are 2-3, rarely higher.
/// p controls the decay: higher p → more weight on tf=1.
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
let mut tf = 1u32;
while tf < 10 && rng.random_bool(1.0 - p) {
tf += 1;
}
tf
}
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
/// Each term occurrence has a realistic term frequency (geometric distribution).
/// Field length is padded with filler tokens to create varied fieldnorms.
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut rng = StdRng::from_seed([42u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..NUM_DOCS {
let mut tokens: Vec<String> = Vec::new();
if rng.random_bool(p_a) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("aaa".to_string());
}
}
if rng.random_bool(p_b) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("bbb".to_string());
}
}
if rng.random_bool(p_c) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("ccc".to_string());
}
}
// Pad with filler to create varied field lengths (5-30 tokens).
let filler_count = rng.random_range(5u32..30u32);
for _ in 0..filler_count {
tokens.push("filler".to_string());
}
let text = tokens.join(" ");
writer.add_document(doc!(body => text)).unwrap();
}
writer.commit().unwrap();
}
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![body]);
BenchIndex {
searcher,
query_parser,
}
}
fn main() {
// Scenarios: (label, p_a, p_b, p_c)
//
// "balanced": all terms ~10% → intersection ~1% of docs
// "skewed": one common (50%), one rare (2%) → intersection ~1%
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
// "three_balanced": three terms ~20% each → intersection ~0.8%
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
("balanced_10%_10%", 0.10, 0.10, 0.0),
("skewed_50%_2%", 0.50, 0.02, 0.0),
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
];
let mut runner = BenchRunner::new();
for (label, p_a, p_b, p_c) in &scenarios {
let bench_index = build_index(*p_a, *p_b, *p_c);
let mut group = runner.new_group();
group.set_name(format!("intersection — {label}"));
// Two-term intersection
if *p_a > 0.0 && *p_b > 0.0 {
let query_str = "+aaa +bbb";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
// Three-term intersection
if *p_c > 0.0 {
let query_str = "+aaa +bbb +ccc";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
group.run();
}
}

View File

@@ -20,8 +20,8 @@ use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
@@ -41,7 +41,7 @@ pub struct AggregationsSegmentCtx {
impl AggregationsSegmentCtx {
pub(crate) fn push_term_req_data(&mut self, data: TermsAggReqData) -> usize {
self.per_request.term_req_data.push(data);
self.per_request.term_req_data.push(Some(Box::new(data)));
self.per_request.term_req_data.len() - 1
}
pub(crate) fn push_cardinality_req_data(&mut self, data: CardinalityAggReqData) -> usize {
@@ -61,25 +61,31 @@ impl AggregationsSegmentCtx {
self.per_request.missing_term_req_data.len() - 1
}
pub(crate) fn push_histogram_req_data(&mut self, data: HistogramAggReqData) -> usize {
self.per_request.histogram_req_data.push(data);
self.per_request
.histogram_req_data
.push(Some(Box::new(data)));
self.per_request.histogram_req_data.len() - 1
}
pub(crate) fn push_range_req_data(&mut self, data: RangeAggReqData) -> usize {
self.per_request.range_req_data.push(data);
self.per_request.range_req_data.push(Some(Box::new(data)));
self.per_request.range_req_data.len() - 1
}
pub(crate) fn push_filter_req_data(&mut self, data: FilterAggReqData) -> usize {
self.per_request.filter_req_data.push(data);
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request.composite_req_data.push(data);
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
&self.per_request.term_req_data[idx]
self.per_request.term_req_data[idx]
.as_deref()
.expect("term_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_cardinality_req_data(&self, idx: usize) -> &CardinalityAggReqData {
@@ -97,6 +103,116 @@ impl AggregationsSegmentCtx {
pub(crate) fn get_missing_term_req_data(&self, idx: usize) -> &MissingTermAggReqData {
&self.per_request.missing_term_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data(&self, idx: usize) -> &HistogramAggReqData {
self.per_request.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_range_req_data(&self, idx: usize) -> &RangeAggReqData {
self.per_request.range_req_data[idx]
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
idx: usize,
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
.as_deref_mut()
.expect("histogram_req_data slot is empty (taken)")
}
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
self.per_request.histogram_req_data[idx]
.take()
.expect("histogram_req_data slot is empty (taken)")
}
/// Put back a Histogram request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_histogram_req_data(
&mut self,
idx: usize,
value: Box<HistogramAggReqData>,
) {
debug_assert!(self.per_request.histogram_req_data[idx].is_none());
self.per_request.histogram_req_data[idx] = Some(value);
}
/// Move out the boxed Range request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_range_req_data(&mut self, idx: usize) -> Box<RangeAggReqData> {
self.per_request.range_req_data[idx]
.take()
.expect("range_req_data slot is empty (taken)")
}
/// Put back a Range request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_range_req_data(&mut self, idx: usize, value: Box<RangeAggReqData>) {
debug_assert!(self.per_request.range_req_data[idx].is_none());
self.per_request.range_req_data[idx] = Some(value);
}
/// Move out the boxed Filter request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_filter_req_data(&mut self, idx: usize) -> Box<FilterAggReqData> {
self.per_request.filter_req_data[idx]
.take()
.expect("filter_req_data slot is empty (taken)")
}
/// Put back a Filter request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_filter_req_data(&mut self, idx: usize, value: Box<FilterAggReqData>) {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -107,14 +223,15 @@ impl AggregationsSegmentCtx {
/// for a node with [AggKind::Terms]).
#[derive(Default)]
pub struct PerRequestAggSegCtx {
// Box for cheap take/put - Only necessary for bucket aggs that have sub-aggregations
/// TermsAggReqData contains the request data for a terms aggregation.
pub term_req_data: Vec<TermsAggReqData>,
pub term_req_data: Vec<Option<Box<TermsAggReqData>>>,
/// HistogramAggReqData contains the request data for a histogram aggregation.
pub histogram_req_data: Vec<HistogramAggReqData>,
pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>,
/// RangeAggReqData contains the request data for a range aggregation.
pub range_req_data: Vec<RangeAggReqData>,
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
/// FilterAggReqData contains the request data for a filter aggregation.
pub filter_req_data: Vec<FilterAggReqData>,
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
/// Shared by avg, min, max, sum, stats, extended_stats, count
pub stats_metric_req_data: Vec<MetricAggReqData>,
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
@@ -124,7 +241,7 @@ pub struct PerRequestAggSegCtx {
/// MissingTermAggReqData contains the request data for a missing term aggregation.
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<CompositeAggReqData>,
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -135,22 +252,22 @@ impl PerRequestAggSegCtx {
fn get_memory_consumption(&self) -> usize {
self.term_req_data
.iter()
.map(|t| t.get_memory_consumption())
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.histogram_req_data
.iter()
.map(|t| t.get_memory_consumption())
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.range_req_data
.iter()
.map(|t| t.get_memory_consumption())
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.filter_req_data
.iter()
.map(|t| t.get_memory_consumption())
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.stats_metric_req_data
@@ -175,7 +292,7 @@ impl PerRequestAggSegCtx {
+ self
.composite_req_data
.iter()
.map(|t| t.get_memory_consumption())
.map(|b| b.as_ref().map(|d| d.get_memory_consumption()).unwrap_or(0))
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -184,16 +301,40 @@ impl PerRequestAggSegCtx {
let idx = node.idx_in_req_data;
let kind = node.kind;
match kind {
AggKind::Terms => self.term_req_data[idx].name.as_str(),
AggKind::Terms => self.term_req_data[idx]
.as_deref()
.expect("term_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Cardinality => &self.cardinality_req_data[idx].name,
AggKind::StatsKind(_) => &self.stats_metric_req_data[idx].name,
AggKind::TopHits => &self.top_hits_req_data[idx].name,
AggKind::MissingTerm => &self.missing_term_req_data[idx].name,
AggKind::Histogram => self.histogram_req_data[idx].name.as_str(),
AggKind::DateHistogram => self.histogram_req_data[idx].name.as_str(),
AggKind::Range => self.range_req_data[idx].name.as_str(),
AggKind::Filter => self.filter_req_data[idx].name.as_str(),
AggKind::Composite => self.composite_req_data[idx].name.as_str(),
AggKind::Histogram => self.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::DateHistogram => self.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Range => self.range_req_data[idx]
.as_deref()
.expect("range_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Filter => self.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
}
}
@@ -271,39 +412,13 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(TermMissingAgg::new(req, node)?))
}
AggKind::Cardinality => {
let req_data = req.get_cardinality_req_data(node.idx_in_req_data);
// For str columns, choose the per-bucket entries representation
// based on the segment's column.max_value():
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
// For non-str columns the `entries` field is unused (values go
// straight into the HLL sketch); we still pick `TermOrdSet`
// because its empty Sparse(FxHashSet) costs nothing.
let is_str = req_data.column_type == ColumnType::Str;
let max_term_ord_inclusive = if is_str {
req_data.accessor.max_value()
} else {
0
};
let collector: Box<dyn SegmentAggregationCollector> =
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
} else {
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
};
Ok(collector)
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
@@ -318,7 +433,7 @@ pub(crate) fn build_segment_agg_collector(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data(node.idx_in_req_data);
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
@@ -658,18 +773,23 @@ fn build_nodes(
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(schema, tokenizers)?;
let evaluator =
std::rc::Rc::new(crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
reader,
)?);
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
reader,
)?;
// Pre-allocate buffer for batch filtering
let max_doc = reader.max_doc();
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
let matching_docs_buffer = Vec::with_capacity(buffer_capacity);
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
name: agg_name.to_string(),
req: filter_req.clone(),
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
@@ -886,20 +1006,10 @@ fn build_terms_or_cardinality_nodes(
(idx_in_req_data, AggKind::Terms)
}
TermsOrCardinalityRequest::Cardinality(ref req) => {
// `str_dict_column` is computed once per field; for JSON paths
// with mixed types it's `Some` even on the numeric req_data.
// Cardinality only consults it for the str column path, so
// gate by column_type to avoid driving non-str collectors
// through the coupon-cache path.
let str_dict_column_for_req = if column_type == ColumnType::Str {
str_dict_column.clone()
} else {
None
};
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
accessor,
column_type,
str_dict_column: str_dict_column_for_req,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
name: agg_name.to_string(),
req: req.clone(),

View File

@@ -16,7 +16,6 @@ use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
#[derive(Debug, Clone)]
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
@@ -35,7 +34,6 @@ impl CompositeAggReqData {
}
/// Accessors for a single column in a composite source.
#[derive(Debug, Clone)]
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
@@ -50,7 +48,6 @@ pub struct CompositeAccessor {
}
/// Accessors to all the columns that belong to the field of a composite source.
#[derive(Debug, Clone)]
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
@@ -361,7 +358,7 @@ impl PrecomputedDateInterval {
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),

View File

@@ -118,7 +118,7 @@ impl InternalValueRepr {
pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
req_data: CompositeAggReqData,
accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
@@ -132,7 +132,10 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.req_data.name.clone();
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let buckets = self.add_intermediate_bucket_result(agg_data, parent_bucket_id)?;
results.push(
@@ -150,11 +153,12 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
let mut visitor = CompositeKeyVisitor {
doc_id: *doc,
composite_agg_data: &self.req_data,
composite_agg_data: &composite_agg_data,
buckets: &mut self.parent_buckets[parent_bucket_id as usize],
sub_agg: &mut self.sub_agg,
bucket_id_provider: &mut self.bucket_id_provider,
@@ -162,6 +166,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
};
visitor.visit(0, true)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
@@ -194,17 +199,6 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Composite is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentCompositeCollector {
@@ -216,13 +210,7 @@ impl SegmentCompositeCollector {
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let composite_req_data =
req_data.per_request.composite_req_data[node.idx_in_req_data].clone();
validate_req(&composite_req_data)?;
req_data
.context
.limits
.add_memory_consumed(composite_req_data.get_memory_consumption() as u64)?;
validate_req(req_data, node.idx_in_req_data)?;
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
@@ -232,11 +220,12 @@ impl SegmentCompositeCollector {
None
};
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
let num_sources = composite_req_data.req.sources.len();
Ok(SegmentCompositeCollector {
parent_buckets: vec![DynArrayHeapMap::try_new(num_sources)?],
req_data: composite_req_data,
accessor_idx: node.idx_in_req_data,
sub_agg,
bucket_id_provider: BucketIdProvider::default(),
num_sources,
@@ -258,7 +247,7 @@ impl SegmentCompositeCollector {
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(heap_map.size());
let composite_data = &self.req_data;
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
for (key_internal_repr, agg) in heap_map.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
@@ -298,7 +287,8 @@ impl SegmentCompositeCollector {
}
}
fn validate_req(composite_data: &CompositeAggReqData) -> crate::Result<()> {
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(

View File

@@ -559,30 +559,34 @@ mod tests {
page_size,
agg_req,
);
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on every non-empty page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
}
// Using the after_key from the last page must yield an empty page.
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": after_key,
}
if page_idx + 1 < page_count {
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on all but last page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
} else if res["my_composite"].get("after_key").is_some() {
// currently we sometime have an after_key on the last page,
// check that the next "page" is empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": res["my_composite"]["after_key"].clone(),
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
}
}
@@ -707,28 +711,8 @@ mod tests {
{"key": {"myterm": "terme"}, "doc_count": 1}
])
);
// paginating past last page should be empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": [
{"myterm": {"terms": {"field": "string_id"}}}
],
"size": 3,
"after": &res["my_composite"]["after_key"]
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert!(res["my_composite"].get("after_key").is_none());
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
Ok(())
}
@@ -836,10 +820,7 @@ mod tests {
{"key": {"myterm": "apple"}, "doc_count": 1}
])
);
assert_eq!(
res["fruity_aggreg"]["after_key"],
json!({"myterm": "str:apple"})
);
assert!(res["fruity_aggreg"].get("after_key").is_none());
Ok(())
}
@@ -1811,14 +1792,7 @@ mod tests {
{"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1},
]),
);
let feb_2021_ns = ms_timestamp_from_iso_str("2021-02-01T00:00:00Z") * 1_000_000;
assert_eq!(
res["my_composite"]["after_key"],
json!({
"month": format!("dt:{}", feb_2021_ns),
"category": "str:books"
})
);
assert!(res["my_composite"].get("after_key").is_none());
Ok(())
}

View File

@@ -1,5 +1,4 @@
use std::fmt::Debug;
use std::rc::Rc;
use common::BitSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -397,7 +396,6 @@ impl PartialEq for FilterAggregation {
/// Request data for filter aggregation
/// This struct holds the per-segment data needed to execute a filter aggregation
#[derive(Clone)]
pub struct FilterAggReqData {
/// The name of the filter aggregation
pub name: String,
@@ -405,20 +403,22 @@ pub struct FilterAggReqData {
pub req: FilterAggregation,
/// The segment reader
pub segment_reader: SegmentReader,
/// Document evaluator for the filter query (precomputed BitSet).
/// Wrapped in `Rc` so cloning the request data does not duplicate the (potentially large)
/// underlying BitSet.
pub evaluator: Rc<DocumentQueryEvaluator>,
/// Document evaluator for the filter query (precomputed BitSet)
/// This is built once when the request data is created
pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
}
}
@@ -509,10 +509,8 @@ pub struct SegmentFilterCollector<B: SubAggBuffer> {
/// Sub-aggregation collectors
sub_aggregations: Option<BufferedSubAggs<B>>,
bucket_id_provider: BucketIdProvider,
/// Per-segment filter request data, owned by this collector.
req_data: FilterAggReqData,
/// Reusable buffer for matching documents to minimize allocations during collection.
matching_docs_buffer: Vec<DocId>,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl<B: SubAggBuffer> SegmentFilterCollector<B> {
@@ -520,7 +518,6 @@ impl<B: SubAggBuffer> SegmentFilterCollector<B> {
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
req_data: FilterAggReqData,
) -> crate::Result<Self> {
// Build sub-aggregation collectors if any
let sub_agg_collector = if !node.children.is_empty() {
@@ -530,15 +527,11 @@ impl<B: SubAggBuffer> SegmentFilterCollector<B> {
};
let sub_agg_collector = sub_agg_collector.map(BufferedSubAggs::new);
let max_doc = req_data.segment_reader.max_doc();
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector,
req_data,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
matching_docs_buffer: Vec::with_capacity(buffer_capacity),
})
}
}
@@ -547,23 +540,18 @@ pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let req_data = req.per_request.filter_req_data[node.idx_in_req_data].clone();
req.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let is_top_level = req_data.is_top_level;
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggBuffer>::from_req_and_validate(
req, node, req_data,
)?,
SegmentFilterCollector::<LowCardSubAggBuffer>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggBuffer>::from_req_and_validate(
req, node, req_data,
)?,
SegmentFilterCollector::<HighCardSubAggBuffer>::from_req_and_validate(req, node)?,
))
}
}
@@ -573,7 +561,7 @@ impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("name", &self.req_data.name)
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
@@ -610,7 +598,11 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
};
// Get the name of this filter aggregation
let name = self.req_data.name.clone();
let name = agg_data.per_request.filter_req_data[self.accessor_idx]
.as_ref()
.expect("filter_req_data slot is empty")
.name
.clone();
results.push(
name,
@@ -631,24 +623,27 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
// Use batch filtering with O(1) BitSet lookups
self.matching_docs_buffer.clear();
self.req_data
.evaluator
.filter_batch(docs, &mut self.matching_docs_buffer);
req.matching_docs_buffer.clear();
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
bucket.doc_count += self.matching_docs_buffer.len() as u64;
bucket.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !self.matching_docs_buffer.is_empty() {
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
for &doc_id in &self.matching_docs_buffer {
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
@@ -679,17 +674,6 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
None
}
}
/// Intermediate result for filter aggregation

View File

@@ -21,7 +21,6 @@ use crate::TantivyError;
/// Contains all information required by the SegmentHistogramCollector to perform the
/// histogram or date_histogram aggregation on a segment.
#[derive(Debug, Clone)]
pub struct HistogramAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
@@ -298,7 +297,7 @@ pub struct SegmentHistogramCollector {
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardBufferedSubAggs>,
req_data: HistogramAggReqData,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
@@ -309,7 +308,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.req_data.name.clone();
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
@@ -326,10 +328,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let req = &self.req_data;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
@@ -359,6 +361,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
}
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
@@ -391,17 +394,6 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Histogram is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentHistogramCollector {
@@ -424,7 +416,10 @@ impl SegmentHistogramCollector {
}
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
let is_date_agg = self.req_data.field_type == ColumnType::DateTime;
let is_date_agg = agg_data
.get_histogram_req_data(self.accessor_idx)
.field_type
== ColumnType::DateTime;
Ok(IntermediateBucketResult::Histogram {
buckets,
is_date_agg,
@@ -440,7 +435,7 @@ impl SegmentHistogramCollector {
} else {
None
};
let mut req_data = agg_data.per_request.histogram_req_data[node.idx_in_req_data].clone();
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
@@ -450,16 +445,12 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
agg_data
.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let sub_agg = sub_agg.map(BufferedSubAggs::new);
Ok(Self {
parent_buckets: Default::default(),
sub_agg,
req_data,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}

View File

@@ -28,6 +28,7 @@ mod histogram;
mod range;
mod term_agg;
mod term_missing_agg;
mod term_ord_to_str_cache;
use std::collections::HashMap;
use std::fmt;

View File

@@ -23,7 +23,6 @@ use crate::TantivyError;
/// Contains all information required by the SegmentRangeCollector to perform the
/// range aggregation on a segment.
#[derive(Debug, Clone)]
pub struct RangeAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
@@ -162,7 +161,7 @@ pub struct SegmentRangeCollector<B: SubAggBuffer> {
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) req_data: RangeAggReqData,
pub(crate) accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<B>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
@@ -185,7 +184,7 @@ impl<B: SubAggBuffer> Debug for SegmentRangeCollector<B> {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("name", &self.req_data.name)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
@@ -240,7 +239,10 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = self.req_data.name.to_string();
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
@@ -279,15 +281,17 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_range_req_data(self.accessor_idx);
agg_data
.column_block_accessor
.fetch_block(docs, &self.req_data.accessor);
.fetch_block(docs, &req.accessor);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &self.req_data.accessor)
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
@@ -297,6 +301,7 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
@@ -314,26 +319,15 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets()?;
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Range is a multi-bucket agg with no single value to extract.
None
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
@@ -341,11 +335,8 @@ pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let req_data = agg_data.per_request.range_req_data[node.idx_in_req_data].clone();
agg_data
.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
@@ -363,7 +354,7 @@ pub(crate) fn build_segment_range_collector(
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggBuffer> {
sub_agg: sub_agg.map(LowCardBufferedSubAggs::new),
column_type: field_type,
req_data,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
@@ -372,7 +363,7 @@ pub(crate) fn build_segment_range_collector(
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggBuffer> {
sub_agg: sub_agg.map(BufferedSubAggs::new),
column_type: field_type,
req_data,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
@@ -381,9 +372,12 @@ pub(crate) fn build_segment_range_collector(
}
impl<B: SubAggBuffer> SegmentRangeCollector<B> {
pub(crate) fn create_new_buckets(&mut self) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = &self.req_data;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
@@ -558,16 +552,17 @@ mod tests {
get_test_index_with_num_docs,
};
pub fn build_test_buckets(
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> Vec<SegmentRangeAndBucketEntry> {
) -> SegmentRangeCollector<HighCardSubAggBuffer> {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
..Default::default()
};
extend_validate_ranges(&req.ranges, &field_type)
// Build buckets directly as in from_req_and_validate without AggregationsData
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
.expect("unexpected error in extend_validate_ranges")
.iter()
.map(|range| {
@@ -598,7 +593,16 @@ mod tests {
},
}
})
.collect()
.collect();
SegmentRangeCollector {
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
#[test]
@@ -842,9 +846,9 @@ mod tests {
#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let parent_buckets = vec![build_test_buckets(buckets, ColumnType::F64)];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = parent_buckets[0].clone();
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -865,9 +869,9 @@ mod tests {
(10f64..20f64).into(),
(20f64..f64::MAX).into(),
];
let parent_buckets = vec![build_test_buckets(buckets, ColumnType::F64)];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = parent_buckets[0].clone();
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -880,18 +884,18 @@ mod tests {
#[test]
fn bucket_range_test_negative_vals() {
let buckets = vec![(-10f64..-1f64).into()];
let parent_buckets = vec![build_test_buckets(buckets, ColumnType::F64)];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = parent_buckets[0].clone();
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
#[test]
fn bucket_range_test_positive_vals() {
let buckets = vec![(0f64..10f64).into()];
let parent_buckets = vec![build_test_buckets(buckets, ColumnType::F64)];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = parent_buckets[0].clone();
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -899,8 +903,8 @@ mod tests {
#[test]
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let parent_buckets = vec![build_test_buckets(ranges, ColumnType::U64)];
let search = |val: u64| get_bucket_pos(val, &parent_buckets[0]);
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -945,8 +949,8 @@ mod tests {
fn range_binary_search_test_f64() {
let ranges = vec![(10.0..100.0).into()];
let parent_buckets = vec![build_test_buckets(ranges, ColumnType::F64)];
let search = |val: u64| get_bucket_pos(val, &parent_buckets[0]);
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);

View File

@@ -7,9 +7,10 @@ use columnar::{
NumericalValue, StrColumn,
};
use common::{BitSet, TinySet};
use rustc_hash::FxHashMap;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use super::term_ord_to_str_cache::{StringArena, StringRef, TermOrdToStrCache};
use super::{CustomOrder, Order, OrderTarget};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
@@ -352,15 +353,19 @@ pub(crate) fn build_segment_term_collector(
)));
}
// Validate that the referenced sub-aggregation exists when ordering by one.
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric sub_aggregations"
))
})?;
// Validate sub aggregation exists when ordering by sub-aggregation.
{
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
}
// Build sub-aggregation blueprint if there are children.
@@ -393,6 +398,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
@@ -404,6 +410,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else if max_term_id < 8_000_000 && is_top_level {
@@ -418,6 +425,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else {
@@ -431,6 +439,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
}
@@ -466,6 +475,9 @@ trait TermAggregationMap: Clone + Debug + 'static {
/// Returns the term aggregation as a vector of (term_id, bucket) pairs,
/// in any order.
fn into_vec(self) -> Vec<(u64, Bucket)>;
/// Collects all term ordinals present in this map into the given set.
fn collect_term_ords(&self, set: &mut FxHashSet<u64>);
}
#[derive(Clone, Debug)]
@@ -618,6 +630,20 @@ impl TermAggregationMap for PagedTermMap {
Self { pages, mem_usage }
}
fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (page_idx, page_opt) in self.pages.iter().enumerate() {
if let Some(page) = page_opt {
let base_term_id = (page_idx << PAGE_SHIFT) as u64;
for (bucket_pos, &tiny_set) in page.presence.iter().enumerate() {
let base_offset = bucket_pos * 64;
for bit in tiny_set.into_iter() {
set.insert(base_term_id + (base_offset + bit as usize) as u64);
}
}
}
}
}
}
impl TermAggregationMap for HashMapTermBuckets {
@@ -644,6 +670,10 @@ impl TermAggregationMap for HashMapTermBuckets {
fn new(_max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self {
Self::default()
}
fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
set.extend(self.bucket_map.keys().copied());
}
}
/// An optimized term map implementation for a compact set of term ordinals.
@@ -700,6 +730,14 @@ impl TermAggregationMap for VecTermBucketsNoAgg {
.collect(),
}
}
fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (term_id, &count) in self.buckets.iter().enumerate() {
if count > 0 {
set.insert(term_id as u64);
}
}
}
}
/// An optimized term map implementation for a compact set of term ordinals.
@@ -749,6 +787,63 @@ impl TermAggregationMap for VecTermBuckets {
.collect(),
}
}
fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (term_id, bucket) in self.buckets.iter().enumerate() {
if bucket.count > 0 {
set.insert(term_id as u64);
}
}
}
}
fn build_term_ord_cache<TermMap: TermAggregationMap>(
parent_buckets: &[TermMap],
dictionary: &Dictionary,
term_req: &TermsAggReqData,
) -> std::io::Result<TermOrdToStrCache> {
let capacity: usize = parent_buckets.len() * 64;
let mut term_ords_set: FxHashSet<u64> =
FxHashSet::with_capacity_and_hasher(capacity, FxBuildHasher);
for bucket in parent_buckets.iter() {
bucket.collect_term_ords(&mut term_ords_set);
}
if let Some(missing_sentinel) = term_req.missing_value_for_accessor {
term_ords_set.remove(&missing_sentinel);
}
let mut term_ords: Vec<u64> = term_ords_set.into_iter().collect();
term_ords.sort_unstable();
term_ords.pop_if(|highest_term_ord| *highest_term_ord >= dictionary.num_terms() as u64);
let mut string_arena = StringArena::default();
let mut string_refs: Vec<StringRef> = Vec::with_capacity(term_ords.len());
let all_found: bool = dictionary.sorted_ords_to_term_cb(&term_ords, |term_bytes| {
let term_str = std::str::from_utf8(term_bytes).expect("could not convert to str");
string_refs.push(string_arena.register_str(term_str));
})?;
assert!(all_found);
let missing_key: Option<IntermediateKey> =
term_req
.req
.missing
.as_ref()
.map(|missing_value| match missing_value {
Key::Str(s) => IntermediateKey::Str(s.clone()),
Key::F64(v) => IntermediateKey::F64(*v),
Key::U64(v) => IntermediateKey::U64(*v),
Key::I64(v) => IntermediateKey::I64(*v),
});
Ok(TermOrdToStrCache::new(
term_ords,
string_refs,
string_arena,
missing_key,
))
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
@@ -761,6 +856,7 @@ struct SegmentTermCollector<TermMap: TermAggregationMap, B: SubAggBuffer> {
bucket_id_provider: BucketIdProvider,
max_term_id: u64,
terms_req_data: TermsAggReqData,
term_ord_cache: Option<TermOrdToStrCache>,
}
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
@@ -779,6 +875,17 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
) -> crate::Result<()> {
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(bucket, agg_data)?;
if self.terms_req_data.column_type == ColumnType::Str && self.term_ord_cache.is_none() {
if let Some(str_dict_column) = &self.terms_req_data.str_dict_column {
self.term_ord_cache = Some(build_term_ord_cache(
&self.parent_buckets,
str_dict_column.dictionary(),
&self.terms_req_data,
)?);
}
}
let bucket = std::mem::replace(
&mut self.parent_buckets[bucket as usize],
TermMap::new(0, &mut self.bucket_id_provider),
@@ -793,6 +900,7 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
.map(BufferedSubAggs::get_sub_agg_collector),
bucket,
agg_data,
self.term_ord_cache.as_ref(),
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -883,17 +991,6 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Terms is a multi-bucket agg with no single value to extract.
None
}
}
/// Missing value are represented as a sentinel value in the column.
@@ -964,9 +1061,13 @@ where
mut sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
term_buckets: TermMap,
agg_data: &AggregationsSegmentCtx,
term_ord_cache: Option<&TermOrdToStrCache>,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
match &term_req.req.order.target {
OrderTarget::Key => {
// We rely on the fact, that term ordinals match the order of the strings
@@ -978,37 +1079,10 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
}
}
OrderTarget::SubAggregation(sub_agg_path) => {
// Peek segment-level metric values, sort, then fall through to
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
// by a sub-agg: top-K results are approximate and may differ from the
// global ordering, especially for non-monotonic metrics like avg/min.
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"Could not find sub-aggregation collector for path {sub_agg_path}"
))
})?;
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
// Fetch values up-front; otherwise sort would re-compute per comparison
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
.into_iter()
.map(|bucket| {
let metric_value = coll
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
.unwrap_or(0.0);
(metric_value, bucket)
})
.collect();
if term_req.req.order.order == Order::Desc {
keyed.sort_unstable_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
keyed.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
entries = keyed.into_iter().map(|(_, e)| e).collect();
OrderTarget::SubAggregation(_name) => {
// don't sort and cut off since it's hard to make assumptions on the quality of the
// results when cutting off du to unknown nature of the sub_aggregation (possible
// to check).
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
@@ -1019,8 +1093,11 @@ where
}
}
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
};
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
@@ -1033,43 +1110,76 @@ where
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(intermediate_key, intermediate_entry);
}
if let Some(cache) = term_ord_cache {
// Use cached term resolution: missing value is handled via the cache.
if let Some(missing_sentinel) = term_req.missing_value_for_accessor {
if let Some(pos) = entries.iter().position(|(tid, _)| *tid == missing_sentinel)
{
let (_tid, bucket) = entries.swap_remove(pos);
if let Some(missing_key) = cache.missing_key() {
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(missing_key.clone(), intermediate_entry);
}
}
}
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
for (term_ord, bucket) in entries {
if let Some(term_str) = cache.get(term_ord) {
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(
IntermediateKey::Str(term_str.to_string()),
intermediate_entry,
);
}
}
} else {
if let Some((intermediate_key, bucket)) =
extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;
)?;
dict.insert(intermediate_key, intermediate_entry);
}
let mut intermediate_entry_it = intermediate_entries.into_iter();
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
})?;
let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;
let mut intermediate_entry_it = intermediate_entries.into_iter();
term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
})?;
}
if term_req.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
@@ -1190,13 +1300,13 @@ where
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentTermCollector<TermMap, B> {
#[inline]
fn collect_terms_with_docs(
iter: impl Iterator<Item = (crate::DocId, u64)>,
doc_term_ord_iter: impl Iterator<Item = (crate::DocId, u64)>,
term_buckets: &mut TermMap,
bucket_id_provider: &mut BucketIdProvider,
sub_agg: &mut BufferedSubAggs<B>,
) {
for (doc, term_id) in iter {
let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider);
for (doc, term_ord) in doc_term_ord_iter {
let bucket_id = term_buckets.term_entry(term_ord, bucket_id_provider);
sub_agg.push(bucket_id, doc);
}
}
@@ -1795,263 +1905,6 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(true)
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(false)
}
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
// Distinct score values per bucket key: A→5, B→1, C→3.
// Order by cardinality desc must yield A, C, B.
let segment_and_terms = vec![vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
]];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
// and `sum_other_doc_count` reflects the dropped B (3 docs).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
Ok(())
}
#[test]
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(true)
}
#[test]
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(false)
}
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
let segment_and_terms = vec![
vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
],
vec![
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// Desc on a Sum metric engages the fast path (column is U64).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "asc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 desc with cutoff: top-2 by sum (A, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "mystats.sum": "desc" }
},
"aggs": {
"mystats": { "stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Sum on a signed column (I64) takes the same cutoff path. Results may be
// approximate near the boundary on adversarial data, but for this dataset the
// top-K is unambiguous.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score_i64" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Order by extended_stats sub-property exercises compute_metric_value on the
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "ext.max": "desc" }
},
"aggs": {
"ext": { "extended_stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)

View File

@@ -177,17 +177,6 @@ impl SegmentAggregationCollector for TermMissingAgg {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
None
}
}
#[cfg(test)]

View File

@@ -0,0 +1,196 @@
use rustc_hash::FxHashMap;
use crate::aggregation::intermediate_agg_result::IntermediateKey;
#[derive(Clone, Copy, Debug)]
pub(crate) struct StringRef {
start: u32,
len: u32,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct StringArena {
buffer: String,
}
impl StringArena {
pub fn register_str(&mut self, value: &str) -> StringRef {
let start = self.buffer.len() as u32;
self.buffer.push_str(value);
StringRef {
start,
len: value.len() as u32,
}
}
pub fn get_str(&self, string_ref: StringRef) -> &str {
let start = string_ref.start as usize;
let end = start + string_ref.len as usize;
&self.buffer[start..end]
}
pub fn len(&self) -> usize {
self.buffer.len()
}
}
pub(crate) struct TermOrdToStrCache {
string_arena: StringArena,
missing_key: Option<IntermediateKey>,
addr: TermOrdToAddr,
}
enum TermOrdToAddr {
Dense { offsets: Vec<Option<StringRef>> },
Sparse { terms: FxHashMap<u64, StringRef> },
}
impl std::fmt::Debug for TermOrdToStrCache {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self.addr {
TermOrdToAddr::Dense { offsets } => f
.debug_struct("TermOrdToStrCache::Dense")
.field("num_slots", &offsets.len())
.field("arena_bytes", &self.string_arena.len())
.finish(),
TermOrdToAddr::Sparse { terms } => f
.debug_struct("TermOrdToStrCache::Sparse")
.field("num_terms", &terms.len())
.field("arena_bytes", &self.string_arena.len())
.finish(),
}
}
}
impl TermOrdToStrCache {
/// `term_ords` must be sorted. Each entry in `string_refs` is a reference
/// into `string_arena` for the corresponding term ord.
pub fn new(
term_ords: Vec<u64>,
string_refs: Vec<StringRef>,
string_arena: StringArena,
missing_key: Option<IntermediateKey>,
) -> TermOrdToStrCache {
let num_terms = term_ords.len();
assert_eq!(num_terms, string_refs.len());
if term_ords.is_empty() {
return TermOrdToStrCache {
string_arena,
missing_key,
addr: TermOrdToAddr::Dense {
offsets: Vec::new(),
},
};
}
let highest_term_ord = term_ords.last().copied().unwrap_or(0u64);
let should_use_dense =
highest_term_ord < 1_000_000u64 || highest_term_ord < num_terms as u64 * 3u64;
let addr = if should_use_dense {
let num_slots = highest_term_ord as usize + 1;
let mut offsets: Vec<Option<StringRef>> = vec![None; num_slots];
for (term_ord, string_ref) in term_ords.into_iter().zip(string_refs) {
offsets[term_ord as usize] = Some(string_ref);
}
TermOrdToAddr::Dense { offsets }
} else {
let terms: FxHashMap<u64, StringRef> = term_ords.into_iter().zip(string_refs).collect();
TermOrdToAddr::Sparse { terms }
};
TermOrdToStrCache {
string_arena,
missing_key,
addr,
}
}
pub fn get(&self, term_ord: u64) -> Option<&str> {
let string_ref = match &self.addr {
TermOrdToAddr::Dense { offsets } => (*offsets.get(term_ord as usize)?)?,
TermOrdToAddr::Sparse { terms } => *terms.get(&term_ord)?,
};
Some(self.string_arena.get_str(string_ref))
}
pub fn missing_key(&self) -> Option<&IntermediateKey> {
self.missing_key.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_string_arena() {
let mut arena = StringArena::default();
let r1 = arena.register_str("hello");
let r2 = arena.register_str("world");
let r3 = arena.register_str("");
let r4 = arena.register_str("!");
assert_eq!(arena.get_str(r1), "hello");
assert_eq!(arena.get_str(r2), "world");
assert_eq!(arena.get_str(r3), "");
assert_eq!(arena.get_str(r4), "!");
}
fn build_arena(terms: &[&str]) -> (StringArena, Vec<StringRef>) {
let mut arena = StringArena::default();
let refs: Vec<StringRef> = terms.iter().map(|t| arena.register_str(t)).collect();
(arena, refs)
}
#[test]
fn test_dense_cache() {
let term_ords = vec![0, 2, 5];
let (arena, refs) = build_arena(&["alpha", "beta", "gamma"]);
let cache = TermOrdToStrCache::new(term_ords, refs, arena, None);
assert_eq!(cache.get(0), Some("alpha"));
assert_eq!(cache.get(1), None);
assert_eq!(cache.get(2), Some("beta"));
assert_eq!(cache.get(3), None);
assert_eq!(cache.get(5), Some("gamma"));
assert_eq!(cache.get(6), None);
assert_eq!(cache.get(100), None);
}
#[test]
fn test_sparse_cache() {
let term_ords = vec![1_000_000, 2_000_000, 5_000_000];
let (arena, refs) = build_arena(&["foo", "bar", "baz"]);
let cache = TermOrdToStrCache::new(term_ords, refs, arena, None);
assert_eq!(cache.get(1_000_000), Some("foo"));
assert_eq!(cache.get(2_000_000), Some("bar"));
assert_eq!(cache.get(5_000_000), Some("baz"));
assert_eq!(cache.get(0), None);
assert_eq!(cache.get(3_000_000), None);
}
#[test]
fn test_empty_cache() {
let cache = TermOrdToStrCache::new(Vec::new(), Vec::new(), StringArena::default(), None);
assert_eq!(cache.get(0), None);
assert_eq!(cache.get(100), None);
}
#[test]
fn test_missing_key() {
let missing = IntermediateKey::Str("N/A".to_string());
let (arena, refs) = build_arena(&["x"]);
let cache = TermOrdToStrCache::new(vec![0], refs, arena, Some(missing));
assert_eq!(
cache.missing_key(),
Some(&IntermediateKey::Str("N/A".to_string()))
);
let (arena, refs) = build_arena(&["x"]);
let cache_no_missing = TermOrdToStrCache::new(vec![0], refs, arena, None);
assert_eq!(cache_no_missing.missing_key(), None);
}
}

View File

@@ -1004,20 +1004,24 @@ impl IntermediateCompositeBucketResult {
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap_or_default();
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let buckets = trimmed_entry_vec
.into_iter()

View File

@@ -4,7 +4,6 @@ use std::io;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::{BitSet, TinySet};
use datasketches::hll::{Coupon, HllSketch, HllType, HllUnion};
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -21,12 +20,6 @@ use crate::TantivyError;
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
const LG_K: u8 = 11;
/// Promote FxHashSet<u64> -> PagedBitset at ~3% density (`len * 32 >
/// dict_num_terms`). Past this point the bitset (~`dict_num_terms / 7.5`
/// bytes) is smaller than the hashset (~10 B/entry minimum) and avoids
/// the per-insert hash.
const PROMOTION_RATIO: u64 = 32;
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
@@ -184,263 +177,9 @@ impl CouponCache {
}
}
// =================================================================
// PagedBitset: a sparse bitset indexed by term_ord.
//
// Used as the dense alternative to FxHashSet<u64> once a string
// cardinality bucket has accumulated enough unique term ordinals.
// Memory is bounded to (touched pages) * (page bytes), not
// (max_term_ord / 8).
//
// Page geometry mirrors `PagedTermMap` in `term_agg.rs`: 1024 ords
// per page, lazy `Vec<Option<Box<Page>>>` directory.
// =================================================================
const BITSET_PAGE_SHIFT: u32 = 10;
const BITSET_PAGE_BITS: u64 = 1u64 << BITSET_PAGE_SHIFT; // 1024
const BITSET_PAGE_MASK: u64 = BITSET_PAGE_BITS - 1;
const BITSET_WORDS_PER_PAGE: usize = (BITSET_PAGE_BITS / 64) as usize; // 16
#[derive(Clone)]
struct PagedBitsetPage {
words: [TinySet; BITSET_WORDS_PER_PAGE],
}
impl PagedBitsetPage {
fn new() -> Self {
Self {
words: [TinySet::empty(); BITSET_WORDS_PER_PAGE],
}
}
}
pub(crate) struct PagedBitset {
pages: Vec<Option<Box<PagedBitsetPage>>>,
/// Cached number of set bits, maintained on insert.
count: u64,
}
impl PagedBitset {
/// Allocates a directory big enough to hold ords up to and including
/// `max_term_ord`. Pages are allocated lazily on first set.
fn with_max_term_ord(max_term_ord: u64) -> Self {
let max_page_idx = (max_term_ord >> BITSET_PAGE_SHIFT) as usize;
let num_pages = max_page_idx + 1;
Self {
pages: vec![None; num_pages],
count: 0,
}
}
#[inline]
fn insert(&mut self, term_ord: u64) {
let page_idx = (term_ord >> BITSET_PAGE_SHIFT) as usize;
let intra = term_ord & BITSET_PAGE_MASK;
let word_idx = (intra >> 6) as usize;
let bit_idx = (intra & 63) as u32;
let page = match &mut self.pages[page_idx] {
Some(p) => p,
None => {
self.pages[page_idx] = Some(Box::new(PagedBitsetPage::new()));
self.pages[page_idx].as_mut().unwrap()
}
};
if page.words[word_idx].insert_mut(bit_idx) {
self.count += 1;
}
}
/// Number of set bits. O(1).
#[inline]
fn len(&self) -> u64 {
self.count
}
/// Iterate set ords in ascending order.
fn iter_sorted(&self) -> impl Iterator<Item = u64> + '_ {
self.pages
.iter()
.enumerate()
.filter_map(|(page_idx, page_opt)| page_opt.as_ref().map(|p| (page_idx, p)))
.flat_map(|(page_idx, page)| {
let page_base_ord = (page_idx as u64) << BITSET_PAGE_SHIFT;
page.words
.iter()
.enumerate()
.flat_map(move |(word_idx, &word)| {
let word_base_ord = page_base_ord + (word_idx as u64) * 64;
word.into_iter()
.map(move |bit| word_base_ord + u64::from(bit))
})
})
}
}
/// Threshold below which we use `BitSet` instead of `TermOrdSet`.
///
/// Both `BitSet` and `FxHashSet<u64>` have the same 32-byte struct, so the comparison is heap only:
/// * `BitSet` at T=256: 5 `TinySet` words covering 258 bits (with the missing-value sentinel) =
/// 40 bytes.
/// * `FxHashSet<u64>` after one insert: 4-bucket hashbrown table ≈ 56 bytes
pub(crate) const BITSET_MAX_TERM_ORD: u64 = 256;
// =================================================================
// TermOrdAccumulator: per-bucket abstraction over the entries set.
//
// Implementations:
// - `BitSet` (from `common`): used when `column.max_value()` is small (< BITSET_MAX_TERM_ORD).
// Pre-allocated, no promotion.
// - `TermOrdSet`: adaptive, starts as FxHashSet and promotes to a paged bitset when occupancy
// crosses the density threshold (only if promotion is enabled — typically gated on top-level
// aggregation).
//
// The trait lets `SegmentCardinalityCollector` be generic over the choice
// so the hot collect() loop monomorphizes to a direct call (no enum
// dispatch per insert).
// =================================================================
pub(crate) trait TermOrdAccumulator: Sized {
/// Construct an empty accumulator.
/// `max_term_ord_inclusive` is the largest term_ord that may be
/// inserted (used to size pre-allocated bitsets and the dense bitset
/// on promotion).
fn new(max_term_ord_inclusive: u64) -> Self;
fn insert(&mut self, term_ord: u64);
/// Bulk insert. Implementations may override to hoist any inner
/// dispatch outside the loop. Default loops `insert`.
#[inline]
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
for ord in ords {
self.insert(ord);
}
}
/// Hook called once per ingested block. Adaptive impls use this to
/// decide on sparse->dense promotion.
fn maybe_compact(&mut self) {}
fn len(&self) -> usize;
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_;
}
impl TermOrdAccumulator for BitSet {
#[inline]
fn new(max_term_ord_inclusive: u64) -> Self {
// `BitSet::with_max_value(M)` accepts ords in [0, M).
// We need ords up to and including `max_term_ord_inclusive`, plus
// the missing-value sentinel `column.max_value() + 1`.
BitSet::with_max_value((max_term_ord_inclusive + 2) as u32)
}
#[inline]
fn insert(&mut self, term_ord: u64) {
BitSet::insert(self, term_ord as u32);
}
#[inline]
fn len(&self) -> usize {
BitSet::len(self)
}
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
// `BitSet` itself doesn't expose iteration, but
// `BitSet::tinyset(bucket)` does. Walk per-bucket and yield each
// set bit. The capacity is `max_value()`; iterating to
// `div_ceil(64)` covers every possible ord exactly once.
let num_buckets = self.max_value().div_ceil(64);
(0..num_buckets).flat_map(move |bucket| {
let chunk_base = u64::from(bucket) * 64;
self.tinyset(bucket)
.into_iter()
.map(move |bit| chunk_base + u64::from(bit))
})
}
}
// =================================================================
// TermOrdSet: adaptive sparse->dense accumulator.
//
// Starts as an FxHashSet (cheap when few ords are seen). When occupancy
// crosses `len * PROMOTION_RATIO > max_term_ord_inclusive`, drains into
// a `PagedBitset` and continues dense. Promotion is one-way.
// =================================================================
pub(crate) struct TermOrdSet {
inner: TermOrdSetInner,
/// Largest term_ord that may be inserted. Used for both sizing the
/// dense bitset on promotion and as the promotion-threshold reference.
max_term_ord_inclusive: u64,
}
enum TermOrdSetInner {
Sparse(FxHashSet<u64>),
Dense(PagedBitset),
}
impl TermOrdAccumulator for TermOrdSet {
fn new(max_term_ord_inclusive: u64) -> Self {
Self {
inner: TermOrdSetInner::Sparse(FxHashSet::default()),
max_term_ord_inclusive,
}
}
#[inline]
fn insert(&mut self, term_ord: u64) {
match &mut self.inner {
TermOrdSetInner::Sparse(set) => {
set.insert(term_ord);
}
TermOrdSetInner::Dense(bitset) => bitset.insert(term_ord),
}
}
/// Hoist the Sparse/Dense match outside the per-ord loop so that a
/// block of inserts dispatches once.
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
match &mut self.inner {
TermOrdSetInner::Sparse(set) => {
for ord in ords {
set.insert(ord);
}
}
TermOrdSetInner::Dense(bitset) => {
for ord in ords {
bitset.insert(ord);
}
}
}
}
fn maybe_compact(&mut self) {
let TermOrdSetInner::Sparse(set) = &mut self.inner else {
return;
};
if set.len() as u64 * PROMOTION_RATIO <= self.max_term_ord_inclusive {
return;
}
// Size for ord <= max_term_ord_inclusive plus the missing sentinel
// (column.max_value() + 1, which may equal max_term_ord_inclusive
// when the column references every dictionary term).
let mut bitset = PagedBitset::with_max_term_ord(self.max_term_ord_inclusive + 1);
let set = std::mem::take(set);
for ord in set {
bitset.insert(ord);
}
self.inner = TermOrdSetInner::Dense(bitset);
}
fn len(&self) -> usize {
match &self.inner {
TermOrdSetInner::Sparse(set) => set.len(),
TermOrdSetInner::Dense(bitset) => bitset.len() as usize,
}
}
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
match &self.inner {
TermOrdSetInner::Sparse(set) => itertools::Either::Left(set.iter().copied()),
TermOrdSetInner::Dense(bitset) => itertools::Either::Right(bitset.iter_sorted()),
}
}
}
pub(crate) struct SegmentCardinalityCollector<S: TermOrdAccumulator> {
pub(crate) struct SegmentCardinalityCollector {
/// Buckets are Some(_) until they get consumed by into_intermediate_results().
buckets: Vec<Option<SegmentCardinalityCollectorBucket<S>>>,
buckets: Vec<Option<SegmentCardinalityCollectorBucket>>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
@@ -449,13 +188,9 @@ pub(crate) struct SegmentCardinalityCollector<S: TermOrdAccumulator> {
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
coupon_cache: Option<CouponCache>,
/// Largest term_ord that may be inserted into a bucket. For str columns
/// this is `accessor.max_value()`; for non-str columns this is unused
/// (no inserts go into `entries`) and set to 0.
max_term_ord_inclusive: u64,
}
impl<S: TermOrdAccumulator> Debug for SegmentCardinalityCollector<S> {
impl Debug for SegmentCardinalityCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentCardinalityCollector")
.field("column_type", &self.column_type)
@@ -467,21 +202,16 @@ impl<S: TermOrdAccumulator> Debug for SegmentCardinalityCollector<S> {
}
}
/// Per-bucket state. Shape depends on column kind: str columns dedup
/// term ords and only build the HLL sketch at finalization (saves the
/// ~96 B `CardinalityCollector` per bucket during collect); numeric/IpAddr
/// columns feed the sketch directly during collect.
pub(crate) enum SegmentCardinalityCollectorBucket<S: TermOrdAccumulator> {
Str(S),
Numeric(CardinalityCollector),
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl<S: TermOrdAccumulator> SegmentCardinalityCollectorBucket<S> {
impl SegmentCardinalityCollectorBucket {
#[inline(always)]
pub fn new(column_type: ColumnType, max_term_ord_inclusive: u64) -> Self {
if column_type == ColumnType::Str {
Self::Str(S::new(max_term_ord_inclusive))
} else {
Self::Numeric(CardinalityCollector::new(column_type as u8))
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: FxHashSet::default(),
}
}
@@ -492,57 +222,37 @@ impl<S: TermOrdAccumulator> SegmentCardinalityCollectorBucket<S> {
//
// If the column is str, then the values are dictionary encoded
// and have not been added to the sketch yet.
// We need to resolves the term ords accumulated in the str entries
// with the coupon cache, and append the results to a fresh sketch.
// We need to resolves the term ords accumulated in self.entries
// with the coupon cache, and append the results to the sketch.
fn into_intermediate_metric_result(
self,
mut self,
coupon_cache_opt: Option<&CouponCache>,
) -> crate::Result<IntermediateMetricResult> {
let cardinality = match self {
Self::Str(entries) => {
let mut cardinality = CardinalityCollector::new(ColumnType::Str as u8);
if let Some(coupon_cache) = coupon_cache_opt {
// Sketch must be empty for str columns: coupons are appended here
// from the term_ord set (and not directly during collection).
assert!(cardinality.sketch.is_empty());
append_to_sketch(&entries, coupon_cache, &mut cardinality);
}
cardinality
}
Self::Numeric(cardinality) => cardinality,
};
Ok(IntermediateMetricResult::Cardinality(cardinality))
if let Some(coupon_cache) = coupon_cache_opt {
assert!(self.cardinality.sketch.is_empty());
append_to_sketch(&self.entries, coupon_cache, &mut self.cardinality);
}
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
}
}
/// Builds a coupon cache from the given buckets, dictionary, and optional missing value.
/// Returns a mapping from term_ord to the hash (coupon) of the associated term.
fn build_coupon_cache<S: TermOrdAccumulator>(
buckets: &[Option<SegmentCardinalityCollectorBucket<S>>],
fn build_coupon_cache(
buckets: &[Option<SegmentCardinalityCollectorBucket>],
dictionary: &Dictionary,
missing_value_opt: Option<&Key>,
) -> io::Result<CouponCache> {
// Caller restricts this to str cardinality collectors, so every
// present bucket must be the `Str` variant. Pass 1 validates and
// computes the capacity hint; pass 2 inserts.
let mut max_bucket_len = 0usize;
let term_ords_capacity: usize = buckets
.iter()
.flatten()
.map(|bucket| bucket.entries.len())
.max()
.unwrap_or(0)
* 2;
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(term_ords_capacity, FxBuildHasher);
for bucket in buckets.iter().flatten() {
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => {
max_bucket_len = max_bucket_len.max(entries.len());
}
SegmentCardinalityCollectorBucket::Numeric(_) => {
return Err(io::Error::other(
"build_coupon_cache invoked with a non-str bucket",
));
}
}
}
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(max_bucket_len * 2, FxBuildHasher);
for bucket in buckets.iter().flatten() {
if let SegmentCardinalityCollectorBucket::Str(entries) = bucket {
term_ords_set.extend(entries.iter_ords());
}
term_ords_set.extend(bucket.entries.iter().copied());
}
let mut term_ords: Vec<u64> = term_ords_set.into_iter().collect();
term_ords.sort_unstable();
@@ -574,8 +284,8 @@ fn build_coupon_cache<S: TermOrdAccumulator>(
Ok(CouponCache::new(term_ords, coupons, missing_coupon_opt))
}
fn append_to_sketch<S: TermOrdAccumulator>(
term_ords: &S,
fn append_to_sketch(
term_ords: &FxHashSet<u64>,
coupon_cache: &CouponCache,
sketch: &mut CardinalityCollector,
) {
@@ -584,7 +294,7 @@ fn append_to_sketch<S: TermOrdAccumulator>(
coupon_map,
missing_coupon_opt,
} => {
for term_ord in term_ords.iter_ords() {
for &term_ord in term_ords {
if let Some(coupon) = coupon_map
.get(term_ord as usize)
.copied()
@@ -598,8 +308,8 @@ fn append_to_sketch<S: TermOrdAccumulator>(
coupon_map,
missing_coupon_opt,
} => {
for term_ord in term_ords.iter_ords() {
if let Some(coupon) = coupon_map.get(&term_ord).copied().or(*missing_coupon_opt) {
for term_ord in term_ords {
if let Some(coupon) = coupon_map.get(term_ord).copied().or(*missing_coupon_opt) {
sketch.insert_coupon(coupon);
}
}
@@ -607,13 +317,12 @@ fn append_to_sketch<S: TermOrdAccumulator>(
}
}
impl<S: TermOrdAccumulator> SegmentCardinalityCollector<S> {
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
max_term_ord_inclusive: u64,
) -> Self {
Self {
buckets: Vec::new(),
@@ -622,7 +331,6 @@ impl<S: TermOrdAccumulator> SegmentCardinalityCollector<S> {
accessor,
missing_value_for_accessor,
coupon_cache: None,
max_term_ord_inclusive,
}
}
@@ -639,9 +347,7 @@ impl<S: TermOrdAccumulator> SegmentCardinalityCollector<S> {
}
}
impl<S: TermOrdAccumulator + 'static> SegmentAggregationCollector
for SegmentCardinalityCollector<S>
{
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -696,41 +402,31 @@ impl<S: TermOrdAccumulator + 'static> SegmentAggregationCollector
));
};
let col_block_accessor = &agg_data.column_block_accessor;
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => {
// Promotion check runs on the pre-block state: the first call
// sees an empty set (no-op), and the last block of inserts
// doesn't trigger a promotion of a set we won't grow further.
// The trait dispatches once per block (via `extend_from_iter`)
// for adaptive variants and inlines to a tight loop for the
// BitSet path.
entries.maybe_compact();
entries.extend_from_iter(col_block_accessor.iter_vals());
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
bucket.entries.insert(term_ord);
}
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
cardinality.insert(val);
}
} else {
for val in col_block_accessor.iter_vals() {
cardinality.insert(val);
}
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.insert(val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.insert(val);
}
}
@@ -743,40 +439,12 @@ impl<S: TermOrdAccumulator + 'static> SegmentAggregationCollector
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
let column_type = self.column_type;
let max_term_ord_inclusive = self.max_term_ord_inclusive;
self.buckets.resize_with(max_bucket as usize + 1, || {
Some(SegmentCardinalityCollectorBucket::<S>::new(
column_type,
max_term_ord_inclusive,
))
Some(SegmentCardinalityCollectorBucket::new(self.column_type))
});
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.name != sub_agg_name || !sub_agg_property.is_empty() {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?.as_ref()?;
// For string columns the sketch isn't built until finalization; the
// term_ord set's len is the exact distinct count. For numeric columns
// the sketch is populated during collect.
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => Some(entries.len() as f64),
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
Some(cardinality.sketch.estimate().trunc())
}
}
}
}
#[derive(Clone, Debug)]
@@ -924,134 +592,6 @@ mod tests {
Ok(())
}
/// Build a single-segment string-cardinality index with 32 unique terms.
/// `column.max_value() = 31` is well below `BITSET_MAX_TERM_ORD`,
/// so the bucket exercises the `BitSet` path end to end.
#[test]
fn cardinality_aggregation_test_str_bitset() -> crate::Result<()> {
let terms: Vec<String> = (0..32).map(|i| format!("term_{i}")).collect();
let term_refs: Vec<Vec<&str>> = terms.iter().map(|t| vec![t.as_str()]).collect::<Vec<_>>();
// single segment so we have a single dictionary of 32 terms.
let index = get_test_index_from_terms(true, &term_refs)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": { "field": "string_id" }
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 32.0);
Ok(())
}
/// `BitSet` path with a `missing` parameter: the column-level missing
/// sentinel (`column.max_value() + 1`) flows into the bitset, the
/// dict lookup filter at finalization drops it, and the missing
/// coupon is applied separately.
#[test]
fn cardinality_aggregation_test_str_bitset_with_missing() {
let mut schema_builder = Schema::builder();
let name_field = schema_builder.add_text_field("name", STRING | FAST);
let index = Index::create_in_ram(schema_builder.build());
let mut writer = index.writer_for_tests().unwrap();
for i in 0..16 {
let term = format!("t{i:02}");
writer.add_document(doc!(name_field => term)).unwrap();
}
// One empty doc, exercising the missing sentinel.
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "name",
"missing": "MISSING_SENTINEL_KEY",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index).unwrap();
// 16 distinct real terms + 1 distinct "missing" value = 17.
assert_eq!(res["cardinality"]["value"], 17.0);
}
/// Unit-test the PagedBitset itself: cross-page inserts produce sorted
/// iteration, len() matches the inserted set, and duplicates are
/// idempotent.
#[test]
fn paged_bitset_basic() {
use super::PagedBitset;
// Span several pages: BITSET_PAGE_BITS = 1024, so ords > 1024 land
// on the second page, > 2048 on the third, etc.
let ords = [0u64, 1, 63, 64, 1023, 1024, 1025, 4096, 4097, 9999, 10_000];
let max_ord = *ords.iter().max().unwrap();
let mut bitset = PagedBitset::with_max_term_ord(max_ord);
for &ord in &ords {
bitset.insert(ord);
// Idempotent: inserting again must not increase count.
bitset.insert(ord);
}
assert_eq!(bitset.len(), ords.len() as u64);
let collected: Vec<u64> = bitset.iter_sorted().collect();
let mut expected: Vec<u64> = ords.to_vec();
expected.sort_unstable();
assert_eq!(collected, expected);
}
/// Unit-test `TermOrdSet`: starts Sparse, promotes to Dense on
/// `maybe_compact` once the density threshold is crossed, and
/// `iter_ords()` yields the same set in either state. Ords spanning
/// multiple paged-bitset pages exercise the Dense iter ordering.
#[test]
fn term_ord_set_promotes_on_maybe_compact() {
use super::{TermOrdAccumulator, TermOrdSet, PROMOTION_RATIO};
// Pick max so promotion needs few inserts: len * RATIO > max with
// RATIO=32 and max=64 trips at len=3 (3*32=96 > 64).
let max_term_ord = 64u64;
let mut set = <TermOrdSet as TermOrdAccumulator>::new(max_term_ord);
// Two inserts: should stay Sparse after maybe_compact (2 * RATIO = 64, not > 64).
set.insert(0);
set.insert(7);
set.maybe_compact();
assert_eq!(set.len(), 2);
// Third insert promotes on next maybe_compact.
set.insert(20);
assert_eq!(set.len(), 3);
// Sanity check: at len=3, 3 * PROMOTION_RATIO = 96 > 64.
assert!(3u64 * PROMOTION_RATIO > max_term_ord);
set.maybe_compact();
// Post-promotion: extending continues to work.
set.insert(15);
set.insert(15); // dup
assert_eq!(set.len(), 4);
let mut collected: Vec<u64> = set.iter_ords().collect();
collected.sort_unstable();
assert_eq!(collected, vec![0, 7, 15, 20]);
}
/// Unit-test the `BitSet` impl of `TermOrdAccumulator`: insert,
/// dedup, and iter_ords order.
#[test]
fn bitset_accumulator_basic() {
use common::BitSet;
use super::TermOrdAccumulator;
let mut set = <BitSet as TermOrdAccumulator>::new(255);
for ord in [0u64, 1, 63, 64, 65, 128, 200, 200, 0] {
<BitSet as TermOrdAccumulator>::insert(&mut set, ord);
}
assert_eq!(<BitSet as TermOrdAccumulator>::len(&set), 7);
let collected: Vec<u64> = set.iter_ords().collect();
assert_eq!(collected, vec![0, 1, 63, 64, 65, 128, 200]);
}
#[test]
fn cardinality_aggregation_u64() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
@@ -1143,42 +683,6 @@ mod tests {
Ok(())
}
/// A JSON path that resolves to both a Str column and a numeric column
/// produces two collector instances per segment — one with `Str` buckets
/// and one with `Numeric` buckets. Their `IntermediateMetricResult`s must
/// merge into the union cardinality.
#[test]
fn cardinality_aggregation_json_str_and_numeric() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_json_field("json", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(field => json!({"value": "hello"})))?;
writer.add_document(doc!(field => json!({"value": "world"})))?;
writer.add_document(doc!(field => json!({"value": "hello"})))?; // dup str
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(42u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?; // dup num
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "json.value"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// 4 distinct values: "hello", "world", 7, 42.
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
#[test]
fn cardinality_collector_serde_roundtrip() {
use super::CardinalityCollector;

View File

@@ -399,26 +399,6 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let extended = self.buckets.get(bucket_id as usize)?;
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
// doesn't disturb the eventual intermediate result.
extended
.finalize()
.get_value(sub_agg_property)
.ok()
.flatten()
}
}
#[cfg(test)]

View File

@@ -312,26 +312,6 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
return None;
}
let percentile: f64 = sub_agg_property.parse().ok()?;
if !(0.0..=100.0).contains(&percentile) {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?;
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
// not affect the intermediate state used for the final result.
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
}
}
#[cfg(test)]

View File

@@ -321,40 +321,6 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let stats = self.buckets.get(bucket_id as usize)?;
// The property depends on what we're collecting:
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
// the value they were configured to collect.
let prop = match self.collecting_for {
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
StatsType::Sum if sub_agg_property.is_empty() => "sum",
StatsType::Count if sub_agg_property.is_empty() => "count",
StatsType::Max if sub_agg_property.is_empty() => "max",
StatsType::Min if sub_agg_property.is_empty() => "min",
StatsType::Average if sub_agg_property.is_empty() => "avg",
_ => return None,
};
match prop {
"count" => Some(stats.count as f64),
"sum" => Some(stats.sum),
"min" if stats.count > 0 => Some(stats.min),
"max" if stats.count > 0 => Some(stats.max),
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
_ => None,
}
}
}
#[inline]

View File

@@ -644,17 +644,6 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
);
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// top_hits is not a numeric metric and cannot be used as an order target.
None
}
}
#[cfg(test)]

View File

@@ -76,31 +76,6 @@ pub trait SegmentAggregationCollector: Debug {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
///
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
/// this value and cuts off at segment time, matching the approximation tradeoff
/// Elasticsearch makes for any sub-agg ordering.
///
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
/// the metric is a single-value kind such as cardinality.
///
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
/// must be idempotent so the final intermediate result is unaffected.
///
/// No default impl on purpose: every collector must decide explicitly whether it
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
/// buckets to the same key and drop arbitrary winners.
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64>;
}
#[derive(Default)]
@@ -162,21 +137,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
for agg in &self.aggs {
if let Some(value) =
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
{
return Some(value);
}
}
None
}
}

View File

@@ -389,13 +389,6 @@ impl SegmentCollector for FacetSegmentCollector {
}
let mut facet = vec![];
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
// document has the exact registered facet, not a child of it).
// Passing it to ord_to_term would resolve to the last dictionary
// entry and produce a spurious facet from an unrelated branch.
if facet_ord == u64::MAX {
continue;
}
// TODO handle errors.
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
if let Some((end_collapsed_facet, _)) = facet
@@ -821,63 +814,6 @@ mod tests {
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
// When a document has the exact registered facet path (not just a child),
// harvest() must not turn the unmapped sentinel into a spurious root entry.
#[test]
fn test_facet_collector_wrong_root() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let facets: Vec<&str> = vec![
"/science-fiction/asimov",
"/science-fiction/clarke",
"/science-fiction/dick",
"/science-fiction/herbert",
"/science-fiction/orwell",
// This exact match on the registered facet is the bug trigger:
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
// mapping, which without the fix resolves to an unrelated term.
"/fantasy/epic-fantasy",
"/fantasy/epic-fantasy/tolkien",
"/fantasy/epic-fantasy/martin",
];
for facet_str in &facets {
index_writer.add_document(doc!(
facet_field => Facet::from(*facet_str)
))?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut facet_collector = FacetCollector::for_field("facet");
facet_collector.add_facet("/fantasy/epic-fantasy");
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
let result: Vec<(String, u64)> = counts
.get("/")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
assert_eq!(
result,
vec![
("/fantasy/epic-fantasy/martin".to_string(), 1),
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
]
);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -249,12 +249,6 @@ impl BlockSegmentPostings {
/// Returns the length of the current block.
///
/// Returns the decoded term-frequency buffer for the current block.
#[inline]
pub(crate) fn freq_output_array(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
@@ -304,11 +298,6 @@ impl BlockSegmentPostings {
}
}
#[inline]
pub(crate) fn has_remaining_docs(&self) -> bool {
self.skip_reader.has_remaining_docs()
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.block_loaded
}

View File

@@ -146,11 +146,6 @@ impl SkipReader {
skip_reader
}
#[inline(always)]
pub fn has_remaining_docs(&self) -> bool {
self.remaining_docs != 0
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0

View File

@@ -1,464 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
/// Block-max pruning for top-K over intersection of term scorers.
///
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
/// For each window, the sum of block_max_scores is compared to the current threshold;
/// if the block can't beat it, the entire block is skipped.
///
/// Within non-skipped blocks, individual documents are pruned by checking whether
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
/// performing the expensive intersection membership check (seeking into secondary scorers).
///
/// # Preconditions
/// - `scorers` has at least 2 elements
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
pub(crate) fn block_wand_intersection(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
assert!(scorers.len() >= 2);
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
scorers.sort_by_key(TermScorer::size_hint);
let (leader, secondaries) = scorers.split_first_mut().unwrap();
// Precompute global max scores for early termination checks.
let leader_max_score: Score = leader.max_score();
let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum();
// Early exit: no document can possibly beat the threshold.
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
// Borrow fieldnorm reader and BM25 weight before the main loop.
// These are immutable references to disjoint fields from block_cursor,
// but Rust's borrow checker can't see through method calls, so we
// extract them once upfront.
let fieldnorm_reader = leader.fieldnorm_reader().clone();
let bm25_weight = leader.bm25_weight().clone();
let mut doc = leader.doc();
let mut secondary_block_max_scores: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
let mut secondary_suffix_block_max: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
while doc < TERMINATED {
// --- Phase 1: Block-level pruning ---
//
// Position all skip readers on the block containing `doc`.
// seek_block is cheap: it only advances the skip reader, no block decompression.
leader.seek_block(doc);
let leader_block_max: Score = leader.block_max_score();
// Compute the window end as the minimum last_doc_in_block across all scorers.
// This ensures the block_max values are valid for all docs in [doc, window_end].
// Different scorers have independently aligned blocks, so we must use the
// smallest window where all block_max values hold.
let mut window_end: DocId = leader.last_doc_in_block();
let mut secondary_block_max_sum: Score = 0.0;
let num_secondaries = secondaries.len();
for (idx, secondary) in secondaries.iter_mut().enumerate() {
secondary.block_cursor().seek_block(doc);
if !secondary.block_cursor().has_remaining_docs() {
return;
}
window_end = window_end.min(secondary.last_doc_in_block());
let bms = secondary.block_max_score();
secondary_block_max_scores[idx] = bms;
secondary_block_max_sum += bms;
}
if leader_block_max + secondary_block_max_sum <= threshold {
// The entire window cannot beat the threshold. Skip past it.
doc = window_end + 1;
continue;
}
// --- Phase 2: Batch processing within the window ---
//
// Score-first approach: decode the leader's block, filter by threshold,
// then check intersection membership only for survivors. This avoids expensive
// secondary seeks for docs that can't beat the threshold.
let block_cursor = leader.block_cursor();
// seek loads the block and returns the in-block index of the first doc >= `doc`.
let start_idx = block_cursor.seek(doc);
// Use the branchless binary search on the doc decoder to find the first
// index past window_end.
let end_idx = block_cursor
.doc_decoder
.seek_within_block(window_end + 1)
.min(block_cursor.block_len());
let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx];
let block_freqs = &block_cursor.freq_output_array()[start_idx..end_idx];
// Pass 1: Batch-compute leader BM25 scores and branchlessly filter
// candidates that can't beat the threshold.
//
// The trick: always write to the buffer at `num_candidates`, then
// conditionally advance the count. The compiler can turn this into
// a cmov instead of a branch, avoiding misprediction costs.
let score_threshold = threshold - secondary_block_max_sum;
let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE];
let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE];
let mut num_candidates = 0usize;
for (candidate_doc, term_freq) in
block_docs.iter().copied().zip(block_freqs.iter().copied())
{
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc);
let leader_score = bm25_weight.score(fieldnorm_id, term_freq);
candidate_doc_ids[num_candidates] = candidate_doc;
candidate_scores[num_candidates] = leader_score;
num_candidates += (leader_score > score_threshold) as usize;
}
// Precompute suffix sums: suffix[i] = sum of block_max for secondaries[i+1..].
// Used in Phase 2 to prune candidates that can't beat threshold even with
// remaining secondaries contributing their block_max.
if num_candidates == 0 {
doc = window_end + 1;
continue;
}
let mut running = 0.0f32;
for idx in (0..num_secondaries).rev() {
secondary_suffix_block_max[idx] = running;
running += secondary_block_max_scores[idx];
}
// Pass 2: Check intersection membership only for survivors.
// score_threshold may be stale (threshold can increase from callbacks),
// but that's conservative — we may check a few extra candidates, never miss one.
'next_candidate: for candidate_idx in 0..num_candidates {
let candidate_doc = candidate_doc_ids[candidate_idx];
let mut total_score: Score = candidate_scores[candidate_idx];
for (secondary_idx, secondary) in secondaries.iter_mut().enumerate() {
// If a previous candidate already advanced this secondary past
// candidate_doc, the candidate can't be in the intersection.
if secondary.doc() > candidate_doc {
continue 'next_candidate;
}
let seek_result = secondary.seek(candidate_doc);
if seek_result != candidate_doc {
continue 'next_candidate;
}
total_score += secondary.score();
// Prune: even if all remaining secondaries score at their block max,
// can we still beat the threshold?
if total_score + secondary_suffix_block_max[secondary_idx] <= threshold {
continue 'next_candidate;
}
}
// All secondaries matched.
if total_score > threshold {
threshold = callback(candidate_doc, total_score);
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
}
}
doc = window_end + 1;
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use proptest::prelude::*;
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
struct Float(Score);
impl Eq for Float {}
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Float {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn nearly_equals(left: Score, right: Score) -> bool {
(left - right).abs() < 0.0001 * (left + right).abs()
}
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
fn compute_checkpoints_block_wand_intersection(
term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit: Score = 0.0;
let callback = &mut |doc, score| {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
limit
};
super::block_wand_intersection(term_scorers, Score::MIN, callback);
checkpoints
}
/// Naive baseline: intersect by iterating all docs.
fn compute_checkpoints_naive_intersection(
mut term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit = Score::MIN;
// Sort by cost to use the cheapest as driver.
term_scorers.sort_by_key(|s| s.cost());
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
let mut doc = leader.doc();
while doc != TERMINATED {
let mut all_match = true;
for secondary in secondaries.iter_mut() {
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
all_match = false;
break;
}
}
if all_match {
let score: Score =
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
if score > limit {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
}
}
doc = leader.advance();
}
checkpoints
}
const MAX_TERM_FREQ: u32 = 100u32;
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
(1..max_doc + 1)
.prop_flat_map(move |doc_freq| {
(
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
)
})
.prop_map(|(docset, term_freqs)| {
docset
.iter()
.map(|doc| doc as u32)
.zip(term_freqs.iter().cloned())
.collect::<Vec<_>>()
})
.boxed()
}
#[expect(clippy::type_complexity)]
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
(1u32..100u32)
.prop_flat_map(move |max_doc: u32| {
(
proptest::collection::vec(posting_list(max_doc), num_scorers),
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
)
})
.boxed()
}
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
// strategy.
const REPEAT: usize = 64;
let fieldnorms_expanded: Vec<u32> = fieldnorms
.iter()
.cloned()
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
.collect();
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
.iter()
.map(|posting_list| {
posting_list
.iter()
.cloned()
.flat_map(|(doc, term_freq)| {
(0_u32..REPEAT as u32).map(move |offset| {
(
doc * (REPEAT as u32) + offset,
if offset == 0 { term_freq } else { 1 },
)
})
})
.collect::<Vec<(DocId, u32)>>()
})
.collect();
let total_fieldnorms: u64 = fieldnorms_expanded
.iter()
.cloned()
.map(|fieldnorm| fieldnorm as u64)
.sum();
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
let max_doc = fieldnorms_expanded.len();
let make_scorers = || -> Vec<TermScorer> {
postings_lists_expanded
.iter()
.map(|postings| {
let bm25_weight = Bm25Weight::for_one_term(
postings.len() as u64,
max_doc as u64,
average_fieldnorm,
);
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
})
.collect()
};
for top_k in 1..4 {
let checkpoints_optimized =
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
assert_eq!(
checkpoints_optimized.len(),
checkpoints_naive.len(),
"Mismatch in checkpoint count for top_k={top_k}"
);
for (&(left_doc, left_score), &(right_doc, right_score)) in
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
{
assert_eq!(left_doc, right_doc);
assert!(
nearly_equals(left_score, right_score),
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_two_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(2)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_three_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(3)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
#[test]
fn test_block_wand_intersection_disjoint() {
// Two posting lists with no overlap — intersection is empty.
let fieldnorms: Vec<u32> = vec![10; 200];
let average_fieldnorm = 10.0;
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
let scorer_a = TermScorer::create_for_test(
&postings_a,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let scorer_b = TermScorer::create_for_test(
&postings_b,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
assert!(checkpoints.is_empty());
}
#[test]
fn test_block_wand_intersection_all_overlap() {
// Two posting lists with full overlap.
let fieldnorms: Vec<u32> = vec![10; 50];
let average_fieldnorm = 10.0;
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
let make_scorer = || {
TermScorer::create_for_test(
&postings,
&fieldnorms,
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
)
};
let checkpoints_opt =
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
let checkpoints_naive =
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
}
}

View File

@@ -16,7 +16,6 @@ use crate::{DocId, Score};
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
TermIntersection(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
@@ -50,9 +49,10 @@ where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 && !scorers[0].is::<TermScorer>() {
if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
}
{
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
if is_all_term_queries {
@@ -66,9 +66,6 @@ where
{
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else if scorers.len() == 1 {
// Single TermScorer without freq reading — unwrap directly.
return SpecializedScorer::Other(Box::new(scorers.into_iter().next().unwrap()));
} else {
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
@@ -96,13 +93,6 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
intersect_scorers(boxed_scorers, num_docs)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
@@ -307,43 +297,14 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
// Try to detect a pure TermScorer intersection for block-max optimization.
// Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer
// with frequency reading enabled.
if combined_all_scorer_count == 0
&& must_scorers.len() >= 2
&& must_scorers.iter().all(|s| s.is::<TermScorer>())
{
let term_scorers: Vec<TermScorer> = must_scorers
.into_iter()
.map(|s| *(s.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
if term_scorers
.iter()
.all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq)
{
SpecializedScorer::TermIntersection(term_scorers)
} else {
let must_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
let boxed_scorer: Box<dyn Scorer> =
effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
} else {
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
@@ -502,21 +463,15 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
let num_docs = reader.num_docs();
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_scorer(intersection.as_mut(), callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
@@ -530,23 +485,17 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let num_docs = reader.num_docs();
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
@@ -575,9 +524,6 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
super::block_wand_intersection(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}

View File

@@ -1,10 +1,8 @@
mod block_wand_intersection;
mod block_wand_union;
mod block_wand;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand_intersection::block_wand_intersection;
pub(crate) use self::block_wand_union::{block_wand, block_wand_single_scorer};
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_weight::BooleanWeight;

View File

@@ -1,6 +1,6 @@
use crate::docset::DocSet;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
use crate::postings::{FreqReadingOption, Postings, SegmentPostings};
use crate::query::bm25::Bm25Weight;
use crate::query::{Explanation, Scorer};
use crate::{DocId, Score};
@@ -95,21 +95,6 @@ impl TermScorer {
pub fn last_doc_in_block(&self) -> DocId {
self.postings.block_cursor.skip_reader().last_doc_in_block()
}
/// Returns a mutable reference to the underlying block cursor.
pub(crate) fn block_cursor(&mut self) -> &mut BlockSegmentPostings {
&mut self.postings.block_cursor
}
/// Returns a reference to the fieldnorm reader for batch lookups.
pub(crate) fn fieldnorm_reader(&self) -> &FieldNormReader {
&self.fieldnorm_reader
}
/// Returns a reference to the BM25 weight for batch score computation.
pub(crate) fn bm25_weight(&self) -> &Bm25Weight {
&self.similarity_weight
}
}
impl DocSet for TermScorer {