Compare commits

..

6 Commits

Author SHA1 Message Date
Pascal Seitz
d8f4c0b703 chore: Release 0.26.1 2026-04-21 07:30:14 +02:00
Pascal Seitz
386b0a2a68 perf(agg): only measure active parent bucket in composite collect
Same change as 26a589e for SegmentCompositeCollector: get_memory_consumption
summed across all parent_buckets on every block, scaling with outer bucket
cardinality. Pass parent_bucket_id and index the single bucket.
2026-04-21 07:29:35 +02:00
Pascal Seitz
56cd88928d add inline 2026-04-21 07:29:35 +02:00
Pascal Seitz
cb8a2df8b0 agg fix: compute memory consumption only for current bucket 2026-04-21 07:29:35 +02:00
Pascal Seitz
9e63fc5081 chore: Release 2026-03-31 15:10:59 +08:00
Pascal Seitz
d882b34cf8 unbump for release and update Changelog.md 2026-03-31 14:48:43 +08:00
64 changed files with 844 additions and 4505 deletions

View File

@@ -6,8 +6,6 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2
- package-ecosystem: "github-actions"
directory: "/"
@@ -15,5 +13,3 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2

View File

@@ -4,9 +4,6 @@ on:
push:
branches: [main]
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -15,20 +12,16 @@ concurrency:
jobs:
coverage:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
uses: codecov/codecov-action@v3
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View File

@@ -8,9 +8,6 @@ env:
CARGO_TERM_COLOR: always
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -21,13 +18,10 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v4
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal

View File

@@ -1,49 +0,0 @@
name: OpenSSF Scorecard
on:
schedule:
- cron: '0 0 * * 0'
push:
branches:
- main
permissions:
contents: read
jobs:
analysis:
name: Scorecards analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results
id-token: write
steps:
- name: 'Checkout code'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: 'Run analysis'
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
repo_token: ${{ secrets.GITHUB_TOKEN }}
publish_results: true
# Upload the results as artifacts.
- name: 'Upload artifact'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: 'Upload to code-scanning'
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2
with:
sarif_file: results.sarif

View File

@@ -9,9 +9,6 @@ on:
env:
CARGO_TERM_COLOR: always
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -22,27 +19,23 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v4
- name: Install nightly
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
components: rustfmt
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
components: clippy
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: Swatinem/rust-cache@v2
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
@@ -54,7 +47,7 @@ jobs:
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
- uses: actions-rs/clippy-check@v1
with:
toolchain: stable
token: ${{ secrets.GITHUB_TOKEN }}
@@ -64,9 +57,6 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
matrix:
features:
@@ -77,17 +67,17 @@ jobs:
name: test-${{ matrix.features.label}}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: actions/checkout@v4
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |

View File

@@ -1,8 +1,9 @@
Tantivy 0.26.1
================================
## Performance
- Fix quadratic runtime in nested term and composite aggregations: memory accounting scanned all parent buckets on every collect instead of just the current parent (@PSeitz @fulmicoton)
## Bugfixes
- Fix memory consumption accounting in nested term aggregation to only scan the active parent bucket (@PSeitz)
- Fix memory consumption accounting in composite aggregation to only scan the active parent bucket (@PSeitz)
Tantivy 0.26 (Unreleased)
================================

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.26.0"
version = "0.26.1"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]
@@ -65,7 +65,7 @@ tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
datasketches = { version = "0.3.0", features = ["hll"] }
datasketches = "0.2.0"
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -75,7 +75,7 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.17.0"
binggan = "0.15.3"
rand = "0.9"
maplit = "1.0.2"
matches = "0.1.9"
@@ -92,7 +92,7 @@ postcard = { version = "1.0.4", features = [
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
criterion = { version = "0.8", default-features = false }
criterion = { version = "0.5", default-features = false }
[dev-dependencies.fail]
version = "0.5.0"
@@ -201,11 +201,3 @@ harness = false
[[bench]]
name = "regex_all_terms"
harness = false
[[bench]]
name = "query_parser_nested"
harness = false
[[bench]]
name = "intersection_bench"
harness = false

View File

@@ -1,7 +1,6 @@
[![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/)
[![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy)
[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/quickwit-oss/tantivy/badge)](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
[![Join the chat at https://discord.gg/MT27AG5EVE](https://shields.io/discord/908281611840282624?label=chat%20on%20discord)](https://discord.gg/MT27AG5EVE)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy)

View File

@@ -63,8 +63,6 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
@@ -79,12 +77,7 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_100_buckets_with_cardinality_agg);
register!(group, terms_many_with_single_term_order_by_card);
register!(group, terms_many_with_single_term_2_order_by_card);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
@@ -172,52 +165,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a near-1M-cardinality string field.
// Hits the dense (PagedBitset) path: every doc has a unique term,
// so the bucket promotes from FxHashSet shortly into the scan.
fn cardinality_agg_high_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_all_unique_terms"
},
}
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
// values). Stays on the FxHashSet path — the promotion threshold is
// never crossed. Validates no regression on the sparse path.
fn cardinality_agg_low_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_100_buckets_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf", "size": 100 },
"aggs": {
"cardinality": {
"cardinality": {
@@ -230,58 +181,6 @@ fn terms_100_buckets_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many_with_single_term_order_by_card(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"nested_terms": {
"terms": {
"field": "single_term",
"order": { "cardinality": "desc" }
},
"aggs": {
"cardinality": {
"cardinality": { "field": "text_few_terms" }
}
}
}
}
},
});
execute_agg(index, agg_req);
}
// Two-level terms ordered by cardinality at each level: a high-card outer terms
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
let agg_req = json!({
"by_ip": {
"terms": {
"field": "text_many_terms",
"order": { "card_few_terms": "desc" }
},
"aggs": {
"card_few_terms": {
"cardinality": { "field": "text_few_terms" }
},
"nested_terms": {
"terms": {
"field": " single_term",
"order": { "distinct_path2": "desc" }
},
"aggs": {
"avg_botscore": { "avg": { "field": "score" } },
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
}
}
}
}
});
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
@@ -354,30 +253,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_terms_zipf_1000_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"nested_terms": { "terms": { "field": "text_1000_terms_zipf" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_terms_status_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"nested_terms": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -691,8 +566,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let single_term = schema_builder.add_text_field("single_term", FAST);
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
@@ -756,8 +630,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
single_term => "single_term",
single_term => "single_term",
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
@@ -792,7 +664,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),

View File

@@ -1,149 +0,0 @@
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
//
// What's measured:
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
// - Realistic term frequencies (geometric distribution, mostly low)
// - 1M-doc single segment
//
// Run with: cargo bench --bench intersection_bench
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
const NUM_DOCS: usize = 1_000_000;
struct BenchIndex {
searcher: Searcher,
query_parser: QueryParser,
}
/// Generate term frequency from a geometric-like distribution.
/// Most values are 1, a few are 2-3, rarely higher.
/// p controls the decay: higher p → more weight on tf=1.
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
let mut tf = 1u32;
while tf < 10 && rng.random_bool(1.0 - p) {
tf += 1;
}
tf
}
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
/// Each term occurrence has a realistic term frequency (geometric distribution).
/// Field length is padded with filler tokens to create varied fieldnorms.
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut rng = StdRng::from_seed([42u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..NUM_DOCS {
let mut tokens: Vec<String> = Vec::new();
if rng.random_bool(p_a) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("aaa".to_string());
}
}
if rng.random_bool(p_b) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("bbb".to_string());
}
}
if rng.random_bool(p_c) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("ccc".to_string());
}
}
// Pad with filler to create varied field lengths (5-30 tokens).
let filler_count = rng.random_range(5u32..30u32);
for _ in 0..filler_count {
tokens.push("filler".to_string());
}
let text = tokens.join(" ");
writer.add_document(doc!(body => text)).unwrap();
}
writer.commit().unwrap();
}
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![body]);
BenchIndex {
searcher,
query_parser,
}
}
fn main() {
// Scenarios: (label, p_a, p_b, p_c)
//
// "balanced": all terms ~10% → intersection ~1% of docs
// "skewed": one common (50%), one rare (2%) → intersection ~1%
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
// "three_balanced": three terms ~20% each → intersection ~0.8%
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
("balanced_10%_10%", 0.10, 0.10, 0.0),
("skewed_50%_2%", 0.50, 0.02, 0.0),
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
];
let mut runner = BenchRunner::new();
for (label, p_a, p_b, p_c) in &scenarios {
let bench_index = build_index(*p_a, *p_b, *p_c);
let mut group = runner.new_group();
group.set_name(format!("intersection — {label}"));
// Two-term intersection
if *p_a > 0.0 && *p_b > 0.0 {
let query_str = "+aaa +bbb";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
// Three-term intersection
if *p_c > 0.0 {
let query_str = "+aaa +bbb +ccc";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
group.run();
}
}

View File

@@ -1,35 +0,0 @@
// Benchmark for the query grammar parsing deeply nested queries.
//
// Regression guard for https://github.com/quickwit-oss/tantivy/issues/2498:
// at depth 20/21 the old parser took 0.87 s / 1.72 s respectively because
// `ast()` retried `occur_leaf` on backtrack, giving O(2^n) time. With the
// fix parsing is linear and completes in microseconds.
//
// Run with: `cargo bench --bench query_parser_nested`.
use binggan::{black_box, BenchRunner};
use tantivy::query_grammar::parse_query;
fn nested_query(depth: usize, leading_plus: bool) -> String {
let leading = "(".repeat(depth);
let trailing = ")".repeat(depth);
let prefix = if leading_plus { "+" } else { "" };
format!("{prefix}{leading}title:test{trailing}")
}
fn main() {
let mut runner = BenchRunner::new();
for depth in [20, 21] {
for leading_plus in [false, true] {
let query = nested_query(depth, leading_plus);
let label = format!(
"parse_nested_depth_{depth}_{}",
if leading_plus { "plus" } else { "plain" },
);
runner.bench_function(&label, move |_| {
black_box(parse_query(black_box(&query)).unwrap());
});
}
}
}

View File

@@ -18,10 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy"
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
[dev-dependencies]
binggan = "0.17.0"
rand = "0.9"
proptest = "1"
[[bench]]
name = "bench"
harness = false

View File

@@ -1,110 +1,65 @@
use std::cell::RefCell;
#![feature(test)]
use binggan::{BenchRunner, black_box};
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
extern crate test;
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
buffer
}
#[cfg(test)]
mod tests {
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
const N: usize = 100_000;
const MAX_VAL: u64 = 1_000;
const BIT_WIDTH: u8 = 10; // 2^10 = 1024 > MAX_VAL
fn create_packed_data() -> (BitUnpacker, Vec<u8>) {
let mut bitpacker = BitPacker::new();
let mut data = Vec::new();
for i in 0..N as u64 {
let val = i * MAX_VAL / N as u64;
bitpacker.write(val, BIT_WIDTH, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
(BitUnpacker::new(BIT_WIDTH), data)
}
fn bench_bitpacking() {
let mut runner = BenchRunner::new();
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
runner.bench_function("bitpacking_read", move |_| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
#[inline(never)]
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
// the values do not matter.
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
black_box(out);
});
}
fn bench_blocked_bitpacker() {
let mut runner = BenchRunner::new();
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
buffer
}
runner.bench_function("blockedbitp_read", move |_| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
black_box(out);
});
runner.bench_function("blockedbitp_create", |_| {
#[bench]
fn bench_bitpacking_read(b: &mut Bencher) {
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
}
out
});
}
#[bench]
fn bench_blockedbitp_read(b: &mut Bencher) {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
black_box(blocked_bitpacker);
});
}
fn bench_filter_vec() {
let mut runner = BenchRunner::new();
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_dense", move |_| {
unpacker.get_ids_for_value_range(
250..=750,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_sparse", move |_| {
unpacker.get_ids_for_value_range(0..=50, 0..N as u32, &data, &mut positions.borrow_mut());
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_full", move |_| {
unpacker.get_ids_for_value_range(
0..=MAX_VAL,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
}
fn main() {
bench_bitpacking();
bench_blocked_bitpacker();
bench_filter_vec();
b.iter(|| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
out
});
}
#[bench]
fn bench_blockedbitp_create(b: &mut Bencher) {
b.iter(|| {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
blocked_bitpacker
});
}
}

View File

@@ -1,17 +1,8 @@
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
use std::arch::is_aarch64_feature_detected;
use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "aarch64")]
mod neon;
// SVE intrinsics are not exposed on aarch64-apple-darwin.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
mod sve;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
@@ -19,10 +10,6 @@ mod scalar;
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
Sve = 3u8,
#[cfg(target_arch = "aarch64")]
Neon = 2u8,
Scalar = 1u8,
}
@@ -32,57 +19,29 @@ impl FilterImplPerInstructionSet {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::Sve => is_aarch64_feature_detected!("sve"),
// TIL Neon is required on aarch 64.
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => true,
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementations in preferred order.
// List of available implementation in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
// Non-Apple aarch64: try SVE, NEON, Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Apple aarch64 (M-series): SVE not available; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
#[cfg(not(target_arch = "x86_64"))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[inline]
#[allow(unused_variables)]
#[allow(unused_variables)] // on non-x86_64, code is unused.
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
if code == FilterImplPerInstructionSet::Sve as u8 {
return FilterImplPerInstructionSet::Sve;
}
#[cfg(target_arch = "aarch64")]
if code == FilterImplPerInstructionSet::Neon as u8 {
return FilterImplPerInstructionSet::Neon;
}
FilterImplPerInstructionSet::Scalar
}
@@ -91,10 +50,6 @@ impl FilterImplPerInstructionSet {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::Sve => sve::filter_vec_in_place(range, offset, output),
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => neon::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
@@ -102,12 +57,6 @@ impl FilterImplPerInstructionSet {
}
}
fn available_impls() -> impl Iterator<Item = FilterImplPerInstructionSet> {
IMPLS
.into_iter()
.filter(FilterImplPerInstructionSet::is_available)
}
#[inline]
fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
use std::sync::atomic::{AtomicU8, Ordering};
@@ -115,7 +64,10 @@ fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = available_impls().next().unwrap();
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
.unwrap();
INSTRUCTION_SET_BYTE.store(instruction_set as u8, Ordering::Relaxed);
return instruction_set;
}
@@ -128,12 +80,12 @@ pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut
#[cfg(test)]
mod tests {
use proptest::strategy::Strategy;
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
@@ -150,31 +102,6 @@ mod tests {
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
@@ -199,20 +126,11 @@ mod tests {
assert_eq!(&output, &[1, 3, 4, 5, 6, 7, 8]);
}
fn test_filter_impl_empty_range_aux(filter_impl: FilterImplPerInstructionSet) {
// start > end: RangeInclusive::contains always returns false; output must be empty.
// The SVE path's wrapping_sub would otherwise produce a huge range_width.
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(10..=5, 0, &mut output);
assert_eq!(&output, &[]);
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
test_filter_impl_empty_range_aux(filter_impl);
}
#[test]
@@ -223,59 +141,25 @@ mod tests {
}
}
#[test]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn test_filter_implementation_sve() {
if FilterImplPerInstructionSet::Sve.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Sve);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_filter_implementation_neon() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Neon);
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
}
fn max_val_strategy() -> impl proptest::strategy::Strategy<Value = u32> {
proptest::prop_oneof![
0u32..10u32,
255u32..258u32,
proptest::prelude::Just(1u32 << 25),
proptest::prelude::Just(u32::MAX - 1),
proptest::prelude::Just(u32::MAX),
]
}
fn vals_strategy() -> impl proptest::strategy::Strategy<Value = Vec<u32>> {
proptest::prop_oneof![
proptest::collection::vec(proptest::prelude::any::<u32>(), 0..300),
max_val_strategy()
.prop_flat_map(|max_val| { proptest::collection::vec(0..=max_val, 0..300) })
]
}
#[cfg(target_arch = "x86_64")]
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_impls_impl_proptest(
start in 0u32..400u32,
end in 0u32..400u32,
fn test_filter_compare_scalar_and_avx2_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
offset in 0u32..2u32,
mut vals in vals_strategy()) {
for implementation in available_impls() {
if implementation == FilterImplPerInstructionSet::Scalar {
continue;
}
let mut vals_clone = vals.clone();
implementation.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
if FilterImplPerInstructionSet::AVX2.is_available() {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::AVX2.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
}
}
}

View File

@@ -1,113 +0,0 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
const NUM_LANES: usize = 4;
// Compacts matching lanes to the front using a byte-level shuffle.
// `mask` is a 4-bit value: bit k=1 means lane k should appear in the output.
#[inline]
#[target_feature(enable = "neon")]
unsafe fn compact(data: uint32x4_t, mask: u8) -> uint32x4_t {
unsafe {
// SAFETY: mask is always in [0, 15] by construction (max sum of [1,2,4,8]).
// BYTE_SHUFFLE_TABLE has 16 entries, so this is always in bounds.
let shuffle = BYTE_SHUFFLE_TABLE.get_unchecked(mask as usize);
let shuffle_vec = vld1q_u8(shuffle.as_ptr());
vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(data), shuffle_vec))
}
}
#[inline(never)]
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_neon_aux(
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let remainder_start = num_words * NUM_LANES;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "neon")]
unsafe fn filter_vec_neon_aux(
input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
unsafe {
let mut input = input;
let mut output_tail = output;
let range_start_simd = vdupq_n_u32(*range.start());
let range_end_simd = vdupq_n_u32(*range.end());
let mut ids = vld1q_u32([offset, offset + 1, offset + 2, offset + 3].as_ptr());
let shift = vdupq_n_u32(NUM_LANES as u32);
let bit_weights = vld1q_u32([1u32, 2, 4, 8].as_ptr());
for _ in 0..num_words {
let word = vld1q_u32(input);
// Unsigned compares: CMHS (compare higher or same) tests `word >= start`
// and `end >= word`. ANDing both gives the inside-range mask directly,
// which is cheaper than computing `outside` and then negating.
let ge_start = vcgeq_u32(word, range_start_simd);
let le_end = vcleq_u32(word, range_end_simd);
// inside[k] = 0xFFFFFFFF if val[k] is in range, 0 otherwise.
let inside = vandq_u32(ge_start, le_end);
// Build the 4-bit mask: AND bit_weights with the inside lane mask, so each
// inside lane contributes its bit_weight (1, 2, 4, or 8). Summing yields the
// 4-bit mask in one addv.
let inside_bits = vandq_u32(bit_weights, inside);
let mask = vaddvq_u32(inside_bits) as u8;
// mask is mathematically bounded: max value is 1+2+4+8=15 (all lanes match)
debug_assert!(mask <= 15, "mask must fit in 4 bits: {}", mask);
// Count of matching lanes = popcount(mask). Derives the count directly from
// the mask instead of running a parallel SIMD reduction over `outside`.
let added_len = mask.count_ones() as usize;
// Safe because mask is guaranteed to be in [0, 15]
let filtered_ids = compact(ids, mask);
vst1q_u32(output_tail, filtered_ids);
output_tail = output_tail.add(added_len);
ids = vaddq_u32(ids, shift);
input = input.add(NUM_LANES);
}
output_tail.offset_from(output) as usize
}
}
// Byte shuffle patterns to compact matching lanes to the front of the vector.
// Index is a 4-bit mask: bit k=1 means lane k (bytes 4k..4k+3) is in-range.
// The j-th set bit determines which input lane goes to output position j.
const BYTE_SHUFFLE_TABLE: [[u8; 16]; 16] = [
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0000: none
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0001: lane 0
[4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0010: lane 1
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0011: lanes 0,1
[8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0100: lane 2
[0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0101: lanes 0,2
[4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0110: lanes 1,2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3], // 0b0111: lanes 0,1,2
[12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1000: lane 3
[0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1001: lanes 0,3
[4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1010: lanes 1,3
[0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1011: lanes 0,1,3
[8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1100: lanes 2,3
[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1101: lanes 0,2,3
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1110: lanes 1,2,3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 0b1111: all lanes
];

View File

@@ -1,258 +0,0 @@
use std::ops::RangeInclusive;
// SVE vector length (in u32 lanes) is not a compile-time constant; query at runtime.
// Safe to call only when SVE is confirmed available via is_aarch64_feature_detected!("sve").
#[target_feature(enable = "sve")]
unsafe fn num_lanes() -> usize {
let vl: usize;
unsafe {
core::arch::asm!(
"cntw {vl}",
vl = out(reg) vl,
options(nostack, nomem, preserves_flags),
);
}
vl
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
if range.start() > range.end() {
output.clear();
return;
}
let vl = unsafe { num_lanes() };
let num_words = output.len() / vl;
let range_start = *range.start();
// Unsigned subtraction trick: val ∈ [lo, hi] ↔ (val - lo) ≤ᵤ (hi - lo).
// Values below lo wrap around to large u32, so the single unsigned ≤ excludes them.
let range_width = range.end().wrapping_sub(range_start);
let mut output_len = unsafe {
filter_vec_sve_aux(
output.as_ptr(),
range_start,
range_width,
output.as_mut_ptr(),
offset,
num_words,
vl,
)
};
let remainder_start = num_words * vl;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
// Register allocation for the asm! blocks:
// z0 ids_a (index vector for first half of each pair, advances by step2 each iter)
// z1 range_width broadcast
// z2 range_start broadcast
// z3 step2 broadcast (2 * vl)
// z4 ids_b (index vector for second half, = ids_a + step, advances by step2)
// z5 scratch: loaded word_a, then compacted_a
// z6 scratch: loaded word_b, then compacted_b
// p0 all-true predicate (ptrue p0.s)
// p1 in-range mask for word_a
// p2 in-range mask for word_b
#[target_feature(enable = "sve")]
unsafe fn filter_vec_sve_aux(
input: *const u32,
range_start: u32,
range_width: u32,
output: *mut u32,
offset: u32,
num_words: usize,
vl: usize,
) -> usize {
let num_pairs = num_words / 2;
let mut input_ptr = input;
let mut output_tail = output;
if num_pairs > 0 {
unsafe {
// We rely on asm! because the SVE intrinsics are not available in stable Rust.
// The code that follows was generated by Rustc nightly based on the intrinsics version
// at the bottom of this file.
core::arch::asm!(
// --- Setup ---
// All-true predicate for 32-bit lanes.
"ptrue p0.s",
// ids_a = [offset, offset+1, offset+2, ...]
"index z0.s, {offset:w}, #1",
// Broadcast scalars into SVE vectors.
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
// vl_gpr = number of 32-bit lanes (cntw).
"cntw {vl_gpr}",
// step2_bytes will first hold 2*vl (for the step2 vector), then 2*VL in bytes.
"lsl {step2_bytes}, {vl_gpr}, #1",
// z4 = step = [vl, vl, ...]; will become ids_b after the add below.
"mov z4.s, {vl_gpr:w}",
// z3 = step2 = [2*vl, 2*vl, ...], used to advance both id vectors each iter.
"mov z3.s, {step2_bytes:w}",
// Repurpose step2_bytes to hold the byte stride for advancing the input pointer
// by two full SVE vectors per iteration.
"rdvl {step2_bytes}, #2",
// ids_b = ids_a + step = [offset+vl, offset+vl+1, ...]
"add z4.s, z0.s, z4.s",
// --- Main loop: process two SVE vectors (ids_a and ids_b) per iteration ---
"0:",
// Load two consecutive SVE vectors from input.
"ld1w {{z5.s}}, p0/z, [{input}]",
"ld1w {{z6.s}}, p0/z, [{input}, #1, mul vl]",
// Advance input pointer by 2 * VL bytes.
"add {input}, {input}, {step2_bytes}",
// Unsigned shift: subtract range_start so in-range check becomes a single cmpu ≤.
"sub z5.s, z5.s, z2.s",
"sub z6.s, z6.s, z2.s",
// in_range: shifted value ≤ range_width (unsigned, so values below lo also fail).
"cmphs p1.s, p0/z, z1.s, z5.s",
"cmphs p2.s, p0/z, z1.s, z6.s",
// Count matching lanes; both cntp calls have independent inputs for OOO parallelism.
"cntp {cnt_a}, p0, p1.s",
"compact z5.s, p1, z0.s",
"compact z6.s, p2, z4.s",
"cntp {cnt_b}, p0, p2.s",
// Advance id vectors for the next iteration.
"add z0.s, z0.s, z3.s",
"add z4.s, z4.s, z3.s",
// Store compacted ids. Only the first cnt_a / cnt_b slots are valid; the rest
// will be overwritten by subsequent iterations before the final truncate.
"str z5, [{out}]",
"st1w {{z6.s}}, p0, [{out}, {cnt_a}, lsl #2]",
"add {out}, {out}, {cnt_a}, lsl #2",
"add {out}, {out}, {cnt_b}, lsl #2",
"subs {pairs}, {pairs}, #1",
"b.ne 0b",
// --- Operands ---
input = inout(reg) input_ptr,
out = inout(reg) output_tail,
pairs = inout(reg) num_pairs => _,
offset = in(reg) offset,
range_start = in(reg) range_start,
range_width = in(reg) range_width,
vl_gpr = out(reg) _,
step2_bytes = out(reg) _,
cnt_a = out(reg) _,
cnt_b = out(reg) _,
out("p0") _, out("p1") _, out("p2") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _,
options(nostack),
);
}
}
// Handle an odd trailing vector.
if num_words % 2 == 1 {
// ids_a for the odd word starts at offset + num_pairs * 2 * vl.
// input_ptr was advanced by the main loop and now points at the odd word.
let odd_offset =
offset.wrapping_add((num_pairs as u32).wrapping_mul(2).wrapping_mul(vl as u32));
unsafe {
core::arch::asm!(
"ptrue p0.s",
"index z0.s, {odd_offset:w}, #1",
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
"ld1w {{z3.s}}, p0/z, [{input}]",
"sub z3.s, z3.s, z2.s",
"cmphs p1.s, p0/z, z1.s, z3.s",
"cntp {cnt}, p0, p1.s",
"compact z0.s, p1, z0.s",
"str z0, [{out}]",
"add {out}, {out}, {cnt}, lsl #2",
odd_offset = in(reg) odd_offset,
range_width = in(reg) range_width,
range_start = in(reg) range_start,
input = in(reg) input_ptr,
out = inout(reg) output_tail,
cnt = out(reg) _,
out("p0") _, out("p1") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
options(nostack),
);
}
}
unsafe { output_tail.offset_from(output) as usize }
}
// SVE implements with intrinsics.
//
// #[target_feature(enable = "sve")]
// unsafe fn filter_vec_sve_aux(
// input: *const u32,
// range_start: u32,
// range_width: u32,
// output: *mut u32,
// offset: u32,
// num_words: usize,
// vl: usize,
// ) -> usize {
// unsafe {
// let all_true = svptrue_b32();
// let range_start_simd = svdup_n_u32(range_start);
// let range_width_simd = svdup_n_u32(range_width);
// // ids_a covers [offset .. offset+vl), ids_b covers the next vl ids.
// // Keeping them separate breaks the loop-carried dependency through ids so
// // both compact/cntp chains are fully independent within each unrolled body.
// let mut ids_a = svindex_u32(offset, 1);
// let step = svdup_n_u32(vl as u32);
// let step2 = svdup_n_u32(2 * vl as u32);
// let mut ids_b = svadd_u32_x(all_true, ids_a, step);
// let mut input = input;
// let mut output_tail = output;
// // Unrolled ×2: both cntp calls have independent inputs and execute in parallel.
// // The two output_tail updates are sequential but together cost 4+1+1=6 cy per
// // pair vs 5+5=10 cy for two scalar iterations, breaking the cntp latency chain.
// let num_pairs = num_words / 2;
// for _ in 0..num_pairs {
// let word_a = svld1_u32(all_true, input);
// let word_b = svld1_u32(all_true, input.add(vl));
// let shifted_a = svsub_u32_x(all_true, word_a, range_start_simd);
// let shifted_b = svsub_u32_x(all_true, word_b, range_start_simd);
// let in_range_a = svcmple_u32(all_true, shifted_a, range_width_simd);
// let in_range_b = svcmple_u32(all_true, shifted_b, range_width_simd);
// let compacted_a = svcompact_u32(in_range_a, ids_a);
// let compacted_b = svcompact_u32(in_range_b, ids_b);
// // cntp_a and cntp_b have independent inputs: OOO engine issues them in parallel.
// let added_len_a = svcntp_b32(all_true, in_range_a) as usize;
// let added_len_b = svcntp_b32(all_true, in_range_b) as usize;
// // Write the full vector — only the first added_len slots are valid.
// // Subsequent iterations overwrite the trailing zeros before truncate.
// svst1_u32(all_true, output_tail, compacted_a);
// output_tail = output_tail.add(added_len_a);
// svst1_u32(all_true, output_tail, compacted_b);
// output_tail = output_tail.add(added_len_b);
// ids_a = svadd_u32_x(all_true, ids_a, step2);
// ids_b = svadd_u32_x(all_true, ids_b, step2);
// input = input.add(2 * vl);
// }
// // Handle an odd trailing word.
// if num_words % 2 == 1 {
// let word = svld1_u32(all_true, input);
// let shifted = svsub_u32_x(all_true, word, range_start_simd);
// let in_range = svcmple_u32(all_true, shifted, range_width_simd);
// let added_len = svcntp_b32(all_true, in_range) as usize;
// let compacted_ids = svcompact_u32(in_range, ids_a);
// svst1_u32(all_true, output_tail, compacted_ids);
// output_tail = output_tail.add(added_len);
// }
// output_tail.offset_from(output) as usize
// }
// }

View File

@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
binggan = "0.17.0"
binggan = "0.15.3"
[[bench]]
name = "bench_merge"

View File

@@ -33,14 +33,14 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing_opt: Option<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing_opt else {
let Some(missing) = missing else {
return;
};
@@ -191,7 +191,6 @@ where F: FnMut(u32) {
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;

View File

@@ -19,6 +19,6 @@ time = { version = "0.3.47", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.17.0"
binggan = "0.15.3"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -47,9 +47,6 @@ impl TinySet {
TinySet(val)
}
/// An empty `TinySet` constant.
pub const EMPTY: TinySet = TinySet(0u64);
/// Returns an empty `TinySet`.
#[inline]
pub fn empty() -> TinySet {

View File

@@ -1045,43 +1045,18 @@ fn operand_leaf(inp: &str) -> IResult<&str, (Option<BinaryOperand>, Option<Occur
}
fn ast(inp: &str) -> IResult<&str, UserInputAst> {
// Parse `occur_leaf` once, then conditionally extend into a boolean
// expression. The previous implementation used `alt((boolean_expr,
// single_leaf))` which, when the input was a single leaf with no
// following operand, would parse `occur_leaf` once for `boolean_expr`,
// fail at `multispace1`, backtrack, then re-parse `occur_leaf` for
// `single_leaf`. With recursively-nested groups like `(+(+(+a)))`, that
// doubling at every level produced O(2^n) parse time. Parsing once and
// peeking ahead for the operand keeps it O(n).
delimited(
multispace0,
|inp| {
let (rest, first) = occur_leaf(inp)?;
// Only fall back on `Err::Error` (recoverable), mirroring
// `alt`'s behaviour. `Err::Failure` and `Err::Incomplete`
// must propagate so cut points and streaming needs are not
// accidentally swallowed if they are ever introduced in the
// operand parsers.
match preceded(multispace1, many1(operand_leaf))(rest) {
Ok((rest, more)) => {
let combined = aggregate_binary_expressions(first, more)
.map_err(|_| nom::Err::Error(Error::new(inp, ErrorKind::MapRes)))?;
Ok((rest, combined))
}
Err(nom::Err::Error(_)) => {
let (occur, ast) = first;
let single = if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
};
Ok((rest, single))
}
Err(e) => Err(e),
}
},
multispace0,
)(inp)
let boolean_expr = map_res(
separated_pair(occur_leaf, multispace1, many1(operand_leaf)),
|(left, right)| aggregate_binary_expressions(left, right),
);
let single_leaf = map(occur_leaf, |(occur, ast)| {
if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
}
});
delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp)
}
fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> {
@@ -1916,23 +1891,4 @@ mod test {
r#"(+"field":'happy tax payer' +"other_field":1)"#,
);
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2498:
// deeply nested parenthesized queries used to take O(2^n) time because the
// top-level `ast()` parser tried `boolean_expr` first and re-parsed the
// inner `occur_leaf` when it backtracked to `single_leaf`. Depth 60 would
// take ~10^18 operations under the regression; with the fix it parses
// instantly. We use `test_parse_query_to_ast_helper` so this test would
// never finish if the regression returned.
#[test]
fn test_parse_deeply_nested_query() {
let depth = 60;
let leading: String = "(".repeat(depth);
let trailing: String = ")".repeat(depth);
let query = format!("{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query, r#""title":test"#);
let query_with_plus = format!("+{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query_with_plus, r#""title":test"#);
}
}

View File

@@ -20,8 +20,8 @@ use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
@@ -413,38 +413,12 @@ pub(crate) fn build_segment_agg_collector(
}
AggKind::Cardinality => {
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
// For str columns, choose the per-bucket entries representation
// based on the segment's column.max_value():
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
// For non-str columns the `entries` field is unused (values go
// straight into the HLL sketch); we still pick `TermOrdSet`
// because its empty Sparse(FxHashSet) costs nothing.
let is_str = req_data.column_type == ColumnType::Str;
let max_term_ord_inclusive = if is_str {
req_data.accessor.max_value()
} else {
0
};
let collector: Box<dyn SegmentAggregationCollector> =
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
} else {
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
};
Ok(collector)
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
@@ -1011,12 +985,8 @@ fn build_terms_or_cardinality_nodes(
let str_col = str_dict_column
.as_ref()
.expect("str_dict_column must exist for string column");
allowed_term_ids = build_allowed_term_ids_for_str(
str_col,
&req.include,
&req.exclude,
missing.is_some(),
)?;
allowed_term_ids =
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
};
let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
accessor,
@@ -1032,20 +1002,10 @@ fn build_terms_or_cardinality_nodes(
(idx_in_req_data, AggKind::Terms)
}
TermsOrCardinalityRequest::Cardinality(ref req) => {
// `str_dict_column` is computed once per field; for JSON paths
// with mixed types it's `Some` even on the numeric req_data.
// Cardinality only consults it for the str column path, so
// gate by column_type to avoid driving non-str collectors
// through the coupon-cache path.
let str_dict_column_for_req = if column_type == ColumnType::Str {
str_dict_column.clone()
} else {
None
};
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
accessor,
column_type,
str_dict_column: str_dict_column_for_req,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
name: agg_name.to_string(),
req: req.clone(),
@@ -1065,21 +1025,16 @@ fn build_terms_or_cardinality_nodes(
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
///
/// When `reserve_missing_sentinel` is true, the bitset will have 1 additional slot for the missing
/// term ordinal
fn build_allowed_term_ids_for_str(
str_col: &StrColumn,
include: &Option<IncludeExcludeParam>,
exclude: &Option<IncludeExcludeParam>,
reserve_missing_sentinel: bool,
) -> crate::Result<Option<BitSet>> {
let mut allowed: Option<BitSet> = None;
let missing_sentinel_adjustment = if reserve_missing_sentinel { 1 } else { 0 };
let allowed_capacity = str_col.dictionary().num_terms() as u32 + missing_sentinel_adjustment;
let num_terms = str_col.dictionary().num_terms() as u32;
if let Some(include) = include {
// add matches
allowed = Some(BitSet::with_max_value(allowed_capacity));
allowed = Some(BitSet::with_max_value(num_terms));
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
};
@@ -1087,7 +1042,7 @@ fn build_allowed_term_ids_for_str(
if let Some(exclude) = exclude {
if allowed.is_none() {
// Start with all terms allowed
allowed = Some(BitSet::with_max_value_and_full(allowed_capacity));
allowed = Some(BitSet::with_max_value_and_full(num_terms));
}
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;

View File

@@ -115,71 +115,6 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Validates that all fields referenced in the aggregation request exist in the schema
/// and are configured as fast fields.
///
/// This is a convenience function for upfront validation before executing aggregations.
/// Returns an error if any field doesn't exist or is not a fast field.
///
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
/// default lenient behavior (returning empty results for missing fields) supports
/// schema evolution and federated queries where the same request runs against segments
/// or indices with different schemas.
///
/// # Example
/// ```
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
/// use tantivy::schema::{Schema, FAST};
/// use tantivy::Index;
///
/// # fn main() -> tantivy::Result<()> {
/// // Create a simple index
/// let mut schema_builder = Schema::builder();
/// schema_builder.add_f64_field("price", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// // Parse aggregation request
/// let agg_req: Aggregations = serde_json::from_str(r#"{
/// "avg_price": { "avg": { "field": "price" } }
/// }"#)?;
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Validate fields before executing
/// for segment_reader in searcher.segment_readers() {
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
/// }
/// # Ok(())
/// # }
/// ```
pub fn validate_aggregation_fields_exist(
aggs: &Aggregations,
reader: &crate::SegmentReader,
) -> crate::Result<()> {
let field_names = get_fast_field_names(aggs);
let schema = reader.schema();
for field_name in field_names {
// Check if the field is either directly in the schema or could be part of a json field
// present in the schema, and verify it's a fast field.
if let Some((field, _path)) = schema.find_field(&field_name) {
let field_type = schema.get_field_entry(field).field_type();
if !field_type.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field '{}' is not a fast field. Aggregations require fast fields.",
field_name
)));
}
} else {
return Err(crate::TantivyError::FieldNotFound(field_name));
}
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {

View File

@@ -208,8 +208,7 @@ pub enum BucketEntries<T> {
}
impl<T> BucketEntries<T> {
/// Iterate over all bucket entries.
pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),

View File

@@ -1436,46 +1436,3 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
)
);
}
#[test]
fn test_aggregation_field_validation_helper() {
// Test the standalone validation helper function for field validation
let index = get_test_index_2_segments(false).unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
// Test with invalid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "nonexistent_field" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_err());
match result {
Err(crate::TantivyError::FieldNotFound(field_name)) => {
assert_eq!(field_name, "nonexistent_field");
}
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
}
// Test with valid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "score" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_ok());
}

View File

@@ -21,7 +21,7 @@ use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardSubAggBuffer};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardSubAggCache};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
@@ -119,7 +119,7 @@ pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<HighCardSubAggBuffer>>,
sub_agg: Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
num_sources: usize,
@@ -199,17 +199,6 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Composite is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentCompositeCollector {
@@ -226,7 +215,7 @@ impl SegmentCompositeCollector {
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_agg_collector = build_segment_agg_collectors(req_data, &node.children)?;
Some(BufferedSubAggs::new(sub_agg_collector))
Some(CachedSubAggs::new(sub_agg_collector))
} else {
None
};
@@ -340,7 +329,7 @@ fn collect_bucket_with_limit(
limit_num_buckets: usize,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
sub_agg: &mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &mut BucketIdProvider,
) {
let mut record_in_bucket = |bucket: &mut CompositeBucketCollector| {
@@ -496,7 +485,7 @@ struct CompositeKeyVisitor<'a> {
doc_id: crate::DocId,
composite_agg_data: &'a CompositeAggReqData,
buckets: &'a mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
sub_agg: &'a mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
sub_agg: &'a mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &'a mut BucketIdProvider,
sub_level_values: SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
}

View File

@@ -511,14 +511,14 @@ mod tests {
fn datetime_from_iso_str(date_str: &str) -> common::DateTime {
let dt = OffsetDateTime::parse(date_str, &Rfc3339)
.unwrap_or_else(|_| panic!("Failed to parse date: {}", date_str));
.expect(&format!("Failed to parse date: {}", date_str));
let timestamp_secs = dt.unix_timestamp_nanos();
common::DateTime::from_timestamp_nanos(timestamp_secs as i64)
}
fn ms_timestamp_from_iso_str(date_str: &str) -> i64 {
let dt = OffsetDateTime::parse(date_str, &Rfc3339)
.unwrap_or_else(|_| panic!("Failed to parse date: {}", date_str));
.expect(&format!("Failed to parse date: {}", date_str));
(dt.unix_timestamp_nanos() / 1_000_000) as i64
}
@@ -548,7 +548,7 @@ mod tests {
agg_req_json["my_composite"]["composite"]["after"] = after_key.take().unwrap();
}
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
let expected_page_buckets = &expected_buckets_vec[page_idx * page_size
..std::cmp::min((page_idx + 1) * page_size, expected_buckets_vec.len())];
assert_eq!(
@@ -559,30 +559,34 @@ mod tests {
page_size,
agg_req,
);
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on every non-empty page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
}
// Using the after_key from the last page must yield an empty page.
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": after_key,
}
if page_idx + 1 < page_count {
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on all but last page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
} else if res["my_composite"].get("after_key").is_some() {
// currently we sometime have an after_key on the last page,
// check that the next "page" is empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": res["my_composite"]["after_key"].clone(),
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
}
}
@@ -707,28 +711,8 @@ mod tests {
{"key": {"myterm": "terme"}, "doc_count": 1}
])
);
// paginating past last page should be empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": [
{"myterm": {"terms": {"field": "string_id"}}}
],
"size": 3,
"after": &res["my_composite"]["after_key"]
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert!(res["my_composite"].get("after_key").is_none());
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
Ok(())
}
@@ -836,10 +820,7 @@ mod tests {
{"key": {"myterm": "apple"}, "doc_count": 1}
])
);
assert_eq!(
res["fruity_aggreg"]["after_key"],
json!({"myterm": "str:apple"})
);
assert!(res["fruity_aggreg"].get("after_key").is_none());
Ok(())
}
@@ -1811,14 +1792,7 @@ mod tests {
{"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1},
]),
);
let feb_2021_ns = ms_timestamp_from_iso_str("2021-02-01T00:00:00Z") * 1_000_000;
assert_eq!(
res["my_composite"]["after_key"],
json!({
"month": format!("dt:{}", feb_2021_ns),
"category": "str:books"
})
);
assert!(res["my_composite"].get("after_key").is_none());
Ok(())
}

View File

@@ -6,8 +6,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardSubAggBuffer, SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -503,17 +503,17 @@ struct DocCount {
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<B: SubAggBuffer> {
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<BufferedSubAggs<B>>,
sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl<B: SubAggBuffer> SegmentFilterCollector<B> {
impl<C: SubAggCache> SegmentFilterCollector<C> {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
@@ -525,7 +525,7 @@ impl<B: SubAggBuffer> SegmentFilterCollector<B> {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(BufferedSubAggs::new);
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
@@ -547,16 +547,16 @@ pub(crate) fn build_segment_filter_collector(
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggBuffer>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggBuffer>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
@@ -566,7 +566,7 @@ impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
}
}
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B> {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -674,17 +674,6 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
None
}
}
/// Intermediate result for filter aggregation

View File

@@ -10,7 +10,7 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
@@ -258,7 +258,7 @@ pub(crate) struct SegmentHistogramBucketEntry {
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardBufferedSubAggs>,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
@@ -283,11 +283,6 @@ impl SegmentHistogramBucketEntry {
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
impl HistogramBuckets {
fn memory_consumption(&self) -> u64 {
self.buckets.capacity() as u64 * std::mem::size_of::<SegmentHistogramBucketEntry>() as u64
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
@@ -296,7 +291,7 @@ pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardBufferedSubAggs>,
sub_agg: Option<HighCardCachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
@@ -329,7 +324,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
@@ -363,9 +358,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
@@ -394,24 +392,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Histogram is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
@@ -456,7 +444,7 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(BufferedSubAggs::new);
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
parent_buckets: Default::default(),

View File

@@ -9,9 +9,8 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -156,13 +155,13 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<B: SubAggBuffer> {
pub struct SegmentRangeCollector<C: SubAggCache> {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<B>>,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
@@ -179,7 +178,7 @@ pub struct SegmentRangeCollector<B: SubAggBuffer> {
limits: AggregationLimitsGuard,
}
impl<B: SubAggBuffer> Debug for SegmentRangeCollector<B> {
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
@@ -230,7 +229,7 @@ impl SegmentRangeBucketEntry {
}
}
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -328,17 +327,6 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Range is a multi-bucket agg with no single value to extract.
None
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
@@ -362,8 +350,8 @@ pub(crate) fn build_segment_range_collector(
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggBuffer> {
sub_agg: sub_agg.map(LowCardBufferedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
@@ -371,8 +359,8 @@ pub(crate) fn build_segment_range_collector(
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggBuffer> {
sub_agg: sub_agg.map(BufferedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
@@ -382,7 +370,7 @@ pub(crate) fn build_segment_range_collector(
}
}
impl<B: SubAggBuffer> SegmentRangeCollector<B> {
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -566,7 +554,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggBuffer> {
) -> SegmentRangeCollector<HighCardSubAggCache> {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,

View File

@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
@@ -16,9 +17,8 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -352,15 +352,19 @@ pub(crate) fn build_segment_term_collector(
)));
}
// Validate that the referenced sub-aggregation exists when ordering by one.
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric sub_aggregations"
))
})?;
// Validate sub aggregation exists when ordering by sub-aggregation.
{
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
}
// Build sub-aggregation blueprint if there are children.
@@ -387,7 +391,7 @@ pub(crate) fn build_segment_term_collector(
// Decide which bucket storage is best suited for this aggregation.
if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider);
let collector: SegmentTermCollector<_, HighCardSubAggBuffer> = SegmentTermCollector {
let collector: SegmentTermCollector<_, HighCardSubAggCache> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg: None,
bucket_id_provider,
@@ -397,8 +401,8 @@ pub(crate) fn build_segment_term_collector(
Ok(Box::new(collector))
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider);
let sub_agg = sub_agg_collector.map(LowCardBufferedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggBuffer> = SegmentTermCollector {
let sub_agg = sub_agg_collector.map(LowCardCachedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggCache> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
@@ -410,8 +414,8 @@ pub(crate) fn build_segment_term_collector(
let term_buckets: PagedTermMap =
PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider);
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggBuffer> =
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggCache> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
@@ -423,8 +427,8 @@ pub(crate) fn build_segment_term_collector(
} else {
let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default();
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggBuffer> =
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggCache> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
@@ -754,10 +758,10 @@ impl TermAggregationMap for VecTermBuckets {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
struct SegmentTermCollector<TermMap: TermAggregationMap, B: SubAggBuffer> {
struct SegmentTermCollector<TermMap: TermAggregationMap, C: SubAggCache> {
/// The buckets containing the aggregation data.
parent_buckets: Vec<TermMap>,
sub_agg: Option<BufferedSubAggs<B>>,
sub_agg: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
max_term_id: u64,
terms_req_data: TermsAggReqData,
@@ -768,8 +772,8 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
(agg_name, agg_property)
}
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
for SegmentTermCollector<TermMap, B>
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
for SegmentTermCollector<TermMap, C>
{
fn add_intermediate_aggregation_result(
&mut self,
@@ -786,14 +790,8 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
let term_req = &self.terms_req_data;
let name = term_req.name.clone();
let bucket = Self::into_intermediate_bucket_result(
term_req,
self.sub_agg
.as_mut()
.map(BufferedSubAggs::get_sub_agg_collector),
bucket,
agg_data,
)?;
let bucket =
Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
@@ -883,17 +881,6 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Terms is a multi-bucket agg with no single value to extract.
None
}
}
/// Missing value are represented as a sentinel value in the column.
@@ -920,38 +907,10 @@ fn extract_missing_value<T>(
Some((key, bucket))
}
fn reborrow_opt_collector<'a>(
opt: &'a mut Option<&mut dyn SegmentAggregationCollector>,
) -> Option<&'a mut dyn SegmentAggregationCollector> {
match opt {
Some(inner) => Some(*inner),
None => None,
}
}
fn into_intermediate_bucket_entry(
bucket: Bucket,
sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateTermBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg_collector) = sub_agg_collector {
sub_agg_collector.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
}
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
}
impl<TermMap, B> SegmentTermCollector<TermMap, B>
impl<TermMap, C> SegmentTermCollector<TermMap, C>
where
TermMap: TermAggregationMap,
B: SubAggBuffer,
C: SubAggCache,
{
#[inline]
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> usize {
@@ -961,12 +920,15 @@ where
#[inline]
pub(crate) fn into_intermediate_bucket_result(
term_req: &TermsAggReqData,
mut sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
sub_agg: &mut Option<CachedSubAggs<C>>,
term_buckets: TermMap,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
match &term_req.req.order.target {
OrderTarget::Key => {
// We rely on the fact, that term ordinals match the order of the strings
@@ -978,37 +940,10 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
}
}
OrderTarget::SubAggregation(sub_agg_path) => {
// Peek segment-level metric values, sort, then fall through to
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
// by a sub-agg: top-K results are approximate and may differ from the
// global ordering, especially for non-monotonic metrics like avg/min.
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"Could not find sub-aggregation collector for path {sub_agg_path}"
))
})?;
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
// Fetch values up-front; otherwise sort would re-compute per comparison
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
.into_iter()
.map(|bucket| {
let metric_value = coll
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
.unwrap_or(0.0);
(metric_value, bucket)
})
.collect();
if term_req.req.order.order == Order::Desc {
keyed.sort_unstable_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
keyed.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
entries = keyed.into_iter().map(|(_, e)| e).collect();
OrderTarget::SubAggregation(_name) => {
// don't sort and cut off since it's hard to make assumptions on the quality of the
// results when cutting off du to unknown nature of the sub_aggregation (possible
// to check).
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
@@ -1019,12 +954,40 @@ where
}
}
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
};
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
let into_intermediate_bucket_entry =
|bucket: Bucket,
sub_agg: &mut Option<CachedSubAggs<C>>|
-> crate::Result<IntermediateTermBucketEntry> {
if let Some(sub_agg) = sub_agg {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
} else {
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: Default::default(),
})
}
};
if term_req.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = term_req
@@ -1035,11 +998,7 @@ where
if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?;
dict.insert(intermediate_key, intermediate_entry);
}
@@ -1047,28 +1006,19 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
let mut buckets_it = buckets.into_iter();
let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;
let mut intermediate_entry_it = intermediate_entries.into_iter();
term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| {
let bucket = buckets_it.next().unwrap();
let intermediate_entry =
into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
Ok(())
})?;
if term_req.req.min_doc_count == 0 {
@@ -1103,22 +1053,14 @@ where
}
} else if term_req.column_type == ColumnType::DateTime {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val = i64::from_u64(val);
let date = format_date(val)?;
dict.insert(IntermediateKey::Str(date), intermediate_entry);
}
} else if term_req.column_type == ColumnType::Bool {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val = bool::from_u64(val);
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
}
@@ -1138,22 +1080,14 @@ where
})?;
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
dict.insert(IntermediateKey::IpAddr(val), intermediate_entry);
}
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
if term_req.column_type == ColumnType::U64 {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
} else if term_req.column_type == ColumnType::I64 {
@@ -1187,13 +1121,13 @@ where
}
}
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentTermCollector<TermMap, B> {
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentTermCollector<TermMap, C> {
#[inline]
fn collect_terms_with_docs(
iter: impl Iterator<Item = (crate::DocId, u64)>,
term_buckets: &mut TermMap,
bucket_id_provider: &mut BucketIdProvider,
sub_agg: &mut BufferedSubAggs<B>,
sub_agg: &mut CachedSubAggs<C>,
) {
for (doc, term_id) in iter {
let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider);
@@ -1266,7 +1200,7 @@ mod tests {
use crate::aggregation::{AggregationLimitsGuard, DistributedAggregationCollector};
use crate::indexer::NoMergePolicy;
use crate::query::AllQuery;
use crate::schema::{IntoIpv6Addr, Schema, FAST, INDEXED, STRING, TEXT};
use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING};
use crate::{Index, IndexWriter};
#[test]
@@ -1795,263 +1729,6 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(true)
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(false)
}
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
// Distinct score values per bucket key: A→5, B→1, C→3.
// Order by cardinality desc must yield A, C, B.
let segment_and_terms = vec![vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
]];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
// and `sum_other_doc_count` reflects the dropped B (3 docs).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
Ok(())
}
#[test]
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(true)
}
#[test]
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(false)
}
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
let segment_and_terms = vec![
vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
],
vec![
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// Desc on a Sum metric engages the fast path (column is U64).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "asc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 desc with cutoff: top-2 by sum (A, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "mystats.sum": "desc" }
},
"aggs": {
"mystats": { "stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Sum on a signed column (I64) takes the same cutoff path. Results may be
// approximate near the boundary on adversarial data, but for this dataset the
// top-K is unambiguous.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score_i64" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Order by extended_stats sub-property exercises compute_metric_value on the
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "ext.max": "desc" }
},
"aggs": {
"ext": { "extended_stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)
@@ -3217,101 +2894,4 @@ mod tests {
Ok(())
}
fn prep_index_with_n_unique_terms_plus_one_null(n: u64) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED);
let title_field = schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// set to one thread to guarantee all docs end up in the same segment
let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
writer.add_document(doc!(
id_field => 0u64,
))?;
for i in 1u64..=n {
let title = format!("foo{i}");
writer.add_document(doc!(
id_field => i,
title_field => title,
))?;
}
writer.commit()?;
Ok(index)
}
#[test]
fn null_bitset_bounds_check_regression() -> crate::Result<()> {
// include cases
for i in 0..=4 {
let index = prep_index_with_n_unique_terms_plus_one_null(i * 64)?;
let normal_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let include_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"include": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let exclude_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"exclude": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let normal_res = exec_request(normal_req, &index)?;
let normal_buckets = normal_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
normal_buckets.len(),
(i * 64) as usize + 1,
"The normal request should return all 'foo' buckets, plus the missing term bucket",
);
let include_res = exec_request(include_req, &index)?;
eprintln!("include_res: {include_res:?}");
let include_buckets = include_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
include_buckets.len(),
(i * 64) as usize,
"The include request should return all 'foo' buckets, and not the missing term \
bucket",
);
assert!(include_buckets
.iter()
.all(|b| b["key"].as_str().unwrap().starts_with("foo")));
let exclude_res = exec_request(exclude_req, &index)?;
let exclude_buckets = exclude_res["my_bool"]["buckets"].as_array().unwrap();
if i != 0 {
// TODO: Remove this if after fixing exclude + missing bug
assert_eq!(
exclude_buckets.len(),
1,
"The exclude request should exclude all 'foo' buckets, and only the missing \
term bucket",
);
assert_eq!(exclude_buckets[0]["key"], "__NULL__");
}
}
Ok(())
}
}

View File

@@ -5,7 +5,7 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
@@ -47,7 +47,7 @@ struct MissingCount {
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<HighCardBufferedSubAggs>,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
@@ -66,7 +66,7 @@ impl TermMissingAgg {
None
};
let sub_agg = sub_agg.map(BufferedSubAggs::new);
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
@@ -177,17 +177,6 @@ impl SegmentAggregationCollector for TermMissingAgg {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
None
}
}
#[cfg(test)]

View File

@@ -6,7 +6,7 @@ use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A buffer for sub-aggregations, storing doc ids per bucket id.
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
@@ -24,21 +24,21 @@ use crate::DocId;
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct BufferedSubAggs<B: SubAggBuffer> {
buffer: B,
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardBufferedSubAggs = BufferedSubAggs<LowCardSubAggBuffer>;
pub type HighCardBufferedSubAggs = BufferedSubAggs<HighCardSubAggBuffer>;
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for buffering sub-aggregation doc ids per bucket id.
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggBuffer: Debug {
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
@@ -49,22 +49,22 @@ pub trait SubAggBuffer: Debug {
) -> crate::Result<()>;
}
impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
buffer: Backend::new(),
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut dyn SegmentAggregationCollector {
&mut *self.sub_agg_collector
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.buffer.push(bucket_id, doc_id);
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
@@ -75,7 +75,7 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.buffer
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
@@ -85,7 +85,7 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.buffer
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
@@ -94,11 +94,11 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
}
}
/// Number of partitions for high cardinality sub-aggregation buffer.
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggBuffer {
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
@@ -108,7 +108,7 @@ pub(crate) struct HighCardSubAggBuffer {
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggBuffer {
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
@@ -131,7 +131,7 @@ impl PartitionEntry {
}
}
impl SubAggBuffer for HighCardSubAggBuffer {
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
@@ -173,14 +173,14 @@ impl SubAggBuffer for HighCardSubAggBuffer {
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggBuffer {
/// Buffer doc ids per bucket for sub-aggregations.
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggBuffer {
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
@@ -189,7 +189,7 @@ impl LowCardSubAggBuffer {
}
}
impl SubAggBuffer for LowCardSubAggBuffer {
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),

View File

@@ -1,6 +1,6 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::buffered_sub_aggs::LowCardBufferedSubAggs;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardBufferedSubAggs,
agg_collector: LowCardCachedSubAggs,
error: Option<TantivyError>,
}
@@ -152,7 +152,7 @@ impl AggregationSegmentCollector {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardBufferedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero

View File

@@ -1004,20 +1004,24 @@ impl IntermediateCompositeBucketResult {
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap_or_default();
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let buckets = trimmed_entry_vec
.into_iter()

File diff suppressed because it is too large Load Diff

View File

@@ -399,26 +399,6 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let extended = self.buckets.get(bucket_id as usize)?;
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
// doesn't disturb the eventual intermediate result.
extended
.finalize()
.get_value(sub_agg_property)
.ok()
.flatten()
}
}
#[cfg(test)]

View File

@@ -107,9 +107,10 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
/// The percentile key (e.g. 1.0, 5.0, 25.0).
/// Percentile
pub key: f64,
/// The percentile value. `NaN` when there are no values.
/// Value at the percentile
pub value: f64,
}

View File

@@ -312,26 +312,6 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
return None;
}
let percentile: f64 = sub_agg_property.parse().ok()?;
if !(0.0..=100.0).contains(&percentile) {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?;
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
// not affect the intermediate state used for the final result.
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
}
}
#[cfg(test)]

View File

@@ -321,40 +321,6 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let stats = self.buckets.get(bucket_id as usize)?;
// The property depends on what we're collecting:
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
// the value they were configured to collect.
let prop = match self.collecting_for {
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
StatsType::Sum if sub_agg_property.is_empty() => "sum",
StatsType::Count if sub_agg_property.is_empty() => "count",
StatsType::Max if sub_agg_property.is_empty() => "max",
StatsType::Min if sub_agg_property.is_empty() => "min",
StatsType::Average if sub_agg_property.is_empty() => "avg",
_ => return None,
};
match prop {
"count" => Some(stats.count as f64),
"sum" => Some(stats.sum),
"min" if stats.count > 0 => Some(stats.min),
"max" if stats.count > 0 => Some(stats.max),
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
_ => None,
}
}
}
#[inline]

View File

@@ -644,17 +644,6 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
);
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// top_hits is not a numeric metric and cannot be used as an order target.
None
}
}
#[cfg(test)]

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
pub(crate) mod buffered_sub_aggs;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;

View File

@@ -76,31 +76,6 @@ pub trait SegmentAggregationCollector: Debug {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
///
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
/// this value and cuts off at segment time, matching the approximation tradeoff
/// Elasticsearch makes for any sub-agg ordering.
///
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
/// the metric is a single-value kind such as cardinality.
///
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
/// must be idempotent so the final intermediate result is unaffected.
///
/// No default impl on purpose: every collector must decide explicitly whether it
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
/// buckets to the same key and drop arbitrary winners.
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64>;
}
#[derive(Default)]
@@ -162,21 +137,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
for agg in &self.aggs {
if let Some(value) =
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
{
return Some(value);
}
}
None
}
}

View File

@@ -1,6 +1,5 @@
use super::Collector;
use crate::collector::SegmentCollector;
use crate::query::Weight;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
/// `CountCollector` collector only counts how many
@@ -56,15 +55,6 @@ impl Collector for Count {
fn merge_fruits(&self, segment_counts: Vec<usize>) -> crate::Result<usize> {
Ok(segment_counts.into_iter().sum())
}
fn collect_segment(
&self,
weight: &dyn Weight,
_segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<usize> {
Ok(weight.count(reader)? as usize)
}
}
#[derive(Default)]

View File

@@ -389,13 +389,6 @@ impl SegmentCollector for FacetSegmentCollector {
}
let mut facet = vec![];
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
// document has the exact registered facet, not a child of it).
// Passing it to ord_to_term would resolve to the last dictionary
// entry and produce a spurious facet from an unrelated branch.
if facet_ord == u64::MAX {
continue;
}
// TODO handle errors.
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
if let Some((end_collapsed_facet, _)) = facet
@@ -821,63 +814,6 @@ mod tests {
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
// When a document has the exact registered facet path (not just a child),
// harvest() must not turn the unmapped sentinel into a spurious root entry.
#[test]
fn test_facet_collector_wrong_root() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let facets: Vec<&str> = vec![
"/science-fiction/asimov",
"/science-fiction/clarke",
"/science-fiction/dick",
"/science-fiction/herbert",
"/science-fiction/orwell",
// This exact match on the registered facet is the bug trigger:
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
// mapping, which without the fix resolves to an unrelated term.
"/fantasy/epic-fantasy",
"/fantasy/epic-fantasy/tolkien",
"/fantasy/epic-fantasy/martin",
];
for facet_str in &facets {
index_writer.add_document(doc!(
facet_field => Facet::from(*facet_str)
))?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut facet_collector = FacetCollector::for_field("facet");
facet_collector.add_facet("/fantasy/epic-fantasy");
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
let result: Vec<(String, u64)> = counts
.get("/")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
assert_eq!(
result,
vec![
("/fantasy/epic-fantasy/martin".to_string(), 1),
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
]
);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -1,8 +1,5 @@
use std::cmp::{Ordering, Reverse};
use std::collections::BinaryHeap;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -28,10 +25,6 @@ impl SortKeyComputer for SortBySimilarityScore {
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
//
// We use a BinaryHeap (TopNHeap) instead of TopNComputer here so that the
// threshold is always the exact K-th best score. TopNComputer only updates its
// threshold every K docs (at truncation), giving Block-WAND a stale bound.
fn collect_segment_top_k(
&self,
k: usize,
@@ -39,10 +32,12 @@ impl SortKeyComputer for SortBySimilarityScore {
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n = TopNHeap::new(k);
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
@@ -61,7 +56,7 @@ impl SortKeyComputer for SortBySimilarityScore {
Ok(top_n
.into_vec()
.into_iter()
.map(|(score, doc)| (score, DocAddress::new(segment_ord, doc)))
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.collect())
}
}
@@ -80,204 +75,3 @@ impl SegmentSortKeyComputer for SortBySimilarityScore {
score
}
}
/// Min-heap entry: higher score = greater, lower doc wins ties.
struct ScoreHeapEntry {
score: Score,
doc: DocId,
}
impl Eq for ScoreHeapEntry {}
impl PartialEq for ScoreHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for ScoreHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoreHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.then_with(|| other.doc.cmp(&self.doc))
}
}
/// Heap-based top-K for score collection. O(log K) per insert, but the threshold
/// is always tight, so Block-WAND prunes better than with [`TopNComputer`]'s
/// buffer/median approach.
///
/// Like [`TopNComputer`], items must arrive in ascending doc order, and equal
/// scores are rejected (strict `>`) so that lower doc IDs win ties.
///
/// [`TopNComputer`]: crate::collector::TopNComputer
struct TopNHeap {
heap: BinaryHeap<Reverse<ScoreHeapEntry>>,
top_n: usize,
threshold: Option<Score>,
}
impl TopNHeap {
fn new(top_n: usize) -> Self {
TopNHeap {
heap: BinaryHeap::with_capacity(top_n),
top_n,
threshold: None,
}
}
#[inline]
fn push(&mut self, score: Score, doc: DocId) {
if self.heap.len() < self.top_n {
self.heap.push(Reverse(ScoreHeapEntry { score, doc }));
if self.heap.len() == self.top_n {
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
} else if let Some(threshold) = self.threshold {
if score > threshold {
// peek_mut + assign is a single sift-down, vs pop + push = two sifts.
if let Some(mut min) = self.heap.peek_mut() {
*min = Reverse(ScoreHeapEntry { score, doc });
}
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
}
}
fn into_vec(self) -> Vec<(Score, DocId)> {
self.heap
.into_vec()
.into_iter()
.map(|Reverse(entry)| (entry.score, entry.doc))
.collect()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::TopNComputer;
#[test]
fn test_top_n_heap_zero_capacity() {
let mut heap = TopNHeap::new(0);
heap.push(1.0, 0);
heap.push(2.0, 1);
assert!(heap.into_vec().is_empty());
}
#[test]
fn test_top_n_heap_basic() {
let mut heap = TopNHeap::new(2);
heap.push(1.0, 0);
heap.push(3.0, 1);
heap.push(2.0, 2);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 1), (2.0, 2)]);
}
#[test]
fn test_top_n_heap_threshold_always_accurate() {
let mut heap = TopNHeap::new(2);
assert_eq!(heap.threshold, None);
heap.push(1.0, 0);
assert_eq!(heap.threshold, None);
heap.push(3.0, 1);
assert_eq!(heap.threshold, Some(1.0));
heap.push(2.0, 2); // evicts 1.0
assert_eq!(heap.threshold, Some(2.0));
heap.push(4.0, 3); // evicts 2.0
assert_eq!(heap.threshold, Some(3.0));
}
#[test]
fn test_top_n_heap_tiebreaking_lower_doc_wins() {
let mut heap = TopNHeap::new(2);
heap.push(5.0, 0);
heap.push(5.0, 1);
heap.push(5.0, 2); // rejected: not strictly > threshold
let mut results = heap.into_vec();
results.sort_by_key(|&(_, doc)| doc);
assert_eq!(results, vec![(5.0, 0), (5.0, 1)]);
}
#[test]
fn test_top_n_heap_single_element() {
let mut heap = TopNHeap::new(1);
heap.push(1.0, 0);
assert_eq!(heap.threshold, Some(1.0));
heap.push(0.5, 1); // rejected
heap.push(2.0, 2); // accepted
assert_eq!(heap.threshold, Some(2.0));
let results = heap.into_vec();
assert_eq!(results, vec![(2.0, 2)]);
}
#[test]
fn test_top_n_heap_under_capacity() {
let mut heap = TopNHeap::new(5);
heap.push(3.0, 0);
heap.push(1.0, 1);
heap.push(2.0, 2);
// Only 3 elements, capacity is 5 — all should be kept
assert_eq!(heap.threshold, None);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 0), (2.0, 2), (1.0, 1)]);
}
proptest! {
#[test]
fn test_top_n_heap_matches_top_n_computer(
limit in 0..20_usize,
mut docs in proptest::collection::vec((0..1000_u32, 0..1000_u32), 0..200_usize),
) {
// Both require ascending doc order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
docs.dedup_by_key(|(_, doc_id)| *doc_id);
let mut heap = TopNHeap::new(limit);
let mut computer: TopNComputer<Score, DocId, NaturalComparator> =
TopNComputer::new_with_comparator(limit, NaturalComparator);
for &(score_u32, doc) in &docs {
let score = score_u32 as Score;
heap.push(score, doc);
computer.push(score, doc);
}
let mut heap_results = heap.into_vec();
heap_results.sort_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1))
});
let computer_results: Vec<(Score, DocId)> = computer
.into_sorted_vec()
.into_iter()
.map(|cd| (cd.sort_key, cd.doc))
.collect();
prop_assert_eq!(heap_results, computer_results);
}
}
}

View File

@@ -513,9 +513,7 @@ pub struct TopNComputer<Score, D, C> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
/// The current threshold for pruning. Documents with scores at or below
/// this value are skipped by `push()`. Updated when the buffer is truncated.
pub threshold: Option<Score>,
pub(crate) threshold: Option<Score>,
comparator: C,
}

View File

@@ -1,7 +1,5 @@
use std::borrow::{Borrow, BorrowMut};
use common::TinySet;
use crate::fastfield::AliveBitSet;
use crate::DocId;
@@ -16,12 +14,6 @@ pub const TERMINATED: DocId = i32::MAX as u32;
/// exactly this size as long as we can fill the buffer.
pub const COLLECT_BLOCK_BUFFER_LEN: usize = 64;
/// Number of `TinySet` (64-bit) buckets in a block used by [`DocSet::fill_bitset_block`].
pub const BLOCK_NUM_TINYBITSETS: usize = 16;
/// Number of doc IDs covered by one block: `BLOCK_NUM_TINYBITSETS * 64 = 1024`.
pub const BLOCK_WINDOW: u32 = BLOCK_NUM_TINYBITSETS as u32 * 64;
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
@@ -168,31 +160,6 @@ pub trait DocSet: Send {
self.size_hint() as u64
}
/// Fills a bitmask representing which documents in `[min_doc, min_doc + BLOCK_WINDOW)` are
/// present in this docset.
///
/// The window is divided into `BLOCK_NUM_TINYBITSETS` buckets of 64 docs each.
/// Returns the next doc `>= min_doc + BLOCK_WINDOW`, or `TERMINATED` if exhausted.
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
self.seek(min_doc);
let horizon = min_doc + BLOCK_WINDOW;
loop {
let doc = self.doc();
if doc >= horizon {
return doc;
}
let delta = doc - min_doc;
mask[(delta / 64) as usize].insert_mut(delta % 64);
if self.advance() == TERMINATED {
return TERMINATED;
}
}
}
/// Returns the number documents matching.
/// Calling this method consumes the `DocSet`.
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
@@ -247,18 +214,6 @@ impl DocSet for &mut dyn DocSet {
(**self).seek_danger(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
(**self).fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
(**self).fill_bitset_block(min_doc, mask)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -301,15 +256,6 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_bitset_block(min_doc, mask)
}
fn doc(&self) -> DocId {
let unboxed: &TDocSet = self.borrow();
unboxed.doc()

View File

@@ -6,7 +6,6 @@ use common::{ByteCount, HasLen};
use fnv::FnvHashMap;
use itertools::Itertools;
use crate::directory::error::OpenReadError;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
@@ -160,10 +159,12 @@ impl SegmentReader {
let postings_file = segment.open_read(SegmentComponent::Postings)?;
let postings_composite = CompositeFile::open(&postings_file)?;
let positions_composite = match segment.open_read(SegmentComponent::Positions) {
Ok(positions_file) => CompositeFile::open(&positions_file)?,
Err(OpenReadError::FileDoesNotExist(_)) => CompositeFile::empty(),
Err(open_read_error) => return Err(open_read_error.into()),
let positions_composite = {
if let Ok(positions_file) = segment.open_read(SegmentComponent::Positions) {
CompositeFile::open(&positions_file)?
} else {
CompositeFile::empty()
}
};
let schema = segment.schema();

View File

@@ -249,12 +249,6 @@ impl BlockSegmentPostings {
/// Returns the length of the current block.
///
/// Returns the decoded term-frequency buffer for the current block.
#[inline]
pub(crate) fn freq_output_array(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
@@ -304,11 +298,6 @@ impl BlockSegmentPostings {
}
}
#[inline]
pub(crate) fn has_remaining_docs(&self) -> bool {
self.skip_reader.has_remaining_docs()
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.block_loaded
}

View File

@@ -146,11 +146,6 @@ impl SkipReader {
skip_reader
}
#[inline(always)]
pub fn has_remaining_docs(&self) -> bool {
self.remaining_docs != 0
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0

View File

@@ -50,7 +50,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers: &mut [TermScorerWithMaxScore],
pivot_len: usize,
) {
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
let mut scorer_to_seek = pivot_len - 1;
let mut global_max_score = scorers[scorer_to_seek].max_score;
let mut doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block();
@@ -76,7 +76,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers[scorer_to_seek].seek(doc_to_seek_after);
restore_ordering(scorers, scorer_to_seek);
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
}
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
@@ -90,7 +90,7 @@ fn restore_ordering(term_scorers: &mut [TermScorerWithMaxScore], ord: usize) {
}
term_scorers.swap(i, i - 1);
}
debug_assert!(term_scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(term_scorers.iter().map(|scorer| scorer.doc())));
}
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
@@ -150,21 +150,17 @@ pub fn block_wand(
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
scorers.retain(|scorer| scorer.doc() < TERMINATED);
if scorers.len() == 1 {
let scorer = scorers.pop().unwrap();
return block_wand_single_scorer(scorer, threshold, callback);
}
let mut scorers: Vec<TermScorerWithMaxScore> = scorers
.iter_mut()
.map(TermScorerWithMaxScore::from)
.collect();
// At this point we need to ensure that the scorers are sorted!
scorers.sort_by_key(|scorer| scorer.doc());
// At this point we need to ensure that the scorers are sorted!
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
while let Some((before_pivot_len, pivot_len, pivot_doc)) =
find_pivot_doc(&scorers[..], threshold)
{
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert_ne!(pivot_doc, TERMINATED);
debug_assert!(before_pivot_len < pivot_len);
@@ -232,7 +228,7 @@ pub fn block_wand_single_scorer(
loop {
// We position the scorer on a block that can reach
// the threshold.
while scorer.block_max_score() <= threshold {
while scorer.block_max_score() < threshold {
let last_doc_in_block = scorer.last_doc_in_block();
if last_doc_in_block == TERMINATED {
return;
@@ -290,6 +286,18 @@ impl DerefMut for TermScorerWithMaxScore<'_> {
}
}
fn is_sorted<I: Iterator<Item = DocId>>(mut it: I) -> bool {
if let Some(first) = it.next() {
let mut prev = first;
for doc in it {
if doc < prev {
return false;
}
prev = doc;
}
}
true
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

View File

@@ -1,464 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
/// Block-max pruning for top-K over intersection of term scorers.
///
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
/// For each window, the sum of block_max_scores is compared to the current threshold;
/// if the block can't beat it, the entire block is skipped.
///
/// Within non-skipped blocks, individual documents are pruned by checking whether
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
/// performing the expensive intersection membership check (seeking into secondary scorers).
///
/// # Preconditions
/// - `scorers` has at least 2 elements
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
pub(crate) fn block_wand_intersection(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
assert!(scorers.len() >= 2);
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
scorers.sort_by_key(TermScorer::size_hint);
let (leader, secondaries) = scorers.split_first_mut().unwrap();
// Precompute global max scores for early termination checks.
let leader_max_score: Score = leader.max_score();
let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum();
// Early exit: no document can possibly beat the threshold.
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
// Borrow fieldnorm reader and BM25 weight before the main loop.
// These are immutable references to disjoint fields from block_cursor,
// but Rust's borrow checker can't see through method calls, so we
// extract them once upfront.
let fieldnorm_reader = leader.fieldnorm_reader().clone();
let bm25_weight = leader.bm25_weight().clone();
let mut doc = leader.doc();
let mut secondary_block_max_scores: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
let mut secondary_suffix_block_max: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
while doc < TERMINATED {
// --- Phase 1: Block-level pruning ---
//
// Position all skip readers on the block containing `doc`.
// seek_block is cheap: it only advances the skip reader, no block decompression.
leader.seek_block(doc);
let leader_block_max: Score = leader.block_max_score();
// Compute the window end as the minimum last_doc_in_block across all scorers.
// This ensures the block_max values are valid for all docs in [doc, window_end].
// Different scorers have independently aligned blocks, so we must use the
// smallest window where all block_max values hold.
let mut window_end: DocId = leader.last_doc_in_block();
let mut secondary_block_max_sum: Score = 0.0;
let num_secondaries = secondaries.len();
for (idx, secondary) in secondaries.iter_mut().enumerate() {
secondary.block_cursor().seek_block(doc);
if !secondary.block_cursor().has_remaining_docs() {
return;
}
window_end = window_end.min(secondary.last_doc_in_block());
let bms = secondary.block_max_score();
secondary_block_max_scores[idx] = bms;
secondary_block_max_sum += bms;
}
if leader_block_max + secondary_block_max_sum <= threshold {
// The entire window cannot beat the threshold. Skip past it.
doc = window_end + 1;
continue;
}
// --- Phase 2: Batch processing within the window ---
//
// Score-first approach: decode the leader's block, filter by threshold,
// then check intersection membership only for survivors. This avoids expensive
// secondary seeks for docs that can't beat the threshold.
let block_cursor = leader.block_cursor();
// seek loads the block and returns the in-block index of the first doc >= `doc`.
let start_idx = block_cursor.seek(doc);
// Use the branchless binary search on the doc decoder to find the first
// index past window_end.
let end_idx = block_cursor
.doc_decoder
.seek_within_block(window_end + 1)
.min(block_cursor.block_len());
let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx];
let block_freqs = &block_cursor.freq_output_array()[start_idx..end_idx];
// Pass 1: Batch-compute leader BM25 scores and branchlessly filter
// candidates that can't beat the threshold.
//
// The trick: always write to the buffer at `num_candidates`, then
// conditionally advance the count. The compiler can turn this into
// a cmov instead of a branch, avoiding misprediction costs.
let score_threshold = threshold - secondary_block_max_sum;
let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE];
let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE];
let mut num_candidates = 0usize;
for (candidate_doc, term_freq) in
block_docs.iter().copied().zip(block_freqs.iter().copied())
{
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc);
let leader_score = bm25_weight.score(fieldnorm_id, term_freq);
candidate_doc_ids[num_candidates] = candidate_doc;
candidate_scores[num_candidates] = leader_score;
num_candidates += (leader_score > score_threshold) as usize;
}
// Precompute suffix sums: suffix[i] = sum of block_max for secondaries[i+1..].
// Used in Phase 2 to prune candidates that can't beat threshold even with
// remaining secondaries contributing their block_max.
if num_candidates == 0 {
doc = window_end + 1;
continue;
}
let mut running = 0.0f32;
for idx in (0..num_secondaries).rev() {
secondary_suffix_block_max[idx] = running;
running += secondary_block_max_scores[idx];
}
// Pass 2: Check intersection membership only for survivors.
// score_threshold may be stale (threshold can increase from callbacks),
// but that's conservative — we may check a few extra candidates, never miss one.
'next_candidate: for candidate_idx in 0..num_candidates {
let candidate_doc = candidate_doc_ids[candidate_idx];
let mut total_score: Score = candidate_scores[candidate_idx];
for (secondary_idx, secondary) in secondaries.iter_mut().enumerate() {
// If a previous candidate already advanced this secondary past
// candidate_doc, the candidate can't be in the intersection.
if secondary.doc() > candidate_doc {
continue 'next_candidate;
}
let seek_result = secondary.seek(candidate_doc);
if seek_result != candidate_doc {
continue 'next_candidate;
}
total_score += secondary.score();
// Prune: even if all remaining secondaries score at their block max,
// can we still beat the threshold?
if total_score + secondary_suffix_block_max[secondary_idx] <= threshold {
continue 'next_candidate;
}
}
// All secondaries matched.
if total_score > threshold {
threshold = callback(candidate_doc, total_score);
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
}
}
doc = window_end + 1;
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use proptest::prelude::*;
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
struct Float(Score);
impl Eq for Float {}
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Float {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn nearly_equals(left: Score, right: Score) -> bool {
(left - right).abs() < 0.0001 * (left + right).abs()
}
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
fn compute_checkpoints_block_wand_intersection(
term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit: Score = 0.0;
let callback = &mut |doc, score| {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
limit
};
super::block_wand_intersection(term_scorers, Score::MIN, callback);
checkpoints
}
/// Naive baseline: intersect by iterating all docs.
fn compute_checkpoints_naive_intersection(
mut term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit = Score::MIN;
// Sort by cost to use the cheapest as driver.
term_scorers.sort_by_key(|s| s.cost());
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
let mut doc = leader.doc();
while doc != TERMINATED {
let mut all_match = true;
for secondary in secondaries.iter_mut() {
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
all_match = false;
break;
}
}
if all_match {
let score: Score =
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
if score > limit {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
}
}
doc = leader.advance();
}
checkpoints
}
const MAX_TERM_FREQ: u32 = 100u32;
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
(1..max_doc + 1)
.prop_flat_map(move |doc_freq| {
(
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
)
})
.prop_map(|(docset, term_freqs)| {
docset
.iter()
.map(|doc| doc as u32)
.zip(term_freqs.iter().cloned())
.collect::<Vec<_>>()
})
.boxed()
}
#[expect(clippy::type_complexity)]
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
(1u32..100u32)
.prop_flat_map(move |max_doc: u32| {
(
proptest::collection::vec(posting_list(max_doc), num_scorers),
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
)
})
.boxed()
}
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
// strategy.
const REPEAT: usize = 64;
let fieldnorms_expanded: Vec<u32> = fieldnorms
.iter()
.cloned()
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
.collect();
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
.iter()
.map(|posting_list| {
posting_list
.iter()
.cloned()
.flat_map(|(doc, term_freq)| {
(0_u32..REPEAT as u32).map(move |offset| {
(
doc * (REPEAT as u32) + offset,
if offset == 0 { term_freq } else { 1 },
)
})
})
.collect::<Vec<(DocId, u32)>>()
})
.collect();
let total_fieldnorms: u64 = fieldnorms_expanded
.iter()
.cloned()
.map(|fieldnorm| fieldnorm as u64)
.sum();
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
let max_doc = fieldnorms_expanded.len();
let make_scorers = || -> Vec<TermScorer> {
postings_lists_expanded
.iter()
.map(|postings| {
let bm25_weight = Bm25Weight::for_one_term(
postings.len() as u64,
max_doc as u64,
average_fieldnorm,
);
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
})
.collect()
};
for top_k in 1..4 {
let checkpoints_optimized =
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
assert_eq!(
checkpoints_optimized.len(),
checkpoints_naive.len(),
"Mismatch in checkpoint count for top_k={top_k}"
);
for (&(left_doc, left_score), &(right_doc, right_score)) in
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
{
assert_eq!(left_doc, right_doc);
assert!(
nearly_equals(left_score, right_score),
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_two_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(2)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_three_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(3)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
#[test]
fn test_block_wand_intersection_disjoint() {
// Two posting lists with no overlap — intersection is empty.
let fieldnorms: Vec<u32> = vec![10; 200];
let average_fieldnorm = 10.0;
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
let scorer_a = TermScorer::create_for_test(
&postings_a,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let scorer_b = TermScorer::create_for_test(
&postings_b,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
assert!(checkpoints.is_empty());
}
#[test]
fn test_block_wand_intersection_all_overlap() {
// Two posting lists with full overlap.
let fieldnorms: Vec<u32> = vec![10; 50];
let average_fieldnorm = 10.0;
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
let make_scorer = || {
TermScorer::create_for_test(
&postings,
&fieldnorms,
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
)
};
let checkpoints_opt =
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
let checkpoints_naive =
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
}
}

View File

@@ -16,7 +16,6 @@ use crate::{DocId, Score};
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
TermIntersection(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
@@ -50,9 +49,10 @@ where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 && !scorers[0].is::<TermScorer>() {
if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
}
{
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
if is_all_term_queries {
@@ -66,9 +66,6 @@ where
{
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else if scorers.len() == 1 {
// Single TermScorer without freq reading — unwrap directly.
return SpecializedScorer::Other(Box::new(scorers.into_iter().next().unwrap()));
} else {
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
@@ -96,13 +93,6 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
intersect_scorers(boxed_scorers, num_docs)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
@@ -307,43 +297,14 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
// Try to detect a pure TermScorer intersection for block-max optimization.
// Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer
// with frequency reading enabled.
if combined_all_scorer_count == 0
&& must_scorers.len() >= 2
&& must_scorers.iter().all(|s| s.is::<TermScorer>())
{
let term_scorers: Vec<TermScorer> = must_scorers
.into_iter()
.map(|s| *(s.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
if term_scorers
.iter()
.all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq)
{
SpecializedScorer::TermIntersection(term_scorers)
} else {
let must_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
let boxed_scorer: Box<dyn Scorer> =
effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
} else {
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
@@ -502,21 +463,15 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
let num_docs = reader.num_docs();
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_scorer(intersection.as_mut(), callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
@@ -530,23 +485,17 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let num_docs = reader.num_docs();
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
@@ -575,9 +524,6 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
super::block_wand_intersection(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}

View File

@@ -1,10 +1,8 @@
mod block_wand_intersection;
mod block_wand_union;
mod block_wand;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand_intersection::block_wand_intersection;
pub(crate) use self::block_wand_union::{block_wand, block_wand_single_scorer};
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_weight::BooleanWeight;

View File

@@ -1,7 +1,5 @@
use common::TinySet;
use super::size_hint::estimate_intersection;
use crate::docset::{DocSet, SeekDangerResult, BLOCK_NUM_TINYBITSETS, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
@@ -19,7 +17,7 @@ use crate::{DocId, Score};
/// `size_hint` of the intersection.
pub fn intersect_scorers(
mut scorers: Vec<Box<dyn Scorer>>,
segment_num_docs: u32,
num_docs_segment: u32,
) -> Box<dyn Scorer> {
if scorers.is_empty() {
return Box::new(EmptyScorer);
@@ -44,14 +42,14 @@ pub fn intersect_scorers(
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers,
segment_num_docs,
num_docs: num_docs_segment,
});
}
Box::new(Intersection {
left,
right,
others: scorers,
segment_num_docs,
num_docs: num_docs_segment,
})
}
@@ -60,7 +58,7 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
left: TDocSet,
right: TDocSet,
others: Vec<TOtherDocSet>,
segment_num_docs: u32,
num_docs: u32,
}
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
@@ -80,10 +78,7 @@ fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
/// num_docs is the number of documents in the segment.
pub(crate) fn new(
mut docsets: Vec<TDocSet>,
segment_num_docs: u32,
) -> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>, num_docs: u32) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len();
assert!(num_docsets >= 2);
docsets.sort_by_key(|docset| docset.cost());
@@ -102,7 +97,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
left,
right,
others: docsets,
segment_num_docs,
num_docs,
}
}
}
@@ -219,7 +214,7 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
[self.left.size_hint(), self.right.size_hint()]
.into_iter()
.chain(self.others.iter().map(DocSet::size_hint)),
self.segment_num_docs,
self.num_docs,
)
}
@@ -229,91 +224,6 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
// If there are docsets that are bad at skipping, they should also influence the cost.
self.left.cost()
}
fn count_including_deleted(&mut self) -> u32 {
const DENSITY_THRESHOLD_INVERSE: u32 = 32;
if self
.left
.size_hint()
.saturating_mul(DENSITY_THRESHOLD_INVERSE)
< self.segment_num_docs
{
// Sparse path: if the lead iterator covers less than ~3% of docs,
// the block approach wastes time on mostly-empty blocks.
self.count_including_deleted_sparse()
} else {
// Dense approach. We push documents into a block bitset to then
// perform count using popcount.
self.count_including_deleted_dense()
}
}
}
const EMPTY_BLOCK: [TinySet; BLOCK_NUM_TINYBITSETS] = [TinySet::EMPTY; BLOCK_NUM_TINYBITSETS];
/// ANDs `other` into `mask` in-place. Returns `true` if the result is all zeros.
#[inline]
fn and_blocks_and_return_is_empty(
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
update: &[TinySet; BLOCK_NUM_TINYBITSETS],
) -> bool {
let mut all_empty = true;
for (mask_tinyset, update_tinyset) in mask.iter_mut().zip(update.iter()) {
*mask_tinyset = mask_tinyset.intersect(*update_tinyset);
all_empty &= mask_tinyset.is_empty();
}
all_empty
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> Intersection<TDocSet, TOtherDocSet> {
fn count_including_deleted_sparse(&mut self) -> u32 {
let mut count = 0u32;
let mut doc = self.doc();
while doc != TERMINATED {
count += 1;
doc = self.advance();
}
count
}
/// Dense block-wise bitmask intersection count.
///
/// Fills a 1024-doc window from each iterator, ANDs the bitmasks together,
/// and popcounts the result. `fill_bitset_block` handles seeking tails forward
/// when they lag behind the current block.
fn count_including_deleted_dense(&mut self) -> u32 {
let mut count = 0u32;
let mut next_base = self.left.doc();
while next_base < TERMINATED {
let base = next_base;
// Fill lead bitmask.
let mut mask = EMPTY_BLOCK;
next_base = next_base.max(self.left.fill_bitset_block(base, &mut mask));
let mut tail_mask = EMPTY_BLOCK;
next_base = next_base.max(self.right.fill_bitset_block(base, &mut tail_mask));
if and_blocks_and_return_is_empty(&mut mask, &tail_mask) {
continue;
}
// AND with each additional tail.
for other in &mut self.others {
let mut other_mask = EMPTY_BLOCK;
next_base = next_base.max(other.fill_bitset_block(base, &mut other_mask));
if and_blocks_and_return_is_empty(&mut mask, &other_mask) {
continue;
}
}
for tinyset in &mask {
count += tinyset.len();
}
}
count
}
}
impl<TScorer, TOtherScorer> Scorer for Intersection<TScorer, TOtherScorer>
@@ -511,82 +421,6 @@ mod tests {
}
}
proptest! {
#[test]
fn prop_test_count_including_deleted_matches_default(
a in sorted_deduped_vec(1200, 400),
b in sorted_deduped_vec(1200, 400),
c in sorted_deduped_vec(1200, 400),
num_docs in 1200u32..2000u32,
) {
// Compute expected count via set intersection.
let expected: u32 = a.iter()
.filter(|doc| b.contains(doc) && c.contains(doc))
.count() as u32;
// Test count_including_deleted (dense path).
let make_intersection = || {
Intersection::new(
vec![
VecDocSet::from(a.clone()),
VecDocSet::from(b.clone()),
VecDocSet::from(c.clone()),
],
num_docs,
)
};
let mut intersection = make_intersection();
let count = intersection.count_including_deleted();
prop_assert_eq!(count, expected,
"count_including_deleted mismatch: a={:?}, b={:?}, c={:?}", a, b, c);
}
}
#[test]
fn test_count_including_deleted_two_way() {
let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right], 100);
assert_eq!(intersection.count_including_deleted(), 2);
}
#[test]
fn test_count_including_deleted_empty() {
let a = VecDocSet::from(vec![1, 3]);
let b = VecDocSet::from(vec![1, 4]);
let c = VecDocSet::from(vec![3, 9]);
let mut intersection = Intersection::new(vec![a, b, c], 100);
assert_eq!(intersection.count_including_deleted(), 0);
}
/// Test with enough documents to exercise the dense path (>= num_docs/32).
#[test]
fn test_count_including_deleted_dense_path() {
// Create dense docsets: many docs relative to segment size.
let docs_a: Vec<u32> = (0..2000).step_by(2).collect(); // even numbers 0..2000
let docs_b: Vec<u32> = (0..2000).step_by(3).collect(); // multiples of 3
let expected = docs_a.iter().filter(|d| *d % 3 == 0).count() as u32;
let a = VecDocSet::from(docs_a);
let b = VecDocSet::from(docs_b);
let mut intersection = Intersection::new(vec![a, b], 2000);
assert_eq!(intersection.count_including_deleted(), expected);
}
/// Test that spans multiple blocks (>1024 docs).
#[test]
fn test_count_including_deleted_multi_block() {
let docs_a: Vec<u32> = (0..5000).collect();
let docs_b: Vec<u32> = (0..5000).step_by(7).collect();
let expected = docs_b.len() as u32; // all of b is in a
let a = VecDocSet::from(docs_a);
let b = VecDocSet::from(docs_b);
let mut intersection = Intersection::new(vec![a, b], 5000);
assert_eq!(intersection.count_including_deleted(), expected);
}
#[test]
fn test_bug_2811_intersection_candidate_should_increase() {
let mut schema_builder = Schema::builder();

View File

@@ -1,6 +1,6 @@
use crate::docset::DocSet;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
use crate::postings::{FreqReadingOption, Postings, SegmentPostings};
use crate::query::bm25::Bm25Weight;
use crate::query::{Explanation, Scorer};
use crate::{DocId, Score};
@@ -95,21 +95,6 @@ impl TermScorer {
pub fn last_doc_in_block(&self) -> DocId {
self.postings.block_cursor.skip_reader().last_doc_in_block()
}
/// Returns a mutable reference to the underlying block cursor.
pub(crate) fn block_cursor(&mut self) -> &mut BlockSegmentPostings {
&mut self.postings.block_cursor
}
/// Returns a reference to the fieldnorm reader for batch lookups.
pub(crate) fn fieldnorm_reader(&self) -> &FieldNormReader {
&self.fieldnorm_reader
}
/// Returns a reference to the BM25 weight for batch score computation.
pub(crate) fn bm25_weight(&self) -> &Bm25Weight {
&self.similarity_weight
}
}
impl DocSet for TermScorer {
@@ -132,12 +117,6 @@ impl DocSet for TermScorer {
fn size_hint(&self) -> u32 {
self.postings.size_hint()
}
// TODO
// It is probably possible to optimize fill_bitset_block for TermScorer,
// working directly with the blocks, enabling vectorization.
// I did not manage to get a performance improvement on Mac ARM,
// and do not have access to x86 to investigate.
}
impl Scorer for TermScorer {

View File

@@ -23,7 +23,7 @@ zstd-compression = ["zstd"]
[dev-dependencies]
proptest = "1"
criterion = { version = "0.8", default-features = false }
criterion = { version = "0.5", default-features = false }
names = "0.14"
rand = "0.9"

View File

@@ -14,8 +14,11 @@ use itertools::Itertools;
use tantivy_fst::Automaton;
use tantivy_fst::automaton::AlwaysMatch;
use crate::sstable_index_v3::SSTableIndexV3Empty;
use crate::streamer::{Streamer, StreamerBuilder};
use crate::{BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, TermOrdinal, VoidSSTable};
use crate::{
BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, SSTableIndexV3, TermOrdinal, VoidSSTable,
};
/// An SSTable is a sorted map that associates sorted `&[u8]` keys
/// to any kind of typed values.
@@ -285,7 +288,33 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let (sstable_slice, index_slice) = main_slice.split(index_offset as usize);
let sstable_index_bytes = index_slice.read_bytes()?;
let sstable_index = SSTableIndex::open(version, index_offset, sstable_index_bytes)?;
let sstable_index = match version {
2 => SSTableIndex::V2(
crate::sstable_index_v2::SSTableIndex::load(sstable_index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
),
3 => {
let (sstable_index_bytes, mut footerv3_len_bytes) = sstable_index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(
SSTableIndexV3::load(sstable_index_bytes, store_offset).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
Ok(Dictionary {
sstable_slice,
@@ -483,28 +512,21 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
/// Returns the terms for a _sorted_ list of term ordinals.
///
/// Returns true if and only if all terms have been found.
pub fn sorted_ords_to_term_cb(
pub fn sorted_ords_to_term_cb<F: FnMut(&[u8]) -> io::Result<()>>(
&self,
ords: &[TermOrdinal],
mut cb: impl FnMut(&[u8]),
mut ords: impl Iterator<Item = TermOrdinal>,
mut cb: F,
) -> io::Result<bool> {
assert!(ords.is_sorted());
let mut ords = ords.iter().copied();
let Some(mut ord) = ords.next() else {
return Ok(true);
};
// Open the block for the first ordinal.
let mut bytes = Vec::new();
let (mut current_block_addr, block_id) = self.sstable_index.get_and_locate_with_ord(ord);
let mut current_block_addr = self.sstable_index.get_block_with_ord(ord);
let mut current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
let mut current_block_ordinal = current_block_addr.first_ordinal;
let mut current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX);
loop {
// move to the ord inside the current block
@@ -516,38 +538,33 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
bytes.extend_from_slice(current_sstable_delta_reader.suffix());
current_block_ordinal += 1;
}
cb(&bytes);
cb(&bytes)?;
// fetch the next ordinal
let next_ord = loop {
let Some(next_ord) = ords.next() else {
return Ok(true);
};
if next_ord == ord {
// This is the same ordinal, let's just call the callback directly.
cb(&bytes);
} else {
// we checked it was sorted beforehands
debug_assert!(next_ord > ord);
break next_ord;
}
let Some(next_ord) = ords.next() else {
return Ok(true);
};
if next_ord >= current_block_end_bound {
let (new_block_addr, block_id) =
self.sstable_index.get_and_locate_with_ord(next_ord);
current_block_addr = new_block_addr;
current_block_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX)
// advance forward if the new ord is different than the one we just processed
//
// this allows the input TermOrdinal iterator to contain duplicates, so long as it's
// still sorted
if next_ord < ord {
panic!("Ordinals were not sorted: received {next_ord} after {ord}");
} else if next_ord > ord {
// check if block changed for new term_ord
let new_block_addr = self.sstable_index.get_block_with_ord(next_ord);
if new_block_addr != current_block_addr {
current_block_addr = new_block_addr;
current_block_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
}
ord = next_ord;
} else {
// The next ord is equal to the previous ord: no need to seek or advance.
}
ord = next_ord;
}
}
@@ -654,8 +671,8 @@ mod tests {
use common::OwnedBytes;
use super::Dictionary;
use crate::MonotonicU64SSTable;
use crate::dictionary::TermOrdHit;
use crate::{MonotonicU64SSTable, TermOrdinal};
#[derive(Debug)]
struct PermissionedHandle {
@@ -918,24 +935,25 @@ mod tests {
}
#[test]
fn test_sorted_ords_to_term() {
fn test_ords_term() {
let (dic, _slice) = make_test_sstable();
// Single term
let mut terms = Vec::new();
assert!(
dic.sorted_ords_to_term_cb(&[100_000], |term| {
dic.sorted_ords_to_term_cb(100_000..100_001, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap()
);
assert_eq!(terms, vec![format!("{:05X}", 100_000).into_bytes(),]);
// Single term
let mut terms = Vec::new();
let ords: Vec<TermOrdinal> = (100_001..100_002).collect();
assert!(
dic.sorted_ords_to_term_cb(&ords, |term| {
dic.sorted_ords_to_term_cb(100_001..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap()
);
@@ -943,8 +961,9 @@ mod tests {
// both terms
let mut terms = Vec::new();
assert!(
dic.sorted_ords_to_term_cb(&[100_000, 100_001], |term| {
dic.sorted_ords_to_term_cb(100_000..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap()
);
@@ -957,10 +976,10 @@ mod tests {
);
// Test cross block
let mut terms = Vec::new();
let ords: Vec<TermOrdinal> = (98653..=98655).collect();
assert!(
dic.sorted_ords_to_term_cb(&ords, |term| {
dic.sorted_ords_to_term_cb(98653..=98655, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap()
);
@@ -972,43 +991,6 @@ mod tests {
format!("{:05X}", 98655).into_bytes(),
]
);
// redundant
let mut terms = Vec::new();
let ords: Vec<TermOrdinal> = vec![1, 1, 2];
assert!(
dic.sorted_ords_to_term_cb(&ords, |term| {
terms.push(term.to_vec());
})
.unwrap()
);
assert_eq!(
terms,
vec![
format!("{:05X}", 1).into_bytes(),
format!("{:05X}", 1).into_bytes(),
format!("{:05X}", 2).into_bytes(),
]
);
// redundant cross block
let mut terms = Vec::new();
let ords: Vec<TermOrdinal> = vec![98653, 98653, 98654, 98654, 98655, 98655];
assert!(
dic.sorted_ords_to_term_cb(&ords, |term| {
terms.push(term.to_vec());
})
.unwrap()
);
assert_eq!(
terms,
vec![
format!("{:05X}", 98_653).into_bytes(),
format!("{:05X}", 98_653).into_bytes(),
format!("{:05X}", 98_654).into_bytes(),
format!("{:05X}", 98_654).into_bytes(),
format!("{:05X}", 98_655).into_bytes(),
format!("{:05X}", 98_655).into_bytes(),
]
);
}
#[test]

View File

@@ -1,319 +0,0 @@
pub(crate) mod v2;
pub(crate) mod v3;
use std::io::{self, Read, Write};
use std::ops::Range;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_fst::{Automaton, MapBuilder};
use crate::{TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(v2::SSTableIndex),
V3(v3::SSTableIndexV3),
V3Empty(v3::SSTableIndexV3Empty),
}
impl SSTableIndex {
pub(crate) fn open(
version: u32,
index_offset: u64,
index_bytes: OwnedBytes,
) -> io::Result<Self> {
let index = match version {
2 => {
SSTableIndex::V2(v2::SSTableIndex::load(index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?)
}
3 => {
let (index_bytes, mut footerv3_len_bytes) = index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(v3::SSTableIndexV3::load(index_bytes, store_offset).map_err(
|_| io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption"),
)?)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(v3::SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
Ok(index)
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_and_locate_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_and_locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_and_locate_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = v3::BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
}
}
#[cfg(test)]
mod tests {
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
}

View File

@@ -47,8 +47,9 @@ pub mod merge;
mod streamer;
pub mod value;
mod index;
pub use index::{BlockAddr, SSTableIndex, SSTableIndexBuilder};
mod sstable_index_v3;
pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3};
mod sstable_index_v2;
pub(crate) mod vint;
pub use dictionary::{Dictionary, TermOrdHit};
pub use streamer::{Streamer, StreamerBuilder};

View File

@@ -77,13 +77,6 @@ impl SSTableIndex {
self.get_block(self.locate_with_ord(ord)).unwrap()
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let location = self.locate_with_ord(ord);
// locate_with_ord always returns an index within range
let block_addr = self.get_block(location).unwrap();
(block_addr, location as u64)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,

View File

@@ -1,14 +1,106 @@
use std::io::{self, Read, Write};
use std::ops::Range;
use std::sync::Arc;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_bitpacker::{BitPacker, compute_num_bits};
use tantivy_fst::raw::Fst;
use tantivy_fst::{Automaton, IntoStreamer, Map, Streamer};
use tantivy_fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer};
use super::{BlockAddr, BlockStartAddr};
use crate::block_match_automaton::can_block_match_automaton;
use crate::{SSTableDataCorruption, TermOrdinal};
use crate::{SSTableDataCorruption, TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(crate::sstable_index_v2::SSTableIndex),
V3(SSTableIndexV3),
V3Empty(SSTableIndexV3Empty),
}
impl SSTableIndex {
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
#[derive(Debug, Clone)]
pub struct SSTableIndexV3 {
@@ -68,11 +160,6 @@ impl SSTableIndexV3 {
self.block_addr_store.binary_search_ord(ord).1
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let (location, block_addr) = self.block_addr_store.binary_search_ord(ord);
(block_addr, location)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
@@ -129,7 +216,7 @@ impl<A: Automaton> Iterator for GetBlockForAutomaton<'_, A> {
#[derive(Debug, Clone)]
pub struct SSTableIndexV3Empty {
pub block_addr: BlockAddr,
block_addr: BlockAddr,
}
impl SSTableIndexV3Empty {
@@ -143,8 +230,8 @@ impl SSTableIndexV3Empty {
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
(block_id == 0).then(|| self.block_addr.clone())
pub(crate) fn get_block(&self, _block_id: u64) -> Option<BlockAddr> {
Some(self.block_addr.clone())
}
/// Get the block id of the block that would contain `key`.
@@ -169,9 +256,146 @@ impl SSTableIndexV3Empty {
pub(crate) fn get_block_with_ord(&self, _ord: TermOrdinal) -> BlockAddr {
self.block_addr.clone()
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
pub(crate) fn get_and_locate_with_ord(&self, _ord: TermOrdinal) -> (BlockAddr, u64) {
(self.block_addr.clone(), 0)
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
}
}
@@ -423,14 +647,14 @@ fn binary_search(max: u64, cmp_fn: impl Fn(u64) -> std::cmp::Ordering) -> Result
Err(left)
}
pub(crate) struct BlockAddrStoreWriter {
struct BlockAddrStoreWriter {
buffer_block_metas: Vec<u8>,
buffer_addrs: Vec<u8>,
block_addrs: Vec<BlockAddr>,
}
impl BlockAddrStoreWriter {
pub(crate) fn new() -> Self {
fn new() -> Self {
BlockAddrStoreWriter {
buffer_block_metas: Vec::new(),
buffer_addrs: Vec::new(),
@@ -438,7 +662,7 @@ impl BlockAddrStoreWriter {
}
}
pub(crate) fn flush_block(&mut self) -> io::Result<()> {
fn flush_block(&mut self) -> io::Result<()> {
if self.block_addrs.is_empty() {
return Ok(());
}
@@ -517,7 +741,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
pub(crate) fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
self.block_addrs.push(block_addr);
if self.block_addrs.len() >= STORE_BLOCK_LEN {
self.flush_block()?;
@@ -525,7 +749,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
pub(crate) fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
self.flush_block()?;
let len = self.buffer_block_metas.len() as u64;
len.serialize(wrt)?;
@@ -600,9 +824,8 @@ mod tests {
use common::OwnedBytes;
use super::*;
use crate::SSTableDataCorruption;
use crate::block_match_automaton::tests::EqBuffer;
use crate::index::BlockMeta;
use crate::{SSTableDataCorruption, SSTableIndexBuilder};
#[test]
fn test_sstable_index() {
@@ -651,7 +874,36 @@ mod tests {
assert!(matches!(data_corruption_err, SSTableDataCorruption));
}
// use proptest::prelude::*;
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
#[test]
fn test_find_best_slop() {

View File

@@ -27,7 +27,7 @@ rand = "0.9"
zipf = "7.0.0"
rustc-hash = "2.1.0"
proptest = "1.2.0"
binggan = { version = "0.17.0" }
binggan = { version = "0.15.3" }
rand_distr = "0.5"
[features]