mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-05-31 23:50:41 +00:00
Compare commits
53 Commits
paul.masur
...
faster_uni
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74f37f045d | ||
|
|
72cca113cd | ||
|
|
672bf45235 | ||
|
|
33ef167441 | ||
|
|
bf8b263f16 | ||
|
|
24a97dbe69 | ||
|
|
34fec8b23e | ||
|
|
46b3fb9ed3 | ||
|
|
fbe620b9b4 | ||
|
|
95d8a3989a | ||
|
|
ea61a68db4 | ||
|
|
c367df37c1 | ||
|
|
d99a5d4e91 | ||
|
|
2de6f075ce | ||
|
|
18080067c7 | ||
|
|
95db7d2e5c | ||
|
|
fc017c4c74 | ||
|
|
141c91d028 | ||
|
|
36a83e7c1a | ||
|
|
be11f8a6a1 | ||
|
|
4305e4029e | ||
|
|
edfb02b47e | ||
|
|
d0fad88bac | ||
|
|
351280c0b4 | ||
|
|
4480cf0a98 | ||
|
|
d47abdf104 | ||
|
|
c11952eb7c | ||
|
|
09667ee9c8 | ||
|
|
333ccf5300 | ||
|
|
60a39a4689 | ||
|
|
f8f3e4277f | ||
|
|
ff1433713a | ||
|
|
ca139d8eb1 | ||
|
|
ac508108aa | ||
|
|
63da5a21b2 | ||
|
|
54cd5bba98 | ||
|
|
d27ca164a9 | ||
|
|
2f5a48e8b1 | ||
|
|
ae0ab907fe | ||
|
|
7d62e084e7 | ||
|
|
322286ee16 | ||
|
|
73ad18fa1e | ||
|
|
4fbae92187 | ||
|
|
89f0cef807 | ||
|
|
a5d297c75f | ||
|
|
2e16243f9a | ||
|
|
e015abab8e | ||
|
|
73c711ec74 | ||
|
|
cb037c8079 | ||
|
|
ed3453606b | ||
|
|
e9641f99c5 | ||
|
|
3a6a3de8d7 | ||
|
|
af3c6c0070 |
15
.github/workflows/coverage.yml
vendored
15
.github/workflows/coverage.yml
vendored
@@ -4,6 +4,9 @@ on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Ensures that we cancel running jobs for the same PR / same workflow.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
@@ -12,16 +15,20 @@ concurrency:
|
||||
jobs:
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
|
||||
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
continue-on-error: true
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
|
||||
|
||||
10
.github/workflows/long_running.yml
vendored
10
.github/workflows/long_running.yml
vendored
@@ -8,6 +8,9 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Ensures that we cancel running jobs for the same PR / same workflow.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
@@ -18,10 +21,13 @@ jobs:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install stable
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
|
||||
with:
|
||||
toolchain: stable
|
||||
profile: minimal
|
||||
|
||||
49
.github/workflows/scorecard.yml
vendored
Normal file
49
.github/workflows/scorecard.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: OpenSSF Scorecard
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 0'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
analysis:
|
||||
name: Scorecards analysis
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
# Needed to upload the results to code-scanning dashboard.
|
||||
security-events: write
|
||||
# Needed to publish results
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: 'Checkout code'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: 'Run analysis'
|
||||
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
|
||||
with:
|
||||
results_file: results.sarif
|
||||
results_format: sarif
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_results: true
|
||||
|
||||
# Upload the results as artifacts.
|
||||
- name: 'Upload artifact'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: SARIF file
|
||||
path: results.sarif
|
||||
retention-days: 5
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: 'Upload to code-scanning'
|
||||
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
28
.github/workflows/test.yml
vendored
28
.github/workflows/test.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# Ensures that we cancel running jobs for the same PR / same workflow.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
@@ -19,23 +22,27 @@ jobs:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
checks: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install nightly
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
|
||||
with:
|
||||
toolchain: nightly
|
||||
profile: minimal
|
||||
components: rustfmt
|
||||
- name: Install stable
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
|
||||
with:
|
||||
toolchain: stable
|
||||
profile: minimal
|
||||
components: clippy
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
|
||||
|
||||
- name: Check Formatting
|
||||
run: cargo +nightly fmt --all -- --check
|
||||
@@ -47,7 +54,7 @@ jobs:
|
||||
- name: Check Bench Compilation
|
||||
run: cargo +nightly bench --no-run --profile=dev --all-features
|
||||
|
||||
- uses: actions-rs/clippy-check@v1
|
||||
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
|
||||
with:
|
||||
toolchain: stable
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -57,6 +64,9 @@ jobs:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
features:
|
||||
@@ -67,17 +77,17 @@ jobs:
|
||||
name: test-${{ matrix.features.label}}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install stable
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
|
||||
with:
|
||||
toolchain: stable
|
||||
profile: minimal
|
||||
override: true
|
||||
|
||||
- uses: taiki-e/install-action@nextest
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
|
||||
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
Tantivy 0.26.1
|
||||
================================
|
||||
|
||||
## Performance
|
||||
- Fix quadratic runtime in nested term and composite aggregations: memory accounting scanned all parent buckets on every collect instead of just the current parent (@PSeitz @fulmicoton)
|
||||
|
||||
Tantivy 0.26 (Unreleased)
|
||||
================================
|
||||
|
||||
|
||||
14
Cargo.toml
14
Cargo.toml
@@ -65,7 +65,7 @@ tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
|
||||
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
|
||||
datasketches = { git = "https://github.com/fulmicoton-dd/datasketches-rust", rev = "7635fb8" }
|
||||
datasketches = { version = "0.3.0", features = ["hll"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
@@ -75,7 +75,7 @@ typetag = "0.2.21"
|
||||
winapi = "0.3.9"
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.16.1"
|
||||
binggan = "0.17.0"
|
||||
rand = "0.9"
|
||||
maplit = "1.0.2"
|
||||
matches = "0.1.9"
|
||||
@@ -92,7 +92,7 @@ postcard = { version = "1.0.4", features = [
|
||||
], default-features = false }
|
||||
|
||||
[target.'cfg(not(windows))'.dev-dependencies]
|
||||
criterion = { version = "0.5", default-features = false }
|
||||
criterion = { version = "0.8", default-features = false }
|
||||
|
||||
[dev-dependencies.fail]
|
||||
version = "0.5.0"
|
||||
@@ -201,3 +201,11 @@ harness = false
|
||||
[[bench]]
|
||||
name = "regex_all_terms"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "query_parser_nested"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "intersection_bench"
|
||||
harness = false
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
[](https://docs.rs/crate/tantivy/)
|
||||
[](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
|
||||
[](https://codecov.io/gh/quickwit-oss/tantivy)
|
||||
[](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
|
||||
[](https://discord.gg/MT27AG5EVE)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://crates.io/crates/tantivy)
|
||||
|
||||
@@ -63,6 +63,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
|
||||
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
|
||||
register!(group, terms_status_with_histogram);
|
||||
register!(group, terms_zipf_1000);
|
||||
register!(group, terms_zipf_1000_with_histogram);
|
||||
@@ -77,8 +79,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, composite_histogram_calendar);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, cardinality_agg_high_card);
|
||||
register!(group, cardinality_agg_low_card);
|
||||
register!(group, terms_status_with_cardinality_agg);
|
||||
register!(group, terms_100_buckets_with_cardinality_agg);
|
||||
register!(group, terms_many_with_single_term_order_by_card);
|
||||
register!(group, terms_many_with_single_term_2_order_by_card);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
@@ -166,6 +172,32 @@ fn cardinality_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
// Full-scan cardinality on a near-1M-cardinality string field.
|
||||
// Hits the dense (PagedBitset) path: every doc has a unique term,
|
||||
// so the bucket promotes from FxHashSet shortly into the scan.
|
||||
fn cardinality_agg_high_card(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "text_all_unique_terms"
|
||||
},
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
|
||||
// values). Stays on the FxHashSet path — the promotion threshold is
|
||||
// never crossed. Validates no regression on the sparse path.
|
||||
fn cardinality_agg_low_card(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "text_few_terms_status"
|
||||
},
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -198,6 +230,58 @@ fn terms_100_buckets_with_cardinality_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_with_single_term_order_by_card(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_many_terms" },
|
||||
"aggs": {
|
||||
"nested_terms": {
|
||||
"terms": {
|
||||
"field": "single_term",
|
||||
"order": { "cardinality": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
// Two-level terms ordered by cardinality at each level: a high-card outer terms
|
||||
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
|
||||
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
|
||||
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"by_ip": {
|
||||
"terms": {
|
||||
"field": "text_many_terms",
|
||||
"order": { "card_few_terms": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"card_few_terms": {
|
||||
"cardinality": { "field": "text_few_terms" }
|
||||
},
|
||||
"nested_terms": {
|
||||
"terms": {
|
||||
"field": " single_term",
|
||||
"order": { "distinct_path2": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"avg_botscore": { "avg": { "field": "score" } },
|
||||
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_7(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
@@ -270,6 +354,30 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_terms_zipf_1000_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"nested_terms": { "terms": { "field": "text_1000_terms_zipf" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_terms_status_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"nested_terms": { "terms": { "field": "text_few_terms_status" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -583,7 +691,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
|
||||
)
|
||||
.set_stored();
|
||||
let text_field = schema_builder.add_text_field("text", text_fieldtype);
|
||||
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
|
||||
let single_term = schema_builder.add_text_field("single_term", FAST);
|
||||
let json_field = schema_builder.add_json_field("json", FAST);
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
@@ -647,6 +756,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
single_term => "single_term",
|
||||
single_term => "single_term",
|
||||
text_field => "cool",
|
||||
text_field => "cool",
|
||||
text_field_all_unique_terms => "cool",
|
||||
@@ -681,6 +792,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
|
||||
};
|
||||
index_writer.add_document(doc!(
|
||||
single_term => "single_term",
|
||||
text_field => "cool",
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
|
||||
|
||||
149
benches/intersection_bench.rs
Normal file
149
benches/intersection_bench.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
|
||||
//
|
||||
// What's measured:
|
||||
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
|
||||
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
|
||||
// - Realistic term frequencies (geometric distribution, mostly low)
|
||||
// - 1M-doc single segment
|
||||
//
|
||||
// Run with: cargo bench --bench intersection_bench
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::TopDocs;
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::{Schema, TEXT};
|
||||
use tantivy::{doc, Index, ReloadPolicy, Searcher};
|
||||
|
||||
const NUM_DOCS: usize = 1_000_000;
|
||||
|
||||
struct BenchIndex {
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
/// Generate term frequency from a geometric-like distribution.
|
||||
/// Most values are 1, a few are 2-3, rarely higher.
|
||||
/// p controls the decay: higher p → more weight on tf=1.
|
||||
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
|
||||
let mut tf = 1u32;
|
||||
while tf < 10 && rng.random_bool(1.0 - p) {
|
||||
tf += 1;
|
||||
}
|
||||
tf
|
||||
}
|
||||
|
||||
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
|
||||
/// Each term occurrence has a realistic term frequency (geometric distribution).
|
||||
/// Field length is padded with filler tokens to create varied fieldnorms.
|
||||
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let body = schema_builder.add_text_field("body", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
let mut rng = StdRng::from_seed([42u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
|
||||
for _ in 0..NUM_DOCS {
|
||||
let mut tokens: Vec<String> = Vec::new();
|
||||
|
||||
if rng.random_bool(p_a) {
|
||||
let tf = random_term_freq(&mut rng, 0.7);
|
||||
for _ in 0..tf {
|
||||
tokens.push("aaa".to_string());
|
||||
}
|
||||
}
|
||||
if rng.random_bool(p_b) {
|
||||
let tf = random_term_freq(&mut rng, 0.7);
|
||||
for _ in 0..tf {
|
||||
tokens.push("bbb".to_string());
|
||||
}
|
||||
}
|
||||
if rng.random_bool(p_c) {
|
||||
let tf = random_term_freq(&mut rng, 0.7);
|
||||
for _ in 0..tf {
|
||||
tokens.push("ccc".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Pad with filler to create varied field lengths (5-30 tokens).
|
||||
let filler_count = rng.random_range(5u32..30u32);
|
||||
for _ in 0..filler_count {
|
||||
tokens.push("filler".to_string());
|
||||
}
|
||||
|
||||
let text = tokens.join(" ");
|
||||
writer.add_document(doc!(body => text)).unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = QueryParser::for_index(&index, vec![body]);
|
||||
|
||||
BenchIndex {
|
||||
searcher,
|
||||
query_parser,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Scenarios: (label, p_a, p_b, p_c)
|
||||
//
|
||||
// "balanced": all terms ~10% → intersection ~1% of docs
|
||||
// "skewed": one common (50%), one rare (2%) → intersection ~1%
|
||||
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
|
||||
// "three_balanced": three terms ~20% each → intersection ~0.8%
|
||||
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
|
||||
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
|
||||
("balanced_10%_10%", 0.10, 0.10, 0.0),
|
||||
("skewed_50%_2%", 0.50, 0.02, 0.0),
|
||||
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
|
||||
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
|
||||
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
|
||||
for (label, p_a, p_b, p_c) in &scenarios {
|
||||
let bench_index = build_index(*p_a, *p_b, *p_c);
|
||||
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("intersection — {label}"));
|
||||
|
||||
// Two-term intersection
|
||||
if *p_a > 0.0 && *p_b > 0.0 {
|
||||
let query_str = "+aaa +bbb";
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let searcher = bench_index.searcher.clone();
|
||||
group.register(format!("{query_str} top10"), move |_| {
|
||||
let collector = TopDocs::with_limit(10).order_by_score();
|
||||
black_box(searcher.search(&query, &collector).unwrap());
|
||||
1usize
|
||||
});
|
||||
}
|
||||
|
||||
// Three-term intersection
|
||||
if *p_c > 0.0 {
|
||||
let query_str = "+aaa +bbb +ccc";
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let searcher = bench_index.searcher.clone();
|
||||
group.register(format!("{query_str} top10"), move |_| {
|
||||
let collector = TopDocs::with_limit(10).order_by_score();
|
||||
black_box(searcher.search(&query, &collector).unwrap());
|
||||
1usize
|
||||
});
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
35
benches/query_parser_nested.rs
Normal file
35
benches/query_parser_nested.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
// Benchmark for the query grammar parsing deeply nested queries.
|
||||
//
|
||||
// Regression guard for https://github.com/quickwit-oss/tantivy/issues/2498:
|
||||
// at depth 20/21 the old parser took 0.87 s / 1.72 s respectively because
|
||||
// `ast()` retried `occur_leaf` on backtrack, giving O(2^n) time. With the
|
||||
// fix parsing is linear and completes in microseconds.
|
||||
//
|
||||
// Run with: `cargo bench --bench query_parser_nested`.
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use tantivy::query_grammar::parse_query;
|
||||
|
||||
fn nested_query(depth: usize, leading_plus: bool) -> String {
|
||||
let leading = "(".repeat(depth);
|
||||
let trailing = ")".repeat(depth);
|
||||
let prefix = if leading_plus { "+" } else { "" };
|
||||
format!("{prefix}{leading}title:test{trailing}")
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let mut runner = BenchRunner::new();
|
||||
|
||||
for depth in [20, 21] {
|
||||
for leading_plus in [false, true] {
|
||||
let query = nested_query(depth, leading_plus);
|
||||
let label = format!(
|
||||
"parse_nested_depth_{depth}_{}",
|
||||
if leading_plus { "plus" } else { "plain" },
|
||||
);
|
||||
runner.bench_function(&label, move |_| {
|
||||
black_box(parse_query(black_box(&query)).unwrap());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
|
||||
proptest = "1"
|
||||
more-asserts = "0.3.1"
|
||||
rand = "0.9"
|
||||
binggan = "0.16.1"
|
||||
binggan = "0.17.0"
|
||||
|
||||
[[bench]]
|
||||
name = "bench_merge"
|
||||
|
||||
@@ -19,6 +19,6 @@ time = { version = "0.3.47", features = ["serde-well-known"] }
|
||||
serde = { version = "1.0.136", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.16.1"
|
||||
binggan = "0.17.0"
|
||||
proptest = "1.0.0"
|
||||
rand = "0.9"
|
||||
|
||||
@@ -1045,18 +1045,43 @@ fn operand_leaf(inp: &str) -> IResult<&str, (Option<BinaryOperand>, Option<Occur
|
||||
}
|
||||
|
||||
fn ast(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
let boolean_expr = map_res(
|
||||
separated_pair(occur_leaf, multispace1, many1(operand_leaf)),
|
||||
|(left, right)| aggregate_binary_expressions(left, right),
|
||||
);
|
||||
let single_leaf = map(occur_leaf, |(occur, ast)| {
|
||||
if occur == Some(Occur::MustNot) {
|
||||
ast.unary(Occur::MustNot)
|
||||
} else {
|
||||
ast
|
||||
}
|
||||
});
|
||||
delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp)
|
||||
// Parse `occur_leaf` once, then conditionally extend into a boolean
|
||||
// expression. The previous implementation used `alt((boolean_expr,
|
||||
// single_leaf))` which, when the input was a single leaf with no
|
||||
// following operand, would parse `occur_leaf` once for `boolean_expr`,
|
||||
// fail at `multispace1`, backtrack, then re-parse `occur_leaf` for
|
||||
// `single_leaf`. With recursively-nested groups like `(+(+(+a)))`, that
|
||||
// doubling at every level produced O(2^n) parse time. Parsing once and
|
||||
// peeking ahead for the operand keeps it O(n).
|
||||
delimited(
|
||||
multispace0,
|
||||
|inp| {
|
||||
let (rest, first) = occur_leaf(inp)?;
|
||||
// Only fall back on `Err::Error` (recoverable), mirroring
|
||||
// `alt`'s behaviour. `Err::Failure` and `Err::Incomplete`
|
||||
// must propagate so cut points and streaming needs are not
|
||||
// accidentally swallowed if they are ever introduced in the
|
||||
// operand parsers.
|
||||
match preceded(multispace1, many1(operand_leaf))(rest) {
|
||||
Ok((rest, more)) => {
|
||||
let combined = aggregate_binary_expressions(first, more)
|
||||
.map_err(|_| nom::Err::Error(Error::new(inp, ErrorKind::MapRes)))?;
|
||||
Ok((rest, combined))
|
||||
}
|
||||
Err(nom::Err::Error(_)) => {
|
||||
let (occur, ast) = first;
|
||||
let single = if occur == Some(Occur::MustNot) {
|
||||
ast.unary(Occur::MustNot)
|
||||
} else {
|
||||
ast
|
||||
};
|
||||
Ok((rest, single))
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
},
|
||||
multispace0,
|
||||
)(inp)
|
||||
}
|
||||
|
||||
fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> {
|
||||
@@ -1891,4 +1916,23 @@ mod test {
|
||||
r#"(+"field":'happy tax payer' +"other_field":1)"#,
|
||||
);
|
||||
}
|
||||
|
||||
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2498:
|
||||
// deeply nested parenthesized queries used to take O(2^n) time because the
|
||||
// top-level `ast()` parser tried `boolean_expr` first and re-parsed the
|
||||
// inner `occur_leaf` when it backtracked to `single_leaf`. Depth 60 would
|
||||
// take ~10^18 operations under the regression; with the fix it parses
|
||||
// instantly. We use `test_parse_query_to_ast_helper` so this test would
|
||||
// never finish if the regression returned.
|
||||
#[test]
|
||||
fn test_parse_deeply_nested_query() {
|
||||
let depth = 60;
|
||||
let leading: String = "(".repeat(depth);
|
||||
let trailing: String = ")".repeat(depth);
|
||||
let query = format!("{leading}title:test{trailing}");
|
||||
test_parse_query_to_ast_helper(&query, r#""title":test"#);
|
||||
|
||||
let query_with_plus = format!("+{leading}title:test{trailing}");
|
||||
test_parse_query_to_ast_helper(&query_with_plus, r#""title":test"#);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,8 +20,8 @@ use crate::aggregation::metric::{
|
||||
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
|
||||
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
|
||||
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
TopHitsSegmentCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
|
||||
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
|
||||
@@ -413,12 +413,38 @@ pub(crate) fn build_segment_agg_collector(
|
||||
}
|
||||
AggKind::Cardinality => {
|
||||
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
|
||||
Ok(Box::new(SegmentCardinalityCollector::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
)))
|
||||
// For str columns, choose the per-bucket entries representation
|
||||
// based on the segment's column.max_value():
|
||||
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
|
||||
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
|
||||
// For non-str columns the `entries` field is unused (values go
|
||||
// straight into the HLL sketch); we still pick `TermOrdSet`
|
||||
// because its empty Sparse(FxHashSet) costs nothing.
|
||||
let is_str = req_data.column_type == ColumnType::Str;
|
||||
let max_term_ord_inclusive = if is_str {
|
||||
req_data.accessor.max_value()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let collector: Box<dyn SegmentAggregationCollector> =
|
||||
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
|
||||
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
max_term_ord_inclusive,
|
||||
))
|
||||
} else {
|
||||
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
max_term_ord_inclusive,
|
||||
))
|
||||
};
|
||||
Ok(collector)
|
||||
}
|
||||
AggKind::StatsKind(stats_type) => {
|
||||
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
|
||||
@@ -985,8 +1011,12 @@ fn build_terms_or_cardinality_nodes(
|
||||
let str_col = str_dict_column
|
||||
.as_ref()
|
||||
.expect("str_dict_column must exist for string column");
|
||||
allowed_term_ids =
|
||||
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
|
||||
allowed_term_ids = build_allowed_term_ids_for_str(
|
||||
str_col,
|
||||
&req.include,
|
||||
&req.exclude,
|
||||
missing.is_some(),
|
||||
)?;
|
||||
};
|
||||
let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
|
||||
accessor,
|
||||
@@ -1002,10 +1032,20 @@ fn build_terms_or_cardinality_nodes(
|
||||
(idx_in_req_data, AggKind::Terms)
|
||||
}
|
||||
TermsOrCardinalityRequest::Cardinality(ref req) => {
|
||||
// `str_dict_column` is computed once per field; for JSON paths
|
||||
// with mixed types it's `Some` even on the numeric req_data.
|
||||
// Cardinality only consults it for the str column path, so
|
||||
// gate by column_type to avoid driving non-str collectors
|
||||
// through the coupon-cache path.
|
||||
let str_dict_column_for_req = if column_type == ColumnType::Str {
|
||||
str_dict_column.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
|
||||
accessor,
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
str_dict_column: str_dict_column_for_req,
|
||||
missing_value_for_accessor,
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
@@ -1025,16 +1065,21 @@ fn build_terms_or_cardinality_nodes(
|
||||
|
||||
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
|
||||
/// include/exclude parameters.
|
||||
///
|
||||
/// When `reserve_missing_sentinel` is true, the bitset will have 1 additional slot for the missing
|
||||
/// term ordinal
|
||||
fn build_allowed_term_ids_for_str(
|
||||
str_col: &StrColumn,
|
||||
include: &Option<IncludeExcludeParam>,
|
||||
exclude: &Option<IncludeExcludeParam>,
|
||||
reserve_missing_sentinel: bool,
|
||||
) -> crate::Result<Option<BitSet>> {
|
||||
let mut allowed: Option<BitSet> = None;
|
||||
let num_terms = str_col.dictionary().num_terms() as u32;
|
||||
let missing_sentinel_adjustment = if reserve_missing_sentinel { 1 } else { 0 };
|
||||
let allowed_capacity = str_col.dictionary().num_terms() as u32 + missing_sentinel_adjustment;
|
||||
if let Some(include) = include {
|
||||
// add matches
|
||||
allowed = Some(BitSet::with_max_value(num_terms));
|
||||
allowed = Some(BitSet::with_max_value(allowed_capacity));
|
||||
let allowed = allowed.as_mut().unwrap();
|
||||
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
|
||||
};
|
||||
@@ -1042,7 +1087,7 @@ fn build_allowed_term_ids_for_str(
|
||||
if let Some(exclude) = exclude {
|
||||
if allowed.is_none() {
|
||||
// Start with all terms allowed
|
||||
allowed = Some(BitSet::with_max_value_and_full(num_terms));
|
||||
allowed = Some(BitSet::with_max_value_and_full(allowed_capacity));
|
||||
}
|
||||
let allowed = allowed.as_mut().unwrap();
|
||||
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;
|
||||
|
||||
@@ -115,6 +115,71 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
|
||||
fast_field_names
|
||||
}
|
||||
|
||||
/// Validates that all fields referenced in the aggregation request exist in the schema
|
||||
/// and are configured as fast fields.
|
||||
///
|
||||
/// This is a convenience function for upfront validation before executing aggregations.
|
||||
/// Returns an error if any field doesn't exist or is not a fast field.
|
||||
///
|
||||
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
|
||||
/// default lenient behavior (returning empty results for missing fields) supports
|
||||
/// schema evolution and federated queries where the same request runs against segments
|
||||
/// or indices with different schemas.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
|
||||
/// use tantivy::schema::{Schema, FAST};
|
||||
/// use tantivy::Index;
|
||||
///
|
||||
/// # fn main() -> tantivy::Result<()> {
|
||||
/// // Create a simple index
|
||||
/// let mut schema_builder = Schema::builder();
|
||||
/// schema_builder.add_f64_field("price", FAST);
|
||||
/// let schema = schema_builder.build();
|
||||
/// let index = Index::create_in_ram(schema);
|
||||
///
|
||||
/// // Parse aggregation request
|
||||
/// let agg_req: Aggregations = serde_json::from_str(r#"{
|
||||
/// "avg_price": { "avg": { "field": "price" } }
|
||||
/// }"#)?;
|
||||
///
|
||||
/// let reader = index.reader()?;
|
||||
/// let searcher = reader.searcher();
|
||||
///
|
||||
/// // Validate fields before executing
|
||||
/// for segment_reader in searcher.segment_readers() {
|
||||
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn validate_aggregation_fields_exist(
|
||||
aggs: &Aggregations,
|
||||
reader: &crate::SegmentReader,
|
||||
) -> crate::Result<()> {
|
||||
let field_names = get_fast_field_names(aggs);
|
||||
let schema = reader.schema();
|
||||
|
||||
for field_name in field_names {
|
||||
// Check if the field is either directly in the schema or could be part of a json field
|
||||
// present in the schema, and verify it's a fast field.
|
||||
if let Some((field, _path)) = schema.find_field(&field_name) {
|
||||
let field_type = schema.get_field_entry(field).field_type();
|
||||
if !field_type.is_fast() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field '{}' is not a fast field. Aggregations require fast fields.",
|
||||
field_name
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Err(crate::TantivyError::FieldNotFound(field_name));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// All aggregation types.
|
||||
pub enum AggregationVariants {
|
||||
|
||||
@@ -1436,3 +1436,46 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregation_field_validation_helper() {
|
||||
// Test the standalone validation helper function for field validation
|
||||
let index = get_test_index_2_segments(false).unwrap();
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
// Test with invalid field
|
||||
let agg_req: Aggregations = serde_json::from_str(
|
||||
r#"{
|
||||
"avg_test": {
|
||||
"avg": { "field": "nonexistent_field" }
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result =
|
||||
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(crate::TantivyError::FieldNotFound(field_name)) => {
|
||||
assert_eq!(field_name, "nonexistent_field");
|
||||
}
|
||||
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
|
||||
}
|
||||
|
||||
// Test with valid field
|
||||
let agg_req: Aggregations = serde_json::from_str(
|
||||
r#"{
|
||||
"avg_test": {
|
||||
"avg": { "field": "score" }
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result =
|
||||
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let mem_pre = self.get_memory_consumption(parent_bucket_id);
|
||||
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
|
||||
|
||||
for doc in docs {
|
||||
@@ -172,7 +172,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data.context.limits.add_memory_consumed(mem_delta)?;
|
||||
}
|
||||
@@ -199,14 +199,22 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// Composite is a multi-bucket agg with no single value to extract.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCompositeCollector {
|
||||
fn get_memory_consumption(&self) -> u64 {
|
||||
self.parent_buckets
|
||||
.iter()
|
||||
.map(|m| m.memory_consumption())
|
||||
.sum()
|
||||
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
|
||||
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
|
||||
}
|
||||
|
||||
pub(crate) fn from_req_and_validate(
|
||||
|
||||
@@ -559,34 +559,30 @@ mod tests {
|
||||
page_size,
|
||||
agg_req,
|
||||
);
|
||||
if page_idx + 1 < page_count {
|
||||
assert!(
|
||||
res["my_composite"].get("after_key").is_some(),
|
||||
"expected after_key on all but last page"
|
||||
);
|
||||
after_key = Some(res["my_composite"]["after_key"].clone());
|
||||
} else if res["my_composite"].get("after_key").is_some() {
|
||||
// currently we sometime have an after_key on the last page,
|
||||
// check that the next "page" is empty
|
||||
let agg_req_json = json!({
|
||||
"my_composite": {
|
||||
"composite": {
|
||||
"sources": composite_agg_sources,
|
||||
"size": page_size,
|
||||
"after": res["my_composite"]["after_key"].clone(),
|
||||
}
|
||||
}
|
||||
});
|
||||
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
|
||||
let res = exec_request(agg_req.clone(), index).unwrap();
|
||||
assert_eq!(
|
||||
res["my_composite"]["buckets"],
|
||||
json!([]),
|
||||
"expected no buckets when using after_key from last page, query: {:?}",
|
||||
agg_req
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
res["my_composite"].get("after_key").is_some(),
|
||||
"expected after_key on every non-empty page"
|
||||
);
|
||||
after_key = Some(res["my_composite"]["after_key"].clone());
|
||||
}
|
||||
// Using the after_key from the last page must yield an empty page.
|
||||
let agg_req_json = json!({
|
||||
"my_composite": {
|
||||
"composite": {
|
||||
"sources": composite_agg_sources,
|
||||
"size": page_size,
|
||||
"after": after_key,
|
||||
}
|
||||
}
|
||||
});
|
||||
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
|
||||
let res = exec_request(agg_req.clone(), index).unwrap();
|
||||
assert_eq!(
|
||||
res["my_composite"]["buckets"],
|
||||
json!([]),
|
||||
"expected no buckets when using after_key from last page, query: {:?}",
|
||||
agg_req
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,8 +707,28 @@ mod tests {
|
||||
{"key": {"myterm": "terme"}, "doc_count": 1}
|
||||
])
|
||||
);
|
||||
assert!(res["my_composite"].get("after_key").is_none());
|
||||
|
||||
// paginating past last page should be empty
|
||||
let agg_req_json = json!({
|
||||
"my_composite": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{"myterm": {"terms": {"field": "string_id"}}}
|
||||
],
|
||||
"size": 3,
|
||||
"after": &res["my_composite"]["after_key"]
|
||||
}
|
||||
}
|
||||
});
|
||||
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
|
||||
let res = exec_request(agg_req.clone(), &index).unwrap();
|
||||
assert!(res["my_composite"].get("after_key").is_none());
|
||||
assert_eq!(
|
||||
res["my_composite"]["buckets"],
|
||||
json!([]),
|
||||
"expected no buckets when using after_key from last page, query: {:?}",
|
||||
agg_req
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -820,7 +836,10 @@ mod tests {
|
||||
{"key": {"myterm": "apple"}, "doc_count": 1}
|
||||
])
|
||||
);
|
||||
assert!(res["fruity_aggreg"].get("after_key").is_none());
|
||||
assert_eq!(
|
||||
res["fruity_aggreg"]["after_key"],
|
||||
json!({"myterm": "str:apple"})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1792,7 +1811,14 @@ mod tests {
|
||||
{"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1},
|
||||
]),
|
||||
);
|
||||
assert!(res["my_composite"].get("after_key").is_none());
|
||||
let feb_2021_ns = ms_timestamp_from_iso_str("2021-02-01T00:00:00Z") * 1_000_000;
|
||||
assert_eq!(
|
||||
res["my_composite"]["after_key"],
|
||||
json!({
|
||||
"month": format!("dt:{}", feb_2021_ns),
|
||||
"category": "str:books"
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -674,6 +674,17 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result for filter aggregation
|
||||
|
||||
@@ -283,6 +283,11 @@ impl SegmentHistogramBucketEntry {
|
||||
struct HistogramBuckets {
|
||||
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
}
|
||||
impl HistogramBuckets {
|
||||
fn memory_consumption(&self) -> u64 {
|
||||
self.buckets.capacity() as u64 * std::mem::size_of::<SegmentHistogramBucketEntry>() as u64
|
||||
}
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
@@ -324,7 +329,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let mem_pre = self.get_memory_consumption(parent_bucket_id);
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
|
||||
|
||||
let bounds = req.bounds;
|
||||
@@ -358,12 +363,9 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
}
|
||||
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data
|
||||
.context
|
||||
.limits
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
agg_data.context.limits.add_memory_consumed(mem_delta)?;
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
@@ -392,14 +394,24 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// Histogram is a multi-bucket agg with no single value to extract.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentHistogramCollector {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
|
||||
self_mem + buckets_mem
|
||||
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
|
||||
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
|
||||
}
|
||||
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
fn add_intermediate_bucket_result(
|
||||
&mut self,
|
||||
|
||||
@@ -328,6 +328,17 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// Range is a multi-bucket agg with no single value to extract.
|
||||
None
|
||||
}
|
||||
}
|
||||
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
|
||||
@@ -352,19 +352,15 @@ pub(crate) fn build_segment_term_collector(
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate sub aggregation exists when ordering by sub-aggregation.
|
||||
{
|
||||
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
|
||||
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
|
||||
|
||||
node.get_sub_agg(agg_name, &req_data.per_request)
|
||||
.ok_or_else(|| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"could not find aggregation with name {agg_name} in metric \
|
||||
sub_aggregations"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
// Validate that the referenced sub-aggregation exists when ordering by one.
|
||||
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
|
||||
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
|
||||
node.get_sub_agg(agg_name, &req_data.per_request)
|
||||
.ok_or_else(|| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"could not find aggregation with name {agg_name} in metric sub_aggregations"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Build sub-aggregation blueprint if there are children.
|
||||
@@ -809,7 +805,7 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
// let mem_pre = self.get_memory_consumption();
|
||||
let mem_pre = self.get_memory_consumption(parent_bucket_id);
|
||||
|
||||
let req_data = &mut self.terms_req_data;
|
||||
|
||||
@@ -853,16 +849,13 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
|
||||
}
|
||||
}
|
||||
|
||||
// let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
// if mem_delta > 0 {
|
||||
// agg_data
|
||||
// .context
|
||||
// .limits
|
||||
// .add_memory_consumed(mem_delta as u64)?;
|
||||
// }
|
||||
|
||||
// After commenting out -> 6000ms -> 36ms
|
||||
|
||||
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data
|
||||
.context
|
||||
.limits
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
@@ -890,6 +883,17 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// Terms is a multi-bucket agg with no single value to extract.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Missing value are represented as a sentinel value in the column.
|
||||
@@ -949,11 +953,9 @@ where
|
||||
TermMap: TermAggregationMap,
|
||||
B: SubAggBuffer,
|
||||
{
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
self.parent_buckets
|
||||
.iter()
|
||||
.map(|b| b.get_memory_consumption())
|
||||
.sum()
|
||||
#[inline]
|
||||
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> usize {
|
||||
self.parent_buckets[parent_bucket_id as usize].get_memory_consumption()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -965,9 +967,6 @@ where
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
|
||||
|
||||
let order_by_sub_aggregation =
|
||||
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
|
||||
|
||||
match &term_req.req.order.target {
|
||||
OrderTarget::Key => {
|
||||
// We rely on the fact, that term ordinals match the order of the strings
|
||||
@@ -979,10 +978,37 @@ where
|
||||
entries.sort_unstable_by_key(|bucket| bucket.0);
|
||||
}
|
||||
}
|
||||
OrderTarget::SubAggregation(_name) => {
|
||||
// don't sort and cut off since it's hard to make assumptions on the quality of the
|
||||
// results when cutting off du to unknown nature of the sub_aggregation (possible
|
||||
// to check).
|
||||
OrderTarget::SubAggregation(sub_agg_path) => {
|
||||
// Peek segment-level metric values, sort, then fall through to
|
||||
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
|
||||
// by a sub-agg: top-K results are approximate and may differ from the
|
||||
// global ordering, especially for non-monotonic metrics like avg/min.
|
||||
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"Could not find sub-aggregation collector for path {sub_agg_path}"
|
||||
))
|
||||
})?;
|
||||
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
|
||||
// Fetch values up-front; otherwise sort would re-compute per comparison
|
||||
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
|
||||
.into_iter()
|
||||
.map(|bucket| {
|
||||
let metric_value = coll
|
||||
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
|
||||
.unwrap_or(0.0);
|
||||
(metric_value, bucket)
|
||||
})
|
||||
.collect();
|
||||
if term_req.req.order.order == Order::Desc {
|
||||
keyed.sort_unstable_by(|a, b| {
|
||||
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
} else {
|
||||
keyed.sort_unstable_by(|a, b| {
|
||||
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
}
|
||||
entries = keyed.into_iter().map(|(_, e)| e).collect();
|
||||
}
|
||||
OrderTarget::Count => {
|
||||
if term_req.req.order.order == Order::Desc {
|
||||
@@ -993,11 +1019,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
|
||||
(0, 0)
|
||||
} else {
|
||||
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
|
||||
};
|
||||
let (term_doc_count_before_cutoff, sum_other_doc_count) =
|
||||
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
|
||||
|
||||
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
|
||||
dict.reserve(entries.len());
|
||||
@@ -1228,7 +1251,6 @@ pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
|
||||
mod tests {
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::time::Instant;
|
||||
|
||||
use common::DateTime;
|
||||
use time::{Date, Month};
|
||||
@@ -1242,10 +1264,9 @@ mod tests {
|
||||
get_test_index_from_terms, get_test_index_from_values_and_terms,
|
||||
};
|
||||
use crate::aggregation::{AggregationLimitsGuard, DistributedAggregationCollector};
|
||||
use crate::collector::{Collector, default_collect_segment_impl};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, EnableScoring, Query};
|
||||
use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING};
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{IntoIpv6Addr, Schema, FAST, INDEXED, STRING, TEXT};
|
||||
use crate::{Index, IndexWriter};
|
||||
|
||||
#[test]
|
||||
@@ -1774,6 +1795,263 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
|
||||
terms_aggregation_order_by_cardinality_desc(true)
|
||||
}
|
||||
#[test]
|
||||
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
|
||||
terms_aggregation_order_by_cardinality_desc(false)
|
||||
}
|
||||
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
|
||||
// Distinct score values per bucket key: A→5, B→1, C→3.
|
||||
// Order by cardinality desc must yield A, C, B.
|
||||
let segment_and_terms = vec![vec![
|
||||
(1.0, "A".to_string()),
|
||||
(2.0, "A".to_string()),
|
||||
(3.0, "A".to_string()),
|
||||
(4.0, "A".to_string()),
|
||||
(5.0, "A".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(1.0, "C".to_string()),
|
||||
(2.0, "C".to_string()),
|
||||
(3.0, "C".to_string()),
|
||||
]];
|
||||
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "card": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"card": { "cardinality": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
|
||||
|
||||
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
|
||||
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "card": "asc" }
|
||||
},
|
||||
"aggs": {
|
||||
"card": { "cardinality": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
|
||||
|
||||
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
|
||||
// and `sum_other_doc_count` reflects the dropped B (3 docs).
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"size": 2,
|
||||
"order": { "card": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"card": { "cardinality": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
|
||||
|
||||
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"size": 2,
|
||||
"order": { "card": "asc" }
|
||||
},
|
||||
"aggs": {
|
||||
"card": { "cardinality": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
|
||||
terms_aggregation_order_by_sum(true)
|
||||
}
|
||||
#[test]
|
||||
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
|
||||
terms_aggregation_order_by_sum(false)
|
||||
}
|
||||
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
|
||||
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
|
||||
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
|
||||
let segment_and_terms = vec![
|
||||
vec![
|
||||
(1.0, "A".to_string()),
|
||||
(2.0, "A".to_string()),
|
||||
(3.0, "A".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(1.0, "C".to_string()),
|
||||
],
|
||||
vec![
|
||||
(4.0, "A".to_string()),
|
||||
(5.0, "A".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(1.0, "B".to_string()),
|
||||
(2.0, "C".to_string()),
|
||||
(3.0, "C".to_string()),
|
||||
],
|
||||
];
|
||||
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
|
||||
|
||||
// Desc on a Sum metric engages the fast path (column is U64).
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "total": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"total": { "sum": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
|
||||
|
||||
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
|
||||
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "total": "asc" }
|
||||
},
|
||||
"aggs": {
|
||||
"total": { "sum": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
|
||||
|
||||
// size=2 desc with cutoff: top-2 by sum (A, C).
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"size": 2,
|
||||
"order": { "total": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"total": { "sum": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
|
||||
|
||||
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "mystats.sum": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"mystats": { "stats": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
|
||||
|
||||
// Sum on a signed column (I64) takes the same cutoff path. Results may be
|
||||
// approximate near the boundary on adversarial data, but for this dataset the
|
||||
// top-K is unambiguous.
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "total": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"total": { "sum": { "field": "score_i64" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
|
||||
|
||||
// Order by extended_stats sub-property exercises compute_metric_value on the
|
||||
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_texts": {
|
||||
"terms": {
|
||||
"field": "string_id",
|
||||
"order": { "ext.max": "desc" }
|
||||
},
|
||||
"aggs": {
|
||||
"ext": { "extended_stats": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
|
||||
terms_aggregation_test_order_key_merge_segment(true)
|
||||
@@ -2940,102 +3218,100 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_terms_double_nesting() {
|
||||
fn prep_index_with_n_unique_terms_plus_one_null(n: u64) -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let outer_field = schema_builder.add_text_field("outer_term", STRING | FAST);
|
||||
let inner_field = schema_builder.add_text_field("inner_term", STRING | FAST);
|
||||
let id_field = schema_builder.add_u64_field("id", INDEXED);
|
||||
let title_field = schema_builder.add_text_field("title", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
// set to one thread to guarantee all docs end up in the same segment
|
||||
let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
|
||||
|
||||
let outer_values = (0..10_000)
|
||||
.map(|i| format!("outer_{i}"))
|
||||
.collect::<Vec<_>>();
|
||||
let inner_values = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
|
||||
{
|
||||
let mut index_writer: IndexWriter = index.writer_with_num_threads(1, 200_000_000).unwrap();
|
||||
for doc_id in 0..1_000_000u64 {
|
||||
let outer_val = &outer_values[doc_id as usize % outer_values.len()];
|
||||
let inner_val = inner_values[doc_id as usize % inner_values.len()];
|
||||
index_writer.add_document(doc!(
|
||||
outer_field => outer_val.as_str(),
|
||||
inner_field => inner_val,
|
||||
)).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
writer.add_document(doc!(
|
||||
id_field => 0u64,
|
||||
))?;
|
||||
for i in 1u64..=n {
|
||||
let title = format!("foo{i}");
|
||||
writer.add_document(doc!(
|
||||
id_field => i,
|
||||
title_field => title,
|
||||
))?;
|
||||
}
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"outer": {
|
||||
"terms": { "field": "outer_term", "size": 10 },
|
||||
"aggs": {
|
||||
"inner": {
|
||||
"terms": { "field": "inner_term" }
|
||||
|
||||
writer.commit()?;
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn null_bitset_bounds_check_regression() -> crate::Result<()> {
|
||||
// include cases
|
||||
for i in 0..=4 {
|
||||
let index = prep_index_with_n_unique_terms_plus_one_null(i * 64)?;
|
||||
let normal_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_bool": {
|
||||
"terms": {
|
||||
"field": "title",
|
||||
"missing": "__NULL__",
|
||||
"size": 1000,
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
let include_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_bool": {
|
||||
"terms": {
|
||||
"field": "title",
|
||||
"include": "foo(.*)",
|
||||
"missing": "__NULL__",
|
||||
"size": 1000,
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
let exclude_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_bool": {
|
||||
"terms": {
|
||||
"field": "title",
|
||||
"exclude": "foo(.*)",
|
||||
"missing": "__NULL__",
|
||||
"size": 1000,
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
|
||||
let normal_res = exec_request(normal_req, &index)?;
|
||||
let normal_buckets = normal_res["my_bool"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(
|
||||
normal_buckets.len(),
|
||||
(i * 64) as usize + 1,
|
||||
"The normal request should return all 'foo' buckets, plus the missing term bucket",
|
||||
);
|
||||
|
||||
let include_res = exec_request(include_req, &index)?;
|
||||
eprintln!("include_res: {include_res:?}");
|
||||
let include_buckets = include_res["my_bool"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(
|
||||
include_buckets.len(),
|
||||
(i * 64) as usize,
|
||||
"The include request should return all 'foo' buckets, and not the missing term \
|
||||
bucket",
|
||||
);
|
||||
assert!(include_buckets
|
||||
.iter()
|
||||
.all(|b| b["key"].as_str().unwrap().starts_with("foo")));
|
||||
|
||||
let exclude_res = exec_request(exclude_req, &index)?;
|
||||
let exclude_buckets = exclude_res["my_bool"]["buckets"].as_array().unwrap();
|
||||
if i != 0 {
|
||||
// TODO: Remove this if after fixing exclude + missing bug
|
||||
assert_eq!(
|
||||
exclude_buckets.len(),
|
||||
1,
|
||||
"The exclude request should exclude all 'foo' buckets, and only the missing \
|
||||
term bucket",
|
||||
);
|
||||
assert_eq!(exclude_buckets[0]["key"], "__NULL__");
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector =
|
||||
crate::aggregation::AggregationCollector::from_aggs(agg_req, Default::default());
|
||||
|
||||
assert_eq!(searcher.segment_readers().len(), 1);
|
||||
let segment_reader = searcher.segment_reader(0u32);
|
||||
let all_weight = AllQuery.weight(EnableScoring::disabled_from_schema(&schema)).unwrap();
|
||||
let mut segment_collector = collector.for_segment(0u32, segment_reader).unwrap();
|
||||
let start = Instant::now();
|
||||
default_collect_segment_impl(&mut segment_collector, &*all_weight, segment_reader, false).unwrap();
|
||||
dbg!(start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_terms_simple_nesting() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let outer_field = schema_builder.add_text_field("outer_term", STRING | FAST);
|
||||
let inner_field = schema_builder.add_text_field("inner_term", STRING | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
let outer_values = (0..10_000)
|
||||
.map(|i| format!("outer_{i}"))
|
||||
.collect::<Vec<_>>();
|
||||
let inner_values = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
|
||||
{
|
||||
let mut index_writer: IndexWriter = index.writer_with_num_threads(1, 200_000_000).unwrap();
|
||||
for doc_id in 0..1_000_000u64 {
|
||||
let outer_val = &outer_values[doc_id as usize % outer_values.len()];
|
||||
let inner_val = inner_values[doc_id as usize % inner_values.len()];
|
||||
index_writer.add_document(doc!(
|
||||
outer_field => outer_val.as_str(),
|
||||
inner_field => inner_val,
|
||||
)).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"outer": {
|
||||
"terms": { "field": "outer_term", "size": 10 },
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector =
|
||||
crate::aggregation::AggregationCollector::from_aggs(agg_req, Default::default());
|
||||
|
||||
assert_eq!(searcher.segment_readers().len(), 1);
|
||||
let segment_reader = searcher.segment_reader(0u32);
|
||||
let all_weight = AllQuery.weight(EnableScoring::disabled_from_schema(&schema)).unwrap();
|
||||
let mut segment_collector = collector.for_segment(0u32, segment_reader).unwrap();
|
||||
let start = Instant::now();
|
||||
default_collect_segment_impl(&mut segment_collector, &*all_weight, segment_reader, false).unwrap();
|
||||
dbg!(start.elapsed());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,6 +177,17 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1004,24 +1004,20 @@ impl IntermediateCompositeBucketResult {
|
||||
) -> crate::Result<BucketResult> {
|
||||
let trimmed_entry_vec =
|
||||
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
|
||||
let after_key = if trimmed_entry_vec.len() == req.size as usize {
|
||||
trimmed_entry_vec
|
||||
.last()
|
||||
.map(|bucket| {
|
||||
let (intermediate_key, _entry) = bucket;
|
||||
intermediate_key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, intermediate_key)| {
|
||||
let source = &req.sources[idx];
|
||||
(source.name().to_string(), intermediate_key.clone().into())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap()
|
||||
} else {
|
||||
FxHashMap::default()
|
||||
};
|
||||
let after_key = trimmed_entry_vec
|
||||
.last()
|
||||
.map(|bucket| {
|
||||
let (intermediate_key, _entry) = bucket;
|
||||
intermediate_key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, intermediate_key)| {
|
||||
let source = &req.sources[idx];
|
||||
(source.name().to_string(), intermediate_key.clone().into())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let buckets = trimmed_entry_vec
|
||||
.into_iter()
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::io;
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use common::{BitSet, TinySet};
|
||||
use datasketches::hll::{Coupon, HllSketch, HllType, HllUnion};
|
||||
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
@@ -20,6 +21,12 @@ use crate::TantivyError;
|
||||
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
|
||||
const LG_K: u8 = 11;
|
||||
|
||||
/// Promote FxHashSet<u64> -> PagedBitset at ~3% density (`len * 32 >
|
||||
/// dict_num_terms`). Past this point the bitset (~`dict_num_terms / 7.5`
|
||||
/// bytes) is smaller than the hashset (~10 B/entry minimum) and avoids
|
||||
/// the per-insert hash.
|
||||
const PROMOTION_RATIO: u64 = 32;
|
||||
|
||||
/// # Cardinality
|
||||
///
|
||||
/// The cardinality aggregation allows for computing an estimate
|
||||
@@ -159,8 +166,12 @@ impl CouponCache {
|
||||
let should_use_dense =
|
||||
highest_term_ord < 1_000_000u64 || highest_term_ord < num_terms as u64 * 3u64;
|
||||
if should_use_dense {
|
||||
let mut coupon_map: Vec<Coupon> = vec![Coupon::EMPTY; highest_term_ord as usize + 1];
|
||||
for (term_ord, coupon) in term_ords.into_iter().zip(coupons.into_iter()) {
|
||||
// We don't really care about the value here. We will populate all the values we will
|
||||
// read anyway.
|
||||
let uninitialized_coupon = Coupon::from_hash(0);
|
||||
let mut coupon_map: Vec<Coupon> =
|
||||
vec![uninitialized_coupon; highest_term_ord as usize + 1];
|
||||
for (term_ord, coupon) in term_ords.into_iter().zip(coupons) {
|
||||
coupon_map[term_ord as usize] = coupon;
|
||||
}
|
||||
CouponCache::Dense {
|
||||
@@ -177,9 +188,263 @@ impl CouponCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
// =================================================================
|
||||
// PagedBitset: a sparse bitset indexed by term_ord.
|
||||
//
|
||||
// Used as the dense alternative to FxHashSet<u64> once a string
|
||||
// cardinality bucket has accumulated enough unique term ordinals.
|
||||
// Memory is bounded to (touched pages) * (page bytes), not
|
||||
// (max_term_ord / 8).
|
||||
//
|
||||
// Page geometry mirrors `PagedTermMap` in `term_agg.rs`: 1024 ords
|
||||
// per page, lazy `Vec<Option<Box<Page>>>` directory.
|
||||
// =================================================================
|
||||
const BITSET_PAGE_SHIFT: u32 = 10;
|
||||
const BITSET_PAGE_BITS: u64 = 1u64 << BITSET_PAGE_SHIFT; // 1024
|
||||
const BITSET_PAGE_MASK: u64 = BITSET_PAGE_BITS - 1;
|
||||
const BITSET_WORDS_PER_PAGE: usize = (BITSET_PAGE_BITS / 64) as usize; // 16
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PagedBitsetPage {
|
||||
words: [TinySet; BITSET_WORDS_PER_PAGE],
|
||||
}
|
||||
|
||||
impl PagedBitsetPage {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
words: [TinySet::empty(); BITSET_WORDS_PER_PAGE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PagedBitset {
|
||||
pages: Vec<Option<Box<PagedBitsetPage>>>,
|
||||
/// Cached number of set bits, maintained on insert.
|
||||
count: u64,
|
||||
}
|
||||
|
||||
impl PagedBitset {
|
||||
/// Allocates a directory big enough to hold ords up to and including
|
||||
/// `max_term_ord`. Pages are allocated lazily on first set.
|
||||
fn with_max_term_ord(max_term_ord: u64) -> Self {
|
||||
let max_page_idx = (max_term_ord >> BITSET_PAGE_SHIFT) as usize;
|
||||
let num_pages = max_page_idx + 1;
|
||||
Self {
|
||||
pages: vec![None; num_pages],
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn insert(&mut self, term_ord: u64) {
|
||||
let page_idx = (term_ord >> BITSET_PAGE_SHIFT) as usize;
|
||||
let intra = term_ord & BITSET_PAGE_MASK;
|
||||
let word_idx = (intra >> 6) as usize;
|
||||
let bit_idx = (intra & 63) as u32;
|
||||
|
||||
let page = match &mut self.pages[page_idx] {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
self.pages[page_idx] = Some(Box::new(PagedBitsetPage::new()));
|
||||
self.pages[page_idx].as_mut().unwrap()
|
||||
}
|
||||
};
|
||||
if page.words[word_idx].insert_mut(bit_idx) {
|
||||
self.count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of set bits. O(1).
|
||||
#[inline]
|
||||
fn len(&self) -> u64 {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Iterate set ords in ascending order.
|
||||
fn iter_sorted(&self) -> impl Iterator<Item = u64> + '_ {
|
||||
self.pages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(page_idx, page_opt)| page_opt.as_ref().map(|p| (page_idx, p)))
|
||||
.flat_map(|(page_idx, page)| {
|
||||
let page_base_ord = (page_idx as u64) << BITSET_PAGE_SHIFT;
|
||||
page.words
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(move |(word_idx, &word)| {
|
||||
let word_base_ord = page_base_ord + (word_idx as u64) * 64;
|
||||
word.into_iter()
|
||||
.map(move |bit| word_base_ord + u64::from(bit))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold below which we use `BitSet` instead of `TermOrdSet`.
|
||||
///
|
||||
/// Both `BitSet` and `FxHashSet<u64>` have the same 32-byte struct, so the comparison is heap only:
|
||||
/// * `BitSet` at T=256: 5 `TinySet` words covering 258 bits (with the missing-value sentinel) =
|
||||
/// 40 bytes.
|
||||
/// * `FxHashSet<u64>` after one insert: 4-bucket hashbrown table ≈ 56 bytes
|
||||
pub(crate) const BITSET_MAX_TERM_ORD: u64 = 256;
|
||||
|
||||
// =================================================================
|
||||
// TermOrdAccumulator: per-bucket abstraction over the entries set.
|
||||
//
|
||||
// Implementations:
|
||||
// - `BitSet` (from `common`): used when `column.max_value()` is small (< BITSET_MAX_TERM_ORD).
|
||||
// Pre-allocated, no promotion.
|
||||
// - `TermOrdSet`: adaptive, starts as FxHashSet and promotes to a paged bitset when occupancy
|
||||
// crosses the density threshold (only if promotion is enabled — typically gated on top-level
|
||||
// aggregation).
|
||||
//
|
||||
// The trait lets `SegmentCardinalityCollector` be generic over the choice
|
||||
// so the hot collect() loop monomorphizes to a direct call (no enum
|
||||
// dispatch per insert).
|
||||
// =================================================================
|
||||
pub(crate) trait TermOrdAccumulator: Sized {
|
||||
/// Construct an empty accumulator.
|
||||
/// `max_term_ord_inclusive` is the largest term_ord that may be
|
||||
/// inserted (used to size pre-allocated bitsets and the dense bitset
|
||||
/// on promotion).
|
||||
fn new(max_term_ord_inclusive: u64) -> Self;
|
||||
fn insert(&mut self, term_ord: u64);
|
||||
/// Bulk insert. Implementations may override to hoist any inner
|
||||
/// dispatch outside the loop. Default loops `insert`.
|
||||
#[inline]
|
||||
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
|
||||
for ord in ords {
|
||||
self.insert(ord);
|
||||
}
|
||||
}
|
||||
/// Hook called once per ingested block. Adaptive impls use this to
|
||||
/// decide on sparse->dense promotion.
|
||||
fn maybe_compact(&mut self) {}
|
||||
fn len(&self) -> usize;
|
||||
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_;
|
||||
}
|
||||
|
||||
impl TermOrdAccumulator for BitSet {
|
||||
#[inline]
|
||||
fn new(max_term_ord_inclusive: u64) -> Self {
|
||||
// `BitSet::with_max_value(M)` accepts ords in [0, M).
|
||||
// We need ords up to and including `max_term_ord_inclusive`, plus
|
||||
// the missing-value sentinel `column.max_value() + 1`.
|
||||
BitSet::with_max_value((max_term_ord_inclusive + 2) as u32)
|
||||
}
|
||||
#[inline]
|
||||
fn insert(&mut self, term_ord: u64) {
|
||||
BitSet::insert(self, term_ord as u32);
|
||||
}
|
||||
#[inline]
|
||||
fn len(&self) -> usize {
|
||||
BitSet::len(self)
|
||||
}
|
||||
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
|
||||
// `BitSet` itself doesn't expose iteration, but
|
||||
// `BitSet::tinyset(bucket)` does. Walk per-bucket and yield each
|
||||
// set bit. The capacity is `max_value()`; iterating to
|
||||
// `div_ceil(64)` covers every possible ord exactly once.
|
||||
let num_buckets = self.max_value().div_ceil(64);
|
||||
(0..num_buckets).flat_map(move |bucket| {
|
||||
let chunk_base = u64::from(bucket) * 64;
|
||||
self.tinyset(bucket)
|
||||
.into_iter()
|
||||
.map(move |bit| chunk_base + u64::from(bit))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =================================================================
|
||||
// TermOrdSet: adaptive sparse->dense accumulator.
|
||||
//
|
||||
// Starts as an FxHashSet (cheap when few ords are seen). When occupancy
|
||||
// crosses `len * PROMOTION_RATIO > max_term_ord_inclusive`, drains into
|
||||
// a `PagedBitset` and continues dense. Promotion is one-way.
|
||||
// =================================================================
|
||||
pub(crate) struct TermOrdSet {
|
||||
inner: TermOrdSetInner,
|
||||
/// Largest term_ord that may be inserted. Used for both sizing the
|
||||
/// dense bitset on promotion and as the promotion-threshold reference.
|
||||
max_term_ord_inclusive: u64,
|
||||
}
|
||||
|
||||
enum TermOrdSetInner {
|
||||
Sparse(FxHashSet<u64>),
|
||||
Dense(PagedBitset),
|
||||
}
|
||||
|
||||
impl TermOrdAccumulator for TermOrdSet {
|
||||
fn new(max_term_ord_inclusive: u64) -> Self {
|
||||
Self {
|
||||
inner: TermOrdSetInner::Sparse(FxHashSet::default()),
|
||||
max_term_ord_inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn insert(&mut self, term_ord: u64) {
|
||||
match &mut self.inner {
|
||||
TermOrdSetInner::Sparse(set) => {
|
||||
set.insert(term_ord);
|
||||
}
|
||||
TermOrdSetInner::Dense(bitset) => bitset.insert(term_ord),
|
||||
}
|
||||
}
|
||||
|
||||
/// Hoist the Sparse/Dense match outside the per-ord loop so that a
|
||||
/// block of inserts dispatches once.
|
||||
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
|
||||
match &mut self.inner {
|
||||
TermOrdSetInner::Sparse(set) => {
|
||||
for ord in ords {
|
||||
set.insert(ord);
|
||||
}
|
||||
}
|
||||
TermOrdSetInner::Dense(bitset) => {
|
||||
for ord in ords {
|
||||
bitset.insert(ord);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_compact(&mut self) {
|
||||
let TermOrdSetInner::Sparse(set) = &mut self.inner else {
|
||||
return;
|
||||
};
|
||||
if set.len() as u64 * PROMOTION_RATIO <= self.max_term_ord_inclusive {
|
||||
return;
|
||||
}
|
||||
// Size for ord <= max_term_ord_inclusive plus the missing sentinel
|
||||
// (column.max_value() + 1, which may equal max_term_ord_inclusive
|
||||
// when the column references every dictionary term).
|
||||
let mut bitset = PagedBitset::with_max_term_ord(self.max_term_ord_inclusive + 1);
|
||||
let set = std::mem::take(set);
|
||||
for ord in set {
|
||||
bitset.insert(ord);
|
||||
}
|
||||
self.inner = TermOrdSetInner::Dense(bitset);
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
match &self.inner {
|
||||
TermOrdSetInner::Sparse(set) => set.len(),
|
||||
TermOrdSetInner::Dense(bitset) => bitset.len() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
|
||||
match &self.inner {
|
||||
TermOrdSetInner::Sparse(set) => itertools::Either::Left(set.iter().copied()),
|
||||
TermOrdSetInner::Dense(bitset) => itertools::Either::Right(bitset.iter_sorted()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SegmentCardinalityCollector<S: TermOrdAccumulator> {
|
||||
/// Buckets are Some(_) until they get consumed by into_intermediate_results().
|
||||
buckets: Vec<Option<SegmentCardinalityCollectorBucket>>,
|
||||
buckets: Vec<Option<SegmentCardinalityCollectorBucket<S>>>,
|
||||
accessor_idx: usize,
|
||||
/// The column accessor to access the fast field values.
|
||||
accessor: Column<u64>,
|
||||
@@ -188,9 +453,13 @@ pub(crate) struct SegmentCardinalityCollector {
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
coupon_cache: Option<CouponCache>,
|
||||
/// Largest term_ord that may be inserted into a bucket. For str columns
|
||||
/// this is `accessor.max_value()`; for non-str columns this is unused
|
||||
/// (no inserts go into `entries`) and set to 0.
|
||||
max_term_ord_inclusive: u64,
|
||||
}
|
||||
|
||||
impl Debug for SegmentCardinalityCollector {
|
||||
impl<S: TermOrdAccumulator> Debug for SegmentCardinalityCollector<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentCardinalityCollector")
|
||||
.field("column_type", &self.column_type)
|
||||
@@ -202,16 +471,21 @@ impl Debug for SegmentCardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SegmentCardinalityCollectorBucket {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
/// Per-bucket state. Shape depends on column kind: str columns dedup
|
||||
/// term ords and only build the HLL sketch at finalization (saves the
|
||||
/// ~96 B `CardinalityCollector` per bucket during collect); numeric/IpAddr
|
||||
/// columns feed the sketch directly during collect.
|
||||
pub(crate) enum SegmentCardinalityCollectorBucket<S: TermOrdAccumulator> {
|
||||
Str(S),
|
||||
Numeric(CardinalityCollector),
|
||||
}
|
||||
impl SegmentCardinalityCollectorBucket {
|
||||
impl<S: TermOrdAccumulator> SegmentCardinalityCollectorBucket<S> {
|
||||
#[inline(always)]
|
||||
pub fn new(column_type: ColumnType) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: FxHashSet::default(),
|
||||
pub fn new(column_type: ColumnType, max_term_ord_inclusive: u64) -> Self {
|
||||
if column_type == ColumnType::Str {
|
||||
Self::Str(S::new(max_term_ord_inclusive))
|
||||
} else {
|
||||
Self::Numeric(CardinalityCollector::new(column_type as u8))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,37 +496,57 @@ impl SegmentCardinalityCollectorBucket {
|
||||
//
|
||||
// If the column is str, then the values are dictionary encoded
|
||||
// and have not been added to the sketch yet.
|
||||
// We need to resolves the term ords accumulated in self.entries
|
||||
// with the coupon cache, and append the results to the sketch.
|
||||
// We need to resolves the term ords accumulated in the str entries
|
||||
// with the coupon cache, and append the results to a fresh sketch.
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
self,
|
||||
coupon_cache_opt: Option<&CouponCache>,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
if let Some(coupon_cache) = coupon_cache_opt {
|
||||
assert!(self.cardinality.sketch.is_empty());
|
||||
append_to_sketch(&self.entries, coupon_cache, &mut self.cardinality);
|
||||
}
|
||||
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
|
||||
let cardinality = match self {
|
||||
Self::Str(entries) => {
|
||||
let mut cardinality = CardinalityCollector::new(ColumnType::Str as u8);
|
||||
if let Some(coupon_cache) = coupon_cache_opt {
|
||||
// Sketch must be empty for str columns: coupons are appended here
|
||||
// from the term_ord set (and not directly during collection).
|
||||
assert!(cardinality.sketch.is_empty());
|
||||
append_to_sketch(&entries, coupon_cache, &mut cardinality);
|
||||
}
|
||||
cardinality
|
||||
}
|
||||
Self::Numeric(cardinality) => cardinality,
|
||||
};
|
||||
Ok(IntermediateMetricResult::Cardinality(cardinality))
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a coupon cache from the given buckets, dictionary, and optional missing value.
|
||||
/// Returns a mapping from term_ord to the hash (coupon) of the associated term.
|
||||
fn build_coupon_cache(
|
||||
buckets: &[Option<SegmentCardinalityCollectorBucket>],
|
||||
fn build_coupon_cache<S: TermOrdAccumulator>(
|
||||
buckets: &[Option<SegmentCardinalityCollectorBucket<S>>],
|
||||
dictionary: &Dictionary,
|
||||
missing_value_opt: Option<&Key>,
|
||||
) -> io::Result<CouponCache> {
|
||||
let term_ords_capacity: usize = buckets
|
||||
.iter()
|
||||
.flatten()
|
||||
.map(|bucket| bucket.entries.len())
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
* 2;
|
||||
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(term_ords_capacity, FxBuildHasher);
|
||||
// Caller restricts this to str cardinality collectors, so every
|
||||
// present bucket must be the `Str` variant. Pass 1 validates and
|
||||
// computes the capacity hint; pass 2 inserts.
|
||||
let mut max_bucket_len = 0usize;
|
||||
for bucket in buckets.iter().flatten() {
|
||||
term_ords_set.extend(bucket.entries.iter().copied());
|
||||
match bucket {
|
||||
SegmentCardinalityCollectorBucket::Str(entries) => {
|
||||
max_bucket_len = max_bucket_len.max(entries.len());
|
||||
}
|
||||
SegmentCardinalityCollectorBucket::Numeric(_) => {
|
||||
return Err(io::Error::other(
|
||||
"build_coupon_cache invoked with a non-str bucket",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(max_bucket_len * 2, FxBuildHasher);
|
||||
for bucket in buckets.iter().flatten() {
|
||||
if let SegmentCardinalityCollectorBucket::Str(entries) = bucket {
|
||||
term_ords_set.extend(entries.iter_ords());
|
||||
}
|
||||
}
|
||||
let mut term_ords: Vec<u64> = term_ords_set.into_iter().collect();
|
||||
term_ords.sort_unstable();
|
||||
@@ -284,8 +578,8 @@ fn build_coupon_cache(
|
||||
Ok(CouponCache::new(term_ords, coupons, missing_coupon_opt))
|
||||
}
|
||||
|
||||
fn append_to_sketch(
|
||||
term_ords: &FxHashSet<u64>,
|
||||
fn append_to_sketch<S: TermOrdAccumulator>(
|
||||
term_ords: &S,
|
||||
coupon_cache: &CouponCache,
|
||||
sketch: &mut CardinalityCollector,
|
||||
) {
|
||||
@@ -294,7 +588,7 @@ fn append_to_sketch(
|
||||
coupon_map,
|
||||
missing_coupon_opt,
|
||||
} => {
|
||||
for &term_ord in term_ords {
|
||||
for term_ord in term_ords.iter_ords() {
|
||||
if let Some(coupon) = coupon_map
|
||||
.get(term_ord as usize)
|
||||
.copied()
|
||||
@@ -308,8 +602,8 @@ fn append_to_sketch(
|
||||
coupon_map,
|
||||
missing_coupon_opt,
|
||||
} => {
|
||||
for term_ord in term_ords {
|
||||
if let Some(coupon) = coupon_map.get(term_ord).copied().or(*missing_coupon_opt) {
|
||||
for term_ord in term_ords.iter_ords() {
|
||||
if let Some(coupon) = coupon_map.get(&term_ord).copied().or(*missing_coupon_opt) {
|
||||
sketch.insert_coupon(coupon);
|
||||
}
|
||||
}
|
||||
@@ -317,12 +611,13 @@ fn append_to_sketch(
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
impl<S: TermOrdAccumulator> SegmentCardinalityCollector<S> {
|
||||
pub fn from_req(
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
accessor: Column<u64>,
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
max_term_ord_inclusive: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: Vec::new(),
|
||||
@@ -331,6 +626,7 @@ impl SegmentCardinalityCollector {
|
||||
accessor,
|
||||
missing_value_for_accessor,
|
||||
coupon_cache: None,
|
||||
max_term_ord_inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,7 +643,9 @@ impl SegmentCardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
impl<S: TermOrdAccumulator + 'static> SegmentAggregationCollector
|
||||
for SegmentCardinalityCollector<S>
|
||||
{
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
@@ -402,31 +700,41 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
));
|
||||
};
|
||||
let col_block_accessor = &agg_data.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
bucket.entries.insert(term_ord);
|
||||
match bucket {
|
||||
SegmentCardinalityCollectorBucket::Str(entries) => {
|
||||
// Promotion check runs on the pre-block state: the first call
|
||||
// sees an empty set (no-op), and the last block of inserts
|
||||
// doesn't trigger a promotion of a set we won't grow further.
|
||||
// The trait dispatches once per block (via `extend_from_iter`)
|
||||
// for adaptive variants and inlines to a tight loop for the
|
||||
// BitSet path.
|
||||
entries.maybe_compact();
|
||||
entries.extend_from_iter(col_block_accessor.iter_vals());
|
||||
}
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
.downcast_arc::<CompactSpaceU64Accessor>()
|
||||
.map_err(|_| {
|
||||
TantivyError::AggregationError(
|
||||
crate::aggregation::AggregationError::InternalError(
|
||||
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
|
||||
.to_string(),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
bucket.cardinality.insert(val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
bucket.cardinality.insert(val);
|
||||
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
|
||||
if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
.downcast_arc::<CompactSpaceU64Accessor>()
|
||||
.map_err(|_| {
|
||||
TantivyError::AggregationError(
|
||||
crate::aggregation::AggregationError::InternalError(
|
||||
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
|
||||
.to_string(),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
cardinality.insert(val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
cardinality.insert(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -439,12 +747,40 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if max_bucket as usize >= self.buckets.len() {
|
||||
let column_type = self.column_type;
|
||||
let max_term_ord_inclusive = self.max_term_ord_inclusive;
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
Some(SegmentCardinalityCollectorBucket::new(self.column_type))
|
||||
Some(SegmentCardinalityCollectorBucket::<S>::new(
|
||||
column_type,
|
||||
max_term_ord_inclusive,
|
||||
))
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.name != sub_agg_name || !sub_agg_property.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let bucket = self.buckets.get(bucket_id as usize)?.as_ref()?;
|
||||
// For string columns the sketch isn't built until finalization; the
|
||||
// term_ord set's len is the exact distinct count. For numeric columns
|
||||
// the sketch is populated during collect.
|
||||
match bucket {
|
||||
SegmentCardinalityCollectorBucket::Str(entries) => Some(entries.len() as f64),
|
||||
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
|
||||
Some(cardinality.sketch.estimate().trunc())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -489,7 +825,7 @@ impl<'de> Deserialize<'de> for CardinalityCollector {
|
||||
impl CardinalityCollector {
|
||||
fn new(salt: u8) -> Self {
|
||||
Self {
|
||||
sketch: HllSketch::new(LG_K, HllType::Hll4),
|
||||
sketch: HllSketch::new(LG_K, HllType::Hll8),
|
||||
salt,
|
||||
}
|
||||
}
|
||||
@@ -520,7 +856,7 @@ impl CardinalityCollector {
|
||||
let mut union = HllUnion::new(LG_K);
|
||||
union.update(&self.sketch);
|
||||
union.update(&right.sketch);
|
||||
self.sketch = union.to_sketch(HllType::Hll4);
|
||||
self.sketch = union.to_sketch(HllType::Hll8);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -592,6 +928,134 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build a single-segment string-cardinality index with 32 unique terms.
|
||||
/// `column.max_value() = 31` is well below `BITSET_MAX_TERM_ORD`,
|
||||
/// so the bucket exercises the `BitSet` path end to end.
|
||||
#[test]
|
||||
fn cardinality_aggregation_test_str_bitset() -> crate::Result<()> {
|
||||
let terms: Vec<String> = (0..32).map(|i| format!("term_{i}")).collect();
|
||||
let term_refs: Vec<Vec<&str>> = terms.iter().map(|t| vec![t.as_str()]).collect::<Vec<_>>();
|
||||
// single segment so we have a single dictionary of 32 terms.
|
||||
let index = get_test_index_from_terms(true, &term_refs)?;
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": { "field": "string_id" }
|
||||
},
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 32.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `BitSet` path with a `missing` parameter: the column-level missing
|
||||
/// sentinel (`column.max_value() + 1`) flows into the bitset, the
|
||||
/// dict lookup filter at finalization drops it, and the missing
|
||||
/// coupon is applied separately.
|
||||
#[test]
|
||||
fn cardinality_aggregation_test_str_bitset_with_missing() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let name_field = schema_builder.add_text_field("name", STRING | FAST);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
for i in 0..16 {
|
||||
let term = format!("t{i:02}");
|
||||
writer.add_document(doc!(name_field => term)).unwrap();
|
||||
}
|
||||
// One empty doc, exercising the missing sentinel.
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "name",
|
||||
"missing": "MISSING_SENTINEL_KEY",
|
||||
}
|
||||
},
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index).unwrap();
|
||||
// 16 distinct real terms + 1 distinct "missing" value = 17.
|
||||
assert_eq!(res["cardinality"]["value"], 17.0);
|
||||
}
|
||||
|
||||
/// Unit-test the PagedBitset itself: cross-page inserts produce sorted
|
||||
/// iteration, len() matches the inserted set, and duplicates are
|
||||
/// idempotent.
|
||||
#[test]
|
||||
fn paged_bitset_basic() {
|
||||
use super::PagedBitset;
|
||||
// Span several pages: BITSET_PAGE_BITS = 1024, so ords > 1024 land
|
||||
// on the second page, > 2048 on the third, etc.
|
||||
let ords = [0u64, 1, 63, 64, 1023, 1024, 1025, 4096, 4097, 9999, 10_000];
|
||||
let max_ord = *ords.iter().max().unwrap();
|
||||
let mut bitset = PagedBitset::with_max_term_ord(max_ord);
|
||||
for &ord in &ords {
|
||||
bitset.insert(ord);
|
||||
// Idempotent: inserting again must not increase count.
|
||||
bitset.insert(ord);
|
||||
}
|
||||
assert_eq!(bitset.len(), ords.len() as u64);
|
||||
let collected: Vec<u64> = bitset.iter_sorted().collect();
|
||||
let mut expected: Vec<u64> = ords.to_vec();
|
||||
expected.sort_unstable();
|
||||
assert_eq!(collected, expected);
|
||||
}
|
||||
|
||||
/// Unit-test `TermOrdSet`: starts Sparse, promotes to Dense on
|
||||
/// `maybe_compact` once the density threshold is crossed, and
|
||||
/// `iter_ords()` yields the same set in either state. Ords spanning
|
||||
/// multiple paged-bitset pages exercise the Dense iter ordering.
|
||||
#[test]
|
||||
fn term_ord_set_promotes_on_maybe_compact() {
|
||||
use super::{TermOrdAccumulator, TermOrdSet, PROMOTION_RATIO};
|
||||
// Pick max so promotion needs few inserts: len * RATIO > max with
|
||||
// RATIO=32 and max=64 trips at len=3 (3*32=96 > 64).
|
||||
let max_term_ord = 64u64;
|
||||
let mut set = <TermOrdSet as TermOrdAccumulator>::new(max_term_ord);
|
||||
// Two inserts: should stay Sparse after maybe_compact (2 * RATIO = 64, not > 64).
|
||||
set.insert(0);
|
||||
set.insert(7);
|
||||
set.maybe_compact();
|
||||
assert_eq!(set.len(), 2);
|
||||
|
||||
// Third insert promotes on next maybe_compact.
|
||||
set.insert(20);
|
||||
assert_eq!(set.len(), 3);
|
||||
// Sanity check: at len=3, 3 * PROMOTION_RATIO = 96 > 64.
|
||||
assert!(3u64 * PROMOTION_RATIO > max_term_ord);
|
||||
set.maybe_compact();
|
||||
|
||||
// Post-promotion: extending continues to work.
|
||||
set.insert(15);
|
||||
set.insert(15); // dup
|
||||
assert_eq!(set.len(), 4);
|
||||
|
||||
let mut collected: Vec<u64> = set.iter_ords().collect();
|
||||
collected.sort_unstable();
|
||||
assert_eq!(collected, vec![0, 7, 15, 20]);
|
||||
}
|
||||
|
||||
/// Unit-test the `BitSet` impl of `TermOrdAccumulator`: insert,
|
||||
/// dedup, and iter_ords order.
|
||||
#[test]
|
||||
fn bitset_accumulator_basic() {
|
||||
use common::BitSet;
|
||||
|
||||
use super::TermOrdAccumulator;
|
||||
let mut set = <BitSet as TermOrdAccumulator>::new(255);
|
||||
for ord in [0u64, 1, 63, 64, 65, 128, 200, 200, 0] {
|
||||
<BitSet as TermOrdAccumulator>::insert(&mut set, ord);
|
||||
}
|
||||
assert_eq!(<BitSet as TermOrdAccumulator>::len(&set), 7);
|
||||
let collected: Vec<u64> = set.iter_ords().collect();
|
||||
assert_eq!(collected, vec![0, 1, 63, 64, 65, 128, 200]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_u64() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
@@ -683,6 +1147,42 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A JSON path that resolves to both a Str column and a numeric column
|
||||
/// produces two collector instances per segment — one with `Str` buckets
|
||||
/// and one with `Numeric` buckets. Their `IntermediateMetricResult`s must
|
||||
/// merge into the union cardinality.
|
||||
#[test]
|
||||
fn cardinality_aggregation_json_str_and_numeric() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_json_field("json", FAST);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
{
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
writer.add_document(doc!(field => json!({"value": "hello"})))?;
|
||||
writer.add_document(doc!(field => json!({"value": "world"})))?;
|
||||
writer.add_document(doc!(field => json!({"value": "hello"})))?; // dup str
|
||||
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?;
|
||||
writer.add_document(doc!(field => json!({"value": i64::from_u64(42u64)})))?;
|
||||
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?; // dup num
|
||||
writer.commit()?;
|
||||
}
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "json.value"
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
// 4 distinct values: "hello", "world", 7, 42.
|
||||
assert_eq!(res["cardinality"]["value"], 4.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_serde_roundtrip() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
@@ -399,6 +399,26 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
if self.name != sub_agg_name {
|
||||
return None;
|
||||
}
|
||||
let extended = self.buckets.get(bucket_id as usize)?;
|
||||
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
|
||||
// doesn't disturb the eventual intermediate result.
|
||||
extended
|
||||
.finalize()
|
||||
.get_value(sub_agg_property)
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -312,6 +312,26 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
|
||||
return None;
|
||||
}
|
||||
let percentile: f64 = sub_agg_property.parse().ok()?;
|
||||
if !(0.0..=100.0).contains(&percentile) {
|
||||
return None;
|
||||
}
|
||||
let bucket = self.buckets.get(bucket_id as usize)?;
|
||||
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
|
||||
// not affect the intermediate state used for the final result.
|
||||
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -321,6 +321,40 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
if self.name != sub_agg_name {
|
||||
return None;
|
||||
}
|
||||
let stats = self.buckets.get(bucket_id as usize)?;
|
||||
// The property depends on what we're collecting:
|
||||
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
|
||||
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
|
||||
// the value they were configured to collect.
|
||||
let prop = match self.collecting_for {
|
||||
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
|
||||
StatsType::Sum if sub_agg_property.is_empty() => "sum",
|
||||
StatsType::Count if sub_agg_property.is_empty() => "count",
|
||||
StatsType::Max if sub_agg_property.is_empty() => "max",
|
||||
StatsType::Min if sub_agg_property.is_empty() => "min",
|
||||
StatsType::Average if sub_agg_property.is_empty() => "avg",
|
||||
_ => return None,
|
||||
};
|
||||
match prop {
|
||||
"count" => Some(stats.count as f64),
|
||||
"sum" => Some(stats.sum),
|
||||
"min" if stats.count > 0 => Some(stats.min),
|
||||
"max" if stats.count > 0 => Some(stats.max),
|
||||
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@@ -644,6 +644,17 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
_bucket_id: BucketId,
|
||||
_sub_agg_name: &str,
|
||||
_sub_agg_property: &str,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
// top_hits is not a numeric metric and cannot be used as an order target.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -76,6 +76,31 @@ pub trait SegmentAggregationCollector: Debug {
|
||||
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
|
||||
///
|
||||
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
|
||||
/// this value and cuts off at segment time, matching the approximation tradeoff
|
||||
/// Elasticsearch makes for any sub-agg ordering.
|
||||
///
|
||||
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
|
||||
/// the metric is a single-value kind such as cardinality.
|
||||
///
|
||||
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
|
||||
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
|
||||
/// must be idempotent so the final intermediate result is unaffected.
|
||||
///
|
||||
/// No default impl on purpose: every collector must decide explicitly whether it
|
||||
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
|
||||
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
|
||||
/// buckets to the same key and drop arbitrary winners.
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64>;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -137,4 +162,21 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_metric_value(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
sub_agg_name: &str,
|
||||
sub_agg_property: &str,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> Option<f64> {
|
||||
for agg in &self.aggs {
|
||||
if let Some(value) =
|
||||
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,6 +389,13 @@ impl SegmentCollector for FacetSegmentCollector {
|
||||
}
|
||||
let mut facet = vec![];
|
||||
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
|
||||
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
|
||||
// document has the exact registered facet, not a child of it).
|
||||
// Passing it to ord_to_term would resolve to the last dictionary
|
||||
// entry and produce a spurious facet from an unrelated branch.
|
||||
if facet_ord == u64::MAX {
|
||||
continue;
|
||||
}
|
||||
// TODO handle errors.
|
||||
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
|
||||
if let Some((end_collapsed_facet, _)) = facet
|
||||
@@ -814,6 +821,63 @@ mod tests {
|
||||
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
|
||||
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
|
||||
}
|
||||
|
||||
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
|
||||
// When a document has the exact registered facet path (not just a child),
|
||||
// harvest() must not turn the unmapped sentinel into a spurious root entry.
|
||||
#[test]
|
||||
fn test_facet_collector_wrong_root() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
let facets: Vec<&str> = vec![
|
||||
"/science-fiction/asimov",
|
||||
"/science-fiction/clarke",
|
||||
"/science-fiction/dick",
|
||||
"/science-fiction/herbert",
|
||||
"/science-fiction/orwell",
|
||||
// This exact match on the registered facet is the bug trigger:
|
||||
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
|
||||
// mapping, which without the fix resolves to an unrelated term.
|
||||
"/fantasy/epic-fantasy",
|
||||
"/fantasy/epic-fantasy/tolkien",
|
||||
"/fantasy/epic-fantasy/martin",
|
||||
];
|
||||
for facet_str in &facets {
|
||||
index_writer.add_document(doc!(
|
||||
facet_field => Facet::from(*facet_str)
|
||||
))?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
|
||||
let query = TermQuery::new(term, IndexRecordOption::Basic);
|
||||
|
||||
let mut facet_collector = FacetCollector::for_field("facet");
|
||||
facet_collector.add_facet("/fantasy/epic-fantasy");
|
||||
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
|
||||
|
||||
let result: Vec<(String, u64)> = counts
|
||||
.get("/")
|
||||
.map(|(facet, count)| (facet.to_string(), count))
|
||||
.collect();
|
||||
|
||||
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![
|
||||
("/fantasy/epic-fantasy/martin".to_string(), 1),
|
||||
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
|
||||
@@ -138,6 +138,31 @@ pub trait DocSet: Send {
|
||||
buffer.len()
|
||||
}
|
||||
|
||||
/// Fills a given mutable buffer with the next doc ids smaller than `horizon`.
|
||||
///
|
||||
/// Unlike [`DocSet::fill_buffer`], this method must not advance past a doc id greater than or
|
||||
/// equal to `horizon`.
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
if self.doc() == TERMINATED {
|
||||
return 0;
|
||||
}
|
||||
for (pos, buffer_val) in buffer.iter_mut().enumerate() {
|
||||
let doc = self.doc();
|
||||
if doc >= horizon {
|
||||
return pos;
|
||||
}
|
||||
*buffer_val = doc;
|
||||
if self.advance() == TERMINATED {
|
||||
return pos + 1;
|
||||
}
|
||||
}
|
||||
buffer.len()
|
||||
}
|
||||
|
||||
/// Returns the current document
|
||||
/// Right after creating a new `DocSet`, the docset points to the first document.
|
||||
///
|
||||
@@ -251,6 +276,14 @@ impl DocSet for &mut dyn DocSet {
|
||||
(**self).fill_buffer(buffer)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
(**self).fill_buffer_up_to(horizon, buffer)
|
||||
}
|
||||
|
||||
fn fill_bitset_block(
|
||||
&mut self,
|
||||
min_doc: DocId,
|
||||
|
||||
@@ -6,6 +6,7 @@ use common::{ByteCount, HasLen};
|
||||
use fnv::FnvHashMap;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::directory::error::OpenReadError;
|
||||
use crate::directory::{CompositeFile, FileSlice};
|
||||
use crate::error::DataCorruption;
|
||||
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
|
||||
@@ -159,12 +160,10 @@ impl SegmentReader {
|
||||
let postings_file = segment.open_read(SegmentComponent::Postings)?;
|
||||
let postings_composite = CompositeFile::open(&postings_file)?;
|
||||
|
||||
let positions_composite = {
|
||||
if let Ok(positions_file) = segment.open_read(SegmentComponent::Positions) {
|
||||
CompositeFile::open(&positions_file)?
|
||||
} else {
|
||||
CompositeFile::empty()
|
||||
}
|
||||
let positions_composite = match segment.open_read(SegmentComponent::Positions) {
|
||||
Ok(positions_file) => CompositeFile::open(&positions_file)?,
|
||||
Err(OpenReadError::FileDoesNotExist(_)) => CompositeFile::empty(),
|
||||
Err(open_read_error) => return Err(open_read_error.into()),
|
||||
};
|
||||
|
||||
let schema = segment.schema();
|
||||
|
||||
@@ -240,6 +240,42 @@ impl BlockSegmentPostings {
|
||||
self.freq_decoder.output_array()
|
||||
}
|
||||
|
||||
pub(crate) fn copy_docs_and_term_freqs(
|
||||
&self,
|
||||
block_offset: usize,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId],
|
||||
term_freqs: &mut [u32],
|
||||
) -> usize {
|
||||
debug_assert_eq!(docs.len(), term_freqs.len());
|
||||
let block_docs = self.docs();
|
||||
let remaining_docs_in_block = block_docs.len().saturating_sub(block_offset);
|
||||
let max_len = remaining_docs_in_block.min(docs.len());
|
||||
if max_len == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let source_docs = &block_docs[block_offset..block_offset + max_len];
|
||||
let len = if source_docs[max_len - 1] < horizon {
|
||||
max_len
|
||||
} else {
|
||||
source_docs
|
||||
.iter()
|
||||
.position(|&doc| doc >= horizon)
|
||||
.unwrap_or(max_len)
|
||||
};
|
||||
|
||||
docs[..len].copy_from_slice(&source_docs[..len]);
|
||||
|
||||
let block_freqs = self.freq_output_array();
|
||||
if block_freqs.len() >= block_offset + len {
|
||||
term_freqs[..len].copy_from_slice(&block_freqs[block_offset..block_offset + len]);
|
||||
} else {
|
||||
term_freqs[..len].fill(1);
|
||||
}
|
||||
len
|
||||
}
|
||||
|
||||
/// Return the frequency at index `idx` of the block.
|
||||
#[inline]
|
||||
pub fn freq(&self, idx: usize) -> u32 {
|
||||
@@ -249,6 +285,12 @@ impl BlockSegmentPostings {
|
||||
|
||||
/// Returns the length of the current block.
|
||||
///
|
||||
/// Returns the decoded term-frequency buffer for the current block.
|
||||
#[inline]
|
||||
pub(crate) fn freq_output_array(&self) -> &[u32] {
|
||||
self.freq_decoder.output_array()
|
||||
}
|
||||
|
||||
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
|
||||
/// except the last block that may have a length
|
||||
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
|
||||
@@ -298,6 +340,11 @@ impl BlockSegmentPostings {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn has_remaining_docs(&self) -> bool {
|
||||
self.skip_reader.has_remaining_docs()
|
||||
}
|
||||
|
||||
pub(crate) fn block_is_loaded(&self) -> bool {
|
||||
self.block_loaded
|
||||
}
|
||||
|
||||
@@ -532,6 +532,16 @@ pub(crate) mod tests {
|
||||
fn score(&mut self) -> Score {
|
||||
self.0.score()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.0.can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.0.score_doc(doc, term_freq)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_skip_against_unoptimized<F: Fn() -> Box<dyn DocSet>>(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use common::HasLen;
|
||||
|
||||
use crate::docset::DocSet;
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::fastfield::AliveBitSet;
|
||||
use crate::positions::PositionReader;
|
||||
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
|
||||
@@ -151,6 +151,34 @@ impl SegmentPostings {
|
||||
position_reader,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
let mut num_elems = 0;
|
||||
while num_elems < COLLECT_BLOCK_BUFFER_LEN && self.doc() < horizon {
|
||||
let copied = self.block_cursor.copy_docs_and_term_freqs(
|
||||
self.cur,
|
||||
horizon,
|
||||
&mut docs[num_elems..],
|
||||
&mut term_freqs[num_elems..],
|
||||
);
|
||||
if copied == 0 {
|
||||
break;
|
||||
}
|
||||
num_elems += copied;
|
||||
self.cur += copied;
|
||||
|
||||
if self.cur == COMPRESSION_BLOCK_SIZE {
|
||||
self.cur = 0;
|
||||
self.block_cursor.advance();
|
||||
}
|
||||
}
|
||||
num_elems
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for SegmentPostings {
|
||||
|
||||
@@ -146,6 +146,11 @@ impl SkipReader {
|
||||
skip_reader
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn has_remaining_docs(&self) -> bool {
|
||||
self.remaining_docs != 0
|
||||
}
|
||||
|
||||
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
|
||||
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
|
||||
0
|
||||
|
||||
@@ -109,6 +109,16 @@ impl Scorer for AllScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
1.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use std::cell::RefCell;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use lru::LruCache;
|
||||
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::query::Explanation;
|
||||
use crate::schema::Field;
|
||||
@@ -59,7 +63,9 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
|
||||
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
|
||||
}
|
||||
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
const BM25_TF_CACHE_CAPACITY: usize = 64;
|
||||
|
||||
fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
let mut cache: [Score; 256] = [0.0; 256];
|
||||
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
|
||||
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
|
||||
@@ -68,6 +74,36 @@ fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
Arc::new(cache)
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static TF_CACHES: RefCell<LruCache<u32, Arc<[Score; 256]>>> = RefCell::new(LruCache::new(
|
||||
NonZeroUsize::new(BM25_TF_CACHE_CAPACITY).unwrap(),
|
||||
));
|
||||
}
|
||||
|
||||
/// The cache is shared across all [Bm25Weight] with the same average fieldnorm on the same thread.
|
||||
/// It is stored in a thread local LRU cache.
|
||||
///
|
||||
/// On one query all terms on the same field will share the same average fieldnorm, and thus the
|
||||
/// same cache. This will lower cache pressure.
|
||||
///
|
||||
/// Even between queries (on the same thread), the cache will be reused, which allows the cache to
|
||||
/// better learn the memory address of the cache and access patterns.
|
||||
///
|
||||
/// Thread local is used in order to be defensive about potential contention on the cache.
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
let cache_key = average_fieldnorm.to_bits();
|
||||
TF_CACHES.with(|cache_by_average_fieldnorm| {
|
||||
let mut cache_by_average_fieldnorm = cache_by_average_fieldnorm.borrow_mut();
|
||||
if let Some(cache) = cache_by_average_fieldnorm.get(&cache_key) {
|
||||
return cache.clone();
|
||||
}
|
||||
|
||||
let cache = compute_tf_cache_uncached(average_fieldnorm);
|
||||
cache_by_average_fieldnorm.put(cache_key, cache.clone());
|
||||
cache
|
||||
})
|
||||
}
|
||||
|
||||
/// A struct used for computing BM25 scores.
|
||||
#[derive(Clone)]
|
||||
pub struct Bm25Weight {
|
||||
@@ -229,7 +265,7 @@ impl Bm25Weight {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::idf;
|
||||
use super::{idf, Bm25Weight};
|
||||
use crate::{assert_nearly_equals, Score};
|
||||
|
||||
#[test]
|
||||
@@ -237,4 +273,12 @@ mod tests {
|
||||
let score: Score = 2.0;
|
||||
assert_nearly_equals!(idf(1, 2), score.ln());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_tf_cache_is_shared_for_same_average_fieldnorm() {
|
||||
let weight1 = Bm25Weight::for_one_term(1, 10, 3.0);
|
||||
let weight2 = Bm25Weight::for_one_term(2, 10, 3.0);
|
||||
|
||||
assert!(std::sync::Arc::ptr_eq(&weight1.cache, &weight2.cache));
|
||||
}
|
||||
}
|
||||
|
||||
464
src/query/boolean_query/block_wand_intersection.rs
Normal file
464
src/query/boolean_query/block_wand_intersection.rs
Normal file
@@ -0,0 +1,464 @@
|
||||
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::Scorer;
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
/// Block-max pruning for top-K over intersection of term scorers.
|
||||
///
|
||||
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
|
||||
/// For each window, the sum of block_max_scores is compared to the current threshold;
|
||||
/// if the block can't beat it, the entire block is skipped.
|
||||
///
|
||||
/// Within non-skipped blocks, individual documents are pruned by checking whether
|
||||
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
|
||||
/// performing the expensive intersection membership check (seeking into secondary scorers).
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `scorers` has at least 2 elements
|
||||
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
|
||||
pub(crate) fn block_wand_intersection(
|
||||
mut scorers: Vec<TermScorer>,
|
||||
mut threshold: Score,
|
||||
callback: &mut dyn FnMut(DocId, Score) -> Score,
|
||||
) {
|
||||
assert!(scorers.len() >= 2);
|
||||
|
||||
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
|
||||
scorers.sort_by_key(TermScorer::size_hint);
|
||||
|
||||
let (leader, secondaries) = scorers.split_first_mut().unwrap();
|
||||
|
||||
// Precompute global max scores for early termination checks.
|
||||
let leader_max_score: Score = leader.max_score();
|
||||
let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum();
|
||||
|
||||
// Early exit: no document can possibly beat the threshold.
|
||||
if leader_max_score + secondaries_global_max_sum <= threshold {
|
||||
return;
|
||||
}
|
||||
|
||||
// Borrow fieldnorm reader and BM25 weight before the main loop.
|
||||
// These are immutable references to disjoint fields from block_cursor,
|
||||
// but Rust's borrow checker can't see through method calls, so we
|
||||
// extract them once upfront.
|
||||
let fieldnorm_reader = leader.fieldnorm_reader().clone();
|
||||
let bm25_weight = leader.bm25_weight().clone();
|
||||
|
||||
let mut doc = leader.doc();
|
||||
|
||||
let mut secondary_block_max_scores: Box<[f32]> =
|
||||
vec![0.0f32; secondaries.len()].into_boxed_slice();
|
||||
let mut secondary_suffix_block_max: Box<[f32]> =
|
||||
vec![0.0f32; secondaries.len()].into_boxed_slice();
|
||||
|
||||
while doc < TERMINATED {
|
||||
// --- Phase 1: Block-level pruning ---
|
||||
//
|
||||
// Position all skip readers on the block containing `doc`.
|
||||
// seek_block is cheap: it only advances the skip reader, no block decompression.
|
||||
leader.seek_block(doc);
|
||||
let leader_block_max: Score = leader.block_max_score();
|
||||
|
||||
// Compute the window end as the minimum last_doc_in_block across all scorers.
|
||||
// This ensures the block_max values are valid for all docs in [doc, window_end].
|
||||
// Different scorers have independently aligned blocks, so we must use the
|
||||
// smallest window where all block_max values hold.
|
||||
let mut window_end: DocId = leader.last_doc_in_block();
|
||||
|
||||
let mut secondary_block_max_sum: Score = 0.0;
|
||||
let num_secondaries = secondaries.len();
|
||||
for (idx, secondary) in secondaries.iter_mut().enumerate() {
|
||||
secondary.block_cursor().seek_block(doc);
|
||||
if !secondary.block_cursor().has_remaining_docs() {
|
||||
return;
|
||||
}
|
||||
window_end = window_end.min(secondary.last_doc_in_block());
|
||||
let bms = secondary.block_max_score();
|
||||
secondary_block_max_scores[idx] = bms;
|
||||
secondary_block_max_sum += bms;
|
||||
}
|
||||
|
||||
if leader_block_max + secondary_block_max_sum <= threshold {
|
||||
// The entire window cannot beat the threshold. Skip past it.
|
||||
doc = window_end + 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Phase 2: Batch processing within the window ---
|
||||
//
|
||||
// Score-first approach: decode the leader's block, filter by threshold,
|
||||
// then check intersection membership only for survivors. This avoids expensive
|
||||
// secondary seeks for docs that can't beat the threshold.
|
||||
let block_cursor = leader.block_cursor();
|
||||
// seek loads the block and returns the in-block index of the first doc >= `doc`.
|
||||
let start_idx = block_cursor.seek(doc);
|
||||
|
||||
// Use the branchless binary search on the doc decoder to find the first
|
||||
// index past window_end.
|
||||
let end_idx = block_cursor
|
||||
.doc_decoder
|
||||
.seek_within_block(window_end + 1)
|
||||
.min(block_cursor.block_len());
|
||||
|
||||
let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx];
|
||||
let block_freqs = &block_cursor.freq_output_array()[start_idx..end_idx];
|
||||
|
||||
// Pass 1: Batch-compute leader BM25 scores and branchlessly filter
|
||||
// candidates that can't beat the threshold.
|
||||
//
|
||||
// The trick: always write to the buffer at `num_candidates`, then
|
||||
// conditionally advance the count. The compiler can turn this into
|
||||
// a cmov instead of a branch, avoiding misprediction costs.
|
||||
let score_threshold = threshold - secondary_block_max_sum;
|
||||
let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE];
|
||||
let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE];
|
||||
let mut num_candidates = 0usize;
|
||||
|
||||
for (candidate_doc, term_freq) in
|
||||
block_docs.iter().copied().zip(block_freqs.iter().copied())
|
||||
{
|
||||
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc);
|
||||
let leader_score = bm25_weight.score(fieldnorm_id, term_freq);
|
||||
candidate_doc_ids[num_candidates] = candidate_doc;
|
||||
candidate_scores[num_candidates] = leader_score;
|
||||
num_candidates += (leader_score > score_threshold) as usize;
|
||||
}
|
||||
|
||||
// Precompute suffix sums: suffix[i] = sum of block_max for secondaries[i+1..].
|
||||
// Used in Phase 2 to prune candidates that can't beat threshold even with
|
||||
// remaining secondaries contributing their block_max.
|
||||
if num_candidates == 0 {
|
||||
doc = window_end + 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut running = 0.0f32;
|
||||
for idx in (0..num_secondaries).rev() {
|
||||
secondary_suffix_block_max[idx] = running;
|
||||
running += secondary_block_max_scores[idx];
|
||||
}
|
||||
|
||||
// Pass 2: Check intersection membership only for survivors.
|
||||
// score_threshold may be stale (threshold can increase from callbacks),
|
||||
// but that's conservative — we may check a few extra candidates, never miss one.
|
||||
'next_candidate: for candidate_idx in 0..num_candidates {
|
||||
let candidate_doc = candidate_doc_ids[candidate_idx];
|
||||
let mut total_score: Score = candidate_scores[candidate_idx];
|
||||
|
||||
for (secondary_idx, secondary) in secondaries.iter_mut().enumerate() {
|
||||
// If a previous candidate already advanced this secondary past
|
||||
// candidate_doc, the candidate can't be in the intersection.
|
||||
if secondary.doc() > candidate_doc {
|
||||
continue 'next_candidate;
|
||||
}
|
||||
let seek_result = secondary.seek(candidate_doc);
|
||||
if seek_result != candidate_doc {
|
||||
continue 'next_candidate;
|
||||
}
|
||||
total_score += secondary.score();
|
||||
|
||||
// Prune: even if all remaining secondaries score at their block max,
|
||||
// can we still beat the threshold?
|
||||
if total_score + secondary_suffix_block_max[secondary_idx] <= threshold {
|
||||
continue 'next_candidate;
|
||||
}
|
||||
}
|
||||
|
||||
// All secondaries matched.
|
||||
if total_score > threshold {
|
||||
threshold = callback(candidate_doc, total_score);
|
||||
|
||||
if leader_max_score + secondaries_global_max_sum <= threshold {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doc = window_end + 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::{Bm25Weight, Scorer};
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
struct Float(Score);
|
||||
|
||||
impl Eq for Float {}
|
||||
|
||||
impl PartialEq for Float {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other) == Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Float {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Float {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
fn nearly_equals(left: Score, right: Score) -> bool {
|
||||
(left - right).abs() < 0.0001 * (left + right).abs()
|
||||
}
|
||||
|
||||
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
|
||||
fn compute_checkpoints_block_wand_intersection(
|
||||
term_scorers: Vec<TermScorer>,
|
||||
top_k: usize,
|
||||
) -> Vec<(DocId, Score)> {
|
||||
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
|
||||
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
||||
let mut limit: Score = 0.0;
|
||||
|
||||
let callback = &mut |doc, score| {
|
||||
heap.push(Float(score));
|
||||
if heap.len() > top_k {
|
||||
heap.pop().unwrap();
|
||||
}
|
||||
if heap.len() == top_k {
|
||||
limit = heap.peek().unwrap().0;
|
||||
}
|
||||
if !nearly_equals(score, limit) {
|
||||
checkpoints.push((doc, score));
|
||||
}
|
||||
limit
|
||||
};
|
||||
|
||||
super::block_wand_intersection(term_scorers, Score::MIN, callback);
|
||||
checkpoints
|
||||
}
|
||||
|
||||
/// Naive baseline: intersect by iterating all docs.
|
||||
fn compute_checkpoints_naive_intersection(
|
||||
mut term_scorers: Vec<TermScorer>,
|
||||
top_k: usize,
|
||||
) -> Vec<(DocId, Score)> {
|
||||
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
|
||||
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
||||
let mut limit = Score::MIN;
|
||||
|
||||
// Sort by cost to use the cheapest as driver.
|
||||
term_scorers.sort_by_key(|s| s.cost());
|
||||
|
||||
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
|
||||
|
||||
let mut doc = leader.doc();
|
||||
while doc != TERMINATED {
|
||||
let mut all_match = true;
|
||||
for secondary in secondaries.iter_mut() {
|
||||
let secondary_doc = secondary.doc();
|
||||
let seek_result = if secondary_doc <= doc {
|
||||
secondary.seek(doc)
|
||||
} else {
|
||||
secondary_doc
|
||||
};
|
||||
if seek_result != doc {
|
||||
all_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if all_match {
|
||||
let score: Score =
|
||||
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
|
||||
|
||||
if score > limit {
|
||||
heap.push(Float(score));
|
||||
if heap.len() > top_k {
|
||||
heap.pop().unwrap();
|
||||
}
|
||||
if heap.len() == top_k {
|
||||
limit = heap.peek().unwrap().0;
|
||||
}
|
||||
if !nearly_equals(score, limit) {
|
||||
checkpoints.push((doc, score));
|
||||
}
|
||||
}
|
||||
}
|
||||
doc = leader.advance();
|
||||
}
|
||||
checkpoints
|
||||
}
|
||||
|
||||
const MAX_TERM_FREQ: u32 = 100u32;
|
||||
|
||||
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
|
||||
(1..max_doc + 1)
|
||||
.prop_flat_map(move |doc_freq| {
|
||||
(
|
||||
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
|
||||
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
|
||||
)
|
||||
})
|
||||
.prop_map(|(docset, term_freqs)| {
|
||||
docset
|
||||
.iter()
|
||||
.map(|doc| doc as u32)
|
||||
.zip(term_freqs.iter().cloned())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[expect(clippy::type_complexity)]
|
||||
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
|
||||
(1u32..100u32)
|
||||
.prop_flat_map(move |max_doc: u32| {
|
||||
(
|
||||
proptest::collection::vec(posting_list(max_doc), num_scorers),
|
||||
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
|
||||
)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
|
||||
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
|
||||
// strategy.
|
||||
const REPEAT: usize = 64;
|
||||
let fieldnorms_expanded: Vec<u32> = fieldnorms
|
||||
.iter()
|
||||
.cloned()
|
||||
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
|
||||
.collect();
|
||||
|
||||
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
|
||||
.iter()
|
||||
.map(|posting_list| {
|
||||
posting_list
|
||||
.iter()
|
||||
.cloned()
|
||||
.flat_map(|(doc, term_freq)| {
|
||||
(0_u32..REPEAT as u32).map(move |offset| {
|
||||
(
|
||||
doc * (REPEAT as u32) + offset,
|
||||
if offset == 0 { term_freq } else { 1 },
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<(DocId, u32)>>()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total_fieldnorms: u64 = fieldnorms_expanded
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|fieldnorm| fieldnorm as u64)
|
||||
.sum();
|
||||
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
|
||||
let max_doc = fieldnorms_expanded.len();
|
||||
|
||||
let make_scorers = || -> Vec<TermScorer> {
|
||||
postings_lists_expanded
|
||||
.iter()
|
||||
.map(|postings| {
|
||||
let bm25_weight = Bm25Weight::for_one_term(
|
||||
postings.len() as u64,
|
||||
max_doc as u64,
|
||||
average_fieldnorm,
|
||||
);
|
||||
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
for top_k in 1..4 {
|
||||
let checkpoints_optimized =
|
||||
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
|
||||
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
|
||||
assert_eq!(
|
||||
checkpoints_optimized.len(),
|
||||
checkpoints_naive.len(),
|
||||
"Mismatch in checkpoint count for top_k={top_k}"
|
||||
);
|
||||
for (&(left_doc, left_score), &(right_doc, right_score)) in
|
||||
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
|
||||
{
|
||||
assert_eq!(left_doc, right_doc);
|
||||
assert!(
|
||||
nearly_equals(left_score, right_score),
|
||||
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(500))]
|
||||
#[test]
|
||||
fn test_block_wand_intersection_two_scorers(
|
||||
(posting_lists, fieldnorms) in gen_term_scorers(2)
|
||||
) {
|
||||
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(500))]
|
||||
#[test]
|
||||
fn test_block_wand_intersection_three_scorers(
|
||||
(posting_lists, fieldnorms) in gen_term_scorers(3)
|
||||
) {
|
||||
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_wand_intersection_disjoint() {
|
||||
// Two posting lists with no overlap — intersection is empty.
|
||||
let fieldnorms: Vec<u32> = vec![10; 200];
|
||||
let average_fieldnorm = 10.0;
|
||||
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
|
||||
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
|
||||
|
||||
let scorer_a = TermScorer::create_for_test(
|
||||
&postings_a,
|
||||
&fieldnorms,
|
||||
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
|
||||
);
|
||||
let scorer_b = TermScorer::create_for_test(
|
||||
&postings_b,
|
||||
&fieldnorms,
|
||||
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
|
||||
);
|
||||
|
||||
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
|
||||
assert!(checkpoints.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_wand_intersection_all_overlap() {
|
||||
// Two posting lists with full overlap.
|
||||
let fieldnorms: Vec<u32> = vec![10; 50];
|
||||
let average_fieldnorm = 10.0;
|
||||
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
|
||||
|
||||
let make_scorer = || {
|
||||
TermScorer::create_for_test(
|
||||
&postings,
|
||||
&fieldnorms,
|
||||
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
|
||||
)
|
||||
};
|
||||
|
||||
let checkpoints_opt =
|
||||
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
|
||||
let checkpoints_naive =
|
||||
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
|
||||
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,7 @@ fn block_max_was_too_low_advance_one_scorer(
|
||||
scorers: &mut [TermScorerWithMaxScore],
|
||||
pivot_len: usize,
|
||||
) {
|
||||
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
|
||||
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
|
||||
let mut scorer_to_seek = pivot_len - 1;
|
||||
let mut global_max_score = scorers[scorer_to_seek].max_score;
|
||||
let mut doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block();
|
||||
@@ -76,7 +76,7 @@ fn block_max_was_too_low_advance_one_scorer(
|
||||
scorers[scorer_to_seek].seek(doc_to_seek_after);
|
||||
|
||||
restore_ordering(scorers, scorer_to_seek);
|
||||
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
|
||||
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
|
||||
}
|
||||
|
||||
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
|
||||
@@ -90,7 +90,7 @@ fn restore_ordering(term_scorers: &mut [TermScorerWithMaxScore], ord: usize) {
|
||||
}
|
||||
term_scorers.swap(i, i - 1);
|
||||
}
|
||||
debug_assert!(is_sorted(term_scorers.iter().map(|scorer| scorer.doc())));
|
||||
debug_assert!(term_scorers.iter().map(|scorer| scorer.doc()).is_sorted());
|
||||
}
|
||||
|
||||
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
|
||||
@@ -150,17 +150,21 @@ pub fn block_wand(
|
||||
mut threshold: Score,
|
||||
callback: &mut dyn FnMut(u32, Score) -> Score,
|
||||
) {
|
||||
scorers.retain(|scorer| scorer.doc() < TERMINATED);
|
||||
if scorers.len() == 1 {
|
||||
let scorer = scorers.pop().unwrap();
|
||||
return block_wand_single_scorer(scorer, threshold, callback);
|
||||
}
|
||||
let mut scorers: Vec<TermScorerWithMaxScore> = scorers
|
||||
.iter_mut()
|
||||
.map(TermScorerWithMaxScore::from)
|
||||
.collect();
|
||||
scorers.sort_by_key(|scorer| scorer.doc());
|
||||
// At this point we need to ensure that the scorers are sorted!
|
||||
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
|
||||
scorers.sort_by_key(|scorer| scorer.doc());
|
||||
while let Some((before_pivot_len, pivot_len, pivot_doc)) =
|
||||
find_pivot_doc(&scorers[..], threshold)
|
||||
{
|
||||
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
|
||||
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
|
||||
debug_assert_ne!(pivot_doc, TERMINATED);
|
||||
debug_assert!(before_pivot_len < pivot_len);
|
||||
|
||||
@@ -228,7 +232,7 @@ pub fn block_wand_single_scorer(
|
||||
loop {
|
||||
// We position the scorer on a block that can reach
|
||||
// the threshold.
|
||||
while scorer.block_max_score() < threshold {
|
||||
while scorer.block_max_score() <= threshold {
|
||||
let last_doc_in_block = scorer.last_doc_in_block();
|
||||
if last_doc_in_block == TERMINATED {
|
||||
return;
|
||||
@@ -286,18 +290,6 @@ impl DerefMut for TermScorerWithMaxScore<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_sorted<I: Iterator<Item = DocId>>(mut it: I) -> bool {
|
||||
if let Some(first) = it.next() {
|
||||
let mut prev = first;
|
||||
for doc in it {
|
||||
if doc < prev {
|
||||
return false;
|
||||
}
|
||||
prev = doc;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
@@ -16,6 +16,7 @@ use crate::{DocId, Score};
|
||||
|
||||
enum SpecializedScorer {
|
||||
TermUnion(Vec<TermScorer>),
|
||||
TermIntersection(Vec<TermScorer>),
|
||||
Other(Box<dyn Scorer>),
|
||||
}
|
||||
|
||||
@@ -49,10 +50,9 @@ where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
assert!(!scorers.is_empty());
|
||||
if scorers.len() == 1 {
|
||||
if scorers.len() == 1 && !scorers[0].is::<TermScorer>() {
|
||||
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
|
||||
}
|
||||
|
||||
{
|
||||
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
|
||||
if is_all_term_queries {
|
||||
@@ -66,6 +66,9 @@ where
|
||||
{
|
||||
// Block wand is only available if we read frequencies.
|
||||
return SpecializedScorer::TermUnion(scorers);
|
||||
} else if scorers.len() == 1 {
|
||||
// Single TermScorer without freq reading — unwrap directly.
|
||||
return SpecializedScorer::Other(Box::new(scorers.into_iter().next().unwrap()));
|
||||
} else {
|
||||
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||
scorers,
|
||||
@@ -88,10 +91,21 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
num_docs: u32,
|
||||
) -> Box<dyn Scorer> {
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||
Box::new(union_scorer)
|
||||
SpecializedScorer::TermUnion(mut term_scorers) => {
|
||||
if term_scorers.len() == 1 {
|
||||
Box::new(term_scorers.pop().unwrap())
|
||||
} else {
|
||||
let union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||
Box::new(union_scorer)
|
||||
}
|
||||
}
|
||||
SpecializedScorer::TermIntersection(term_scorers) => {
|
||||
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
|
||||
.into_iter()
|
||||
.map(|s| Box::new(s) as Box<dyn Scorer>)
|
||||
.collect();
|
||||
intersect_scorers(boxed_scorers, num_docs)
|
||||
}
|
||||
SpecializedScorer::Other(scorer) => scorer,
|
||||
}
|
||||
@@ -297,14 +311,43 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
// Result depends entirely on MUST + any removed AllScorers.
|
||||
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
|
||||
+ should_special_scorer_counts.num_all_scorers;
|
||||
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
|
||||
must_scorers,
|
||||
combined_all_scorer_count,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
)
|
||||
.unwrap_or_else(|| Box::new(EmptyScorer));
|
||||
SpecializedScorer::Other(boxed_scorer)
|
||||
|
||||
// Try to detect a pure TermScorer intersection for block-max optimization.
|
||||
// Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer
|
||||
// with frequency reading enabled.
|
||||
if combined_all_scorer_count == 0
|
||||
&& must_scorers.len() >= 2
|
||||
&& must_scorers.iter().all(|s| s.is::<TermScorer>())
|
||||
{
|
||||
let term_scorers: Vec<TermScorer> = must_scorers
|
||||
.into_iter()
|
||||
.map(|s| *(s.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
|
||||
.collect();
|
||||
if term_scorers
|
||||
.iter()
|
||||
.all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq)
|
||||
{
|
||||
SpecializedScorer::TermIntersection(term_scorers)
|
||||
} else {
|
||||
let must_scorers: Vec<Box<dyn Scorer>> = term_scorers
|
||||
.into_iter()
|
||||
.map(|s| Box::new(s) as Box<dyn Scorer>)
|
||||
.collect();
|
||||
let boxed_scorer: Box<dyn Scorer> =
|
||||
effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs)
|
||||
.unwrap_or_else(|| Box::new(EmptyScorer));
|
||||
SpecializedScorer::Other(boxed_scorer)
|
||||
}
|
||||
} else {
|
||||
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
|
||||
must_scorers,
|
||||
combined_all_scorer_count,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
)
|
||||
.unwrap_or_else(|| Box::new(EmptyScorer));
|
||||
SpecializedScorer::Other(boxed_scorer)
|
||||
}
|
||||
}
|
||||
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
|
||||
// Optional SHOULD: contributes to scoring but not required for matching.
|
||||
@@ -463,15 +506,21 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
callback: &mut dyn FnMut(DocId, Score),
|
||||
) -> crate::Result<()> {
|
||||
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
||||
let num_docs = reader.num_docs();
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let mut union_scorer = BufferedUnionScorer::build(
|
||||
term_scorers,
|
||||
&self.score_combiner_fn,
|
||||
reader.num_docs(),
|
||||
);
|
||||
let mut union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
|
||||
for_each_scorer(&mut union_scorer, callback);
|
||||
}
|
||||
SpecializedScorer::TermIntersection(term_scorers) => {
|
||||
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
|
||||
.into_iter()
|
||||
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
|
||||
.collect();
|
||||
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
|
||||
for_each_scorer(intersection.as_mut(), callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
for_each_scorer(scorer.as_mut(), callback);
|
||||
}
|
||||
@@ -485,17 +534,23 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
callback: &mut dyn FnMut(&[DocId]),
|
||||
) -> crate::Result<()> {
|
||||
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
|
||||
let num_docs = reader.num_docs();
|
||||
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
|
||||
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let mut union_scorer = BufferedUnionScorer::build(
|
||||
term_scorers,
|
||||
&self.score_combiner_fn,
|
||||
reader.num_docs(),
|
||||
);
|
||||
let mut union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
|
||||
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
|
||||
}
|
||||
SpecializedScorer::TermIntersection(term_scorers) => {
|
||||
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
|
||||
.into_iter()
|
||||
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
|
||||
.collect();
|
||||
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
|
||||
for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
|
||||
}
|
||||
@@ -524,6 +579,9 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
super::block_wand(term_scorers, threshold, callback);
|
||||
}
|
||||
SpecializedScorer::TermIntersection(term_scorers) => {
|
||||
super::block_wand_intersection(term_scorers, threshold, callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod block_wand;
|
||||
mod block_wand_intersection;
|
||||
mod block_wand_union;
|
||||
mod boolean_query;
|
||||
mod boolean_weight;
|
||||
|
||||
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
|
||||
pub(crate) use self::block_wand_intersection::block_wand_intersection;
|
||||
pub(crate) use self::block_wand_union::{block_wand, block_wand_single_scorer};
|
||||
pub use self::boolean_query::BooleanQuery;
|
||||
pub use self::boolean_weight::BooleanWeight;
|
||||
|
||||
|
||||
@@ -112,6 +112,14 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
||||
self.underlying.fill_buffer(buffer)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.underlying.fill_buffer_up_to(horizon, buffer)
|
||||
}
|
||||
|
||||
fn doc(&self) -> u32 {
|
||||
self.underlying.doc()
|
||||
}
|
||||
@@ -138,6 +146,27 @@ impl<S: Scorer> Scorer for BoostScorer<S> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.underlying.score() * self.boost
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.underlying.can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.underlying.score_doc(doc, term_freq) * self.boost
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.underlying
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -141,6 +141,16 @@ impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -315,6 +315,20 @@ mod tests {
|
||||
fn score(&mut self) -> Score {
|
||||
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, _term_freq: u32) -> Score {
|
||||
self.foo
|
||||
.iter()
|
||||
.find(|(candidate_doc, _)| *candidate_doc == doc)
|
||||
.map(|(_, score)| *score)
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -59,6 +59,16 @@ impl Scorer for EmptyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
0.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,5 +1,40 @@
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::query::Scorer;
|
||||
use crate::Score;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
struct ScoreOnlyScorer {
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
}
|
||||
|
||||
impl DocSet for ScoreOnlyScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.doc = TERMINATED;
|
||||
TERMINATED
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for ScoreOnlyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score
|
||||
}
|
||||
}
|
||||
|
||||
/// The `ScoreCombiner` trait defines how to compute
|
||||
/// an overall score given a list of scores.
|
||||
@@ -10,6 +45,17 @@ pub trait ScoreCombiner: Default + Clone + Send + Copy + 'static {
|
||||
/// or not.
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer);
|
||||
|
||||
/// Aggregates the score combiner with an already computed score.
|
||||
fn update_score(&mut self, doc: DocId, score: Score) {
|
||||
let mut scorer = ScoreOnlyScorer { doc, score };
|
||||
self.update(&mut scorer);
|
||||
}
|
||||
|
||||
/// Returns true if this combiner needs scorer scores to compute its state.
|
||||
fn requires_scoring() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Clears the score combiner state back to its initial state.
|
||||
fn clear(&mut self);
|
||||
|
||||
@@ -27,6 +73,12 @@ pub struct DoNothingCombiner;
|
||||
impl ScoreCombiner for DoNothingCombiner {
|
||||
fn update<TScorer: Scorer>(&mut self, _scorer: &mut TScorer) {}
|
||||
|
||||
fn update_score(&mut self, _doc: DocId, _score: Score) {}
|
||||
|
||||
fn requires_scoring() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
#[inline]
|
||||
@@ -42,10 +94,16 @@ pub struct SumCombiner {
|
||||
}
|
||||
|
||||
impl ScoreCombiner for SumCombiner {
|
||||
#[inline]
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
|
||||
self.score += scorer.score();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score(&mut self, _doc: DocId, score: Score) {
|
||||
self.score += score;
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.score = 0.0;
|
||||
}
|
||||
@@ -77,12 +135,19 @@ impl DisjunctionMaxCombiner {
|
||||
}
|
||||
|
||||
impl ScoreCombiner for DisjunctionMaxCombiner {
|
||||
#[inline]
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
|
||||
let score = scorer.score();
|
||||
self.max = Score::max(score, self.max);
|
||||
self.sum += score;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score(&mut self, _doc: DocId, score: Score) {
|
||||
self.max = Score::max(score, self.max);
|
||||
self.sum += score;
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.max = 0.0;
|
||||
self.sum = 0.0;
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::ops::DerefMut;
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use crate::docset::DocSet;
|
||||
use crate::Score;
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Scored set of documents matching a query within a specific segment.
|
||||
///
|
||||
@@ -13,6 +13,36 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
|
||||
///
|
||||
/// This method will perform a bit of computation and is not cached.
|
||||
fn score(&mut self) -> Score;
|
||||
|
||||
/// Returns true if [`Scorer::score_doc`] can score buffered docs without
|
||||
/// repositioning the scorer.
|
||||
///
|
||||
/// Scorers whose [`Scorer::score_doc`] needs term frequencies must also override
|
||||
/// [`Scorer::fill_buffer_up_to_with_term_freqs`].
|
||||
fn can_score_doc(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns the score for `doc` with its term frequency.
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
panic!(
|
||||
"score_doc is not supported by this scorer. You need check can_score_doc() before \
|
||||
calling this method."
|
||||
)
|
||||
}
|
||||
|
||||
/// Fills docs up to `horizon`.
|
||||
///
|
||||
/// The default implementation does not fill `term_freqs`. Scorers whose
|
||||
/// [`Scorer::score_doc`] reads term frequencies must override this method.
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
_term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
DocSet::fill_buffer_up_to(self, horizon, docs)
|
||||
}
|
||||
}
|
||||
|
||||
impl_downcast!(Scorer);
|
||||
@@ -22,4 +52,25 @@ impl Scorer for Box<dyn Scorer> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.deref_mut().score()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.as_ref().can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.deref_mut().score_doc(doc, term_freq)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.deref_mut()
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::docset::DocSet;
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::postings::{FreqReadingOption, Postings, SegmentPostings};
|
||||
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
|
||||
use crate::query::bm25::Bm25Weight;
|
||||
use crate::query::{Explanation, Scorer};
|
||||
use crate::{DocId, Score};
|
||||
@@ -95,6 +95,21 @@ impl TermScorer {
|
||||
pub fn last_doc_in_block(&self) -> DocId {
|
||||
self.postings.block_cursor.skip_reader().last_doc_in_block()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the underlying block cursor.
|
||||
pub(crate) fn block_cursor(&mut self) -> &mut BlockSegmentPostings {
|
||||
&mut self.postings.block_cursor
|
||||
}
|
||||
|
||||
/// Returns a reference to the fieldnorm reader for batch lookups.
|
||||
pub(crate) fn fieldnorm_reader(&self) -> &FieldNormReader {
|
||||
&self.fieldnorm_reader
|
||||
}
|
||||
|
||||
/// Returns a reference to the BM25 weight for batch score computation.
|
||||
pub(crate) fn bm25_weight(&self) -> &Bm25Weight {
|
||||
&self.similarity_weight
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for TermScorer {
|
||||
@@ -132,6 +147,27 @@ impl Scorer for TermScorer {
|
||||
let term_freq = self.term_freq();
|
||||
self.similarity_weight.score(fieldnorm_id, term_freq)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
|
||||
self.similarity_weight.score(fieldnorm_id, term_freq)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.postings
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -10,23 +10,7 @@ use crate::{DocId, Score};
|
||||
// of upcoming document IDs (the "horizon").
|
||||
const HORIZON_NUM_TINYBITSETS: usize = HORIZON as usize / 64;
|
||||
const HORIZON: u32 = 64u32 * 64u32;
|
||||
|
||||
// `drain_filter` is not stable yet.
|
||||
// This function is similar except that it does is not unstable, and
|
||||
// it does not keep the original vector ordering.
|
||||
//
|
||||
// Elements are dropped and not yielded.
|
||||
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
|
||||
where P: FnMut(&mut T) -> bool {
|
||||
let mut i = 0;
|
||||
while i < v.len() {
|
||||
if predicate(&mut v[i]) {
|
||||
v.swap_remove(i);
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
const GROUPED_INSERT_MAX_BUCKET_SPAN: u32 = 2;
|
||||
|
||||
/// Creates a `DocSet` that iterate through the union of two or more `DocSet`s.
|
||||
pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
@@ -53,31 +37,213 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
score: Score,
|
||||
/// Number of documents in the segment.
|
||||
num_docs: u32,
|
||||
/// Scratch buffer for block-based refill.
|
||||
refill_docs: [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
/// Scratch buffer for term frequencies matching `refill_docs`.
|
||||
refill_term_freqs: [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
/// Whether all children support scoring buffered docs after advancing.
|
||||
use_score_doc_refill: bool,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn union_bucket(
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
bucket_pos: u32,
|
||||
tinyset: TinySet,
|
||||
) {
|
||||
debug_assert!((bucket_pos as usize) < HORIZON_NUM_TINYBITSETS);
|
||||
// `bucket` comes from a doc delta below `HORIZON`; there are exactly
|
||||
// `HORIZON / 64` buckets in the refill window.
|
||||
bitsets[bucket_pos as usize] = bitsets[bucket_pos as usize].union(tinyset);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn insert_delta(bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], delta: DocId) {
|
||||
debug_assert!(delta < HORIZON);
|
||||
// `delta < HORIZON`, so `delta / 64` is in the bitset array. The bit
|
||||
// offset is reduced modulo 64 before being inserted in the TinySet.
|
||||
bitsets[delta as usize / 64].insert_mut(delta % 64u32);
|
||||
}
|
||||
|
||||
fn insert_and_score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
) {
|
||||
debug_assert!(docs.windows(2).all(|pair| pair[0] < pair[1]));
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc < HORIZON);
|
||||
|
||||
let first_delta = docs[0] - min_doc;
|
||||
let last_delta = docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc;
|
||||
let first_bucket = first_delta / 64;
|
||||
let last_bucket = last_delta / 64;
|
||||
|
||||
// Common for very dense scorers: 64 distinct doc ids in one 64-doc bucket
|
||||
// means all bits in that bucket are present.
|
||||
if first_bucket == last_bucket {
|
||||
union_bucket(bitsets, first_bucket, TinySet::full());
|
||||
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
|
||||
return;
|
||||
}
|
||||
|
||||
// 64 sorted distinct integers spanning exactly 64 values are consecutive.
|
||||
// If they cross a TinySet boundary, this is just the suffix of the first
|
||||
// bucket plus the prefix of the second bucket.
|
||||
if last_delta - first_delta == COLLECT_BLOCK_BUFFER_LEN as u32 - 1 {
|
||||
union_bucket(
|
||||
bitsets,
|
||||
first_bucket,
|
||||
TinySet::range_greater_or_equal(first_delta % 64u32),
|
||||
);
|
||||
union_bucket(
|
||||
bitsets,
|
||||
last_bucket,
|
||||
TinySet::range_lower((last_delta + 1) % 64u32),
|
||||
);
|
||||
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
|
||||
return;
|
||||
}
|
||||
|
||||
// Grouping wins only for very dense buffers that hit the same TinySet many
|
||||
// times. Once the 64 docs are spread farther, a straight pass is cheaper.
|
||||
if last_bucket - first_bucket <= GROUPED_INSERT_MAX_BUCKET_SPAN {
|
||||
let mut bucket = first_bucket;
|
||||
let mut tinyset = TinySet::empty();
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let delta = doc - min_doc;
|
||||
let delta_bucket = delta / 64;
|
||||
if delta_bucket != bucket {
|
||||
union_bucket(bitsets, bucket, tinyset);
|
||||
bucket = delta_bucket;
|
||||
tinyset = TinySet::empty();
|
||||
}
|
||||
tinyset.insert_mut(delta % 64u32);
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
union_bucket(bitsets, bucket, tinyset);
|
||||
} else {
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
// TODO: score_doc access the field_norm reader for each _term_, instead of once per
|
||||
// doc. We could optimize this by caching the field norm for the doc, and
|
||||
// reusing it for all terms in the doc.
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score_combiner<TScoreCombiner: ScoreCombiner>(
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
delta: DocId,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
) {
|
||||
debug_assert!(delta < HORIZON);
|
||||
// Full and partial refill only buffer docs below `horizon`, so their
|
||||
// deltas are always in the score-combiner window.
|
||||
score_combiner[delta as usize].update_score(doc, score);
|
||||
}
|
||||
|
||||
fn score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
) {
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, doc - min_doc, doc, score);
|
||||
}
|
||||
}
|
||||
|
||||
fn refill_scorer_with_score_docs<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
min_doc: DocId,
|
||||
horizon: DocId,
|
||||
) {
|
||||
loop {
|
||||
let len = scorer.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs);
|
||||
if len == COLLECT_BLOCK_BUFFER_LEN {
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] != TERMINATED);
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] < horizon);
|
||||
insert_and_score_full_buffer(
|
||||
scorer,
|
||||
docs,
|
||||
term_freqs,
|
||||
bitsets,
|
||||
score_combiner,
|
||||
min_doc,
|
||||
);
|
||||
} else {
|
||||
for (&doc, &term_freq) in docs[..len].iter().zip(term_freqs[..len].iter()) {
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn refill_scorer_from_current_doc<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
horizon: DocId,
|
||||
) {
|
||||
loop {
|
||||
let doc = scorer.doc();
|
||||
if doc >= horizon {
|
||||
break;
|
||||
}
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
debug_assert!(delta < HORIZON);
|
||||
score_combiner[delta as usize].update(scorer);
|
||||
scorer.advance();
|
||||
}
|
||||
}
|
||||
|
||||
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorers: &mut Vec<TScorer>,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
min_doc: DocId,
|
||||
use_score_doc_refill: bool,
|
||||
) {
|
||||
unordered_drain_filter(scorers, |scorer| {
|
||||
let horizon = min_doc + HORIZON;
|
||||
loop {
|
||||
let doc = scorer.doc();
|
||||
if doc >= horizon {
|
||||
return false;
|
||||
}
|
||||
// add this document
|
||||
let delta = doc - min_doc;
|
||||
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
|
||||
score_combiner[delta as usize].update(scorer);
|
||||
if scorer.advance() == TERMINATED {
|
||||
// remove the docset, it has been entirely consumed.
|
||||
return true;
|
||||
}
|
||||
let horizon = min_doc + HORIZON;
|
||||
for scorer in scorers.iter_mut() {
|
||||
if use_score_doc_refill {
|
||||
refill_scorer_with_score_docs(
|
||||
scorer,
|
||||
bitsets,
|
||||
score_combiner,
|
||||
docs,
|
||||
term_freqs,
|
||||
min_doc,
|
||||
horizon,
|
||||
);
|
||||
} else {
|
||||
refill_scorer_from_current_doc(scorer, bitsets, score_combiner, min_doc, horizon);
|
||||
}
|
||||
});
|
||||
}
|
||||
scorers.retain(|scorer| scorer.doc() != TERMINATED);
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
@@ -87,6 +253,8 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
||||
num_docs: u32,
|
||||
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
let use_score_doc_refill =
|
||||
TScoreCombiner::requires_scoring() && docsets.iter().all(Scorer::can_score_doc);
|
||||
let non_empty_docsets: Vec<TScorer> = docsets
|
||||
.into_iter()
|
||||
.filter(|docset| docset.doc() != TERMINATED)
|
||||
@@ -100,6 +268,9 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
doc: 0,
|
||||
score: 0.0,
|
||||
num_docs,
|
||||
refill_docs: [TERMINATED; COLLECT_BLOCK_BUFFER_LEN],
|
||||
refill_term_freqs: [1u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
use_score_doc_refill,
|
||||
};
|
||||
if union.refill() {
|
||||
union.advance();
|
||||
@@ -120,7 +291,10 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
&mut self.docsets,
|
||||
&mut self.bitsets,
|
||||
&mut self.scores,
|
||||
&mut self.refill_docs,
|
||||
&mut self.refill_term_freqs,
|
||||
min_doc,
|
||||
self.use_score_doc_refill,
|
||||
);
|
||||
true
|
||||
} else {
|
||||
@@ -248,12 +422,12 @@ where
|
||||
|
||||
// The target is outside of the buffered horizon.
|
||||
// advance all docsets to a doc >= to the target.
|
||||
unordered_drain_filter(&mut self.docsets, |docset| {
|
||||
for docset in &mut self.docsets {
|
||||
if docset.doc() < target {
|
||||
docset.seek(target);
|
||||
}
|
||||
docset.doc() == TERMINATED
|
||||
});
|
||||
}
|
||||
self.docsets.retain(|docset| docset.doc() != TERMINATED);
|
||||
|
||||
// at this point all of the docsets
|
||||
// are positioned on a doc >= to the target.
|
||||
|
||||
@@ -10,6 +10,8 @@ pub use simple_union::SimpleUnion;
|
||||
mod tests {
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::BitSet;
|
||||
|
||||
@@ -18,8 +20,8 @@ mod tests {
|
||||
use crate::postings::tests::test_skip_against_unoptimized;
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::union::bitset_union::BitSetPostingUnion;
|
||||
use crate::query::{BitSetDocSet, ConstScorer, VecDocSet};
|
||||
use crate::{tests, DocId};
|
||||
use crate::query::{BitSetDocSet, ConstScorer, Scorer, VecDocSet};
|
||||
use crate::{tests, DocId, Score};
|
||||
|
||||
fn vec_doc_set_from_docs_list(
|
||||
docs_list: &[Vec<DocId>],
|
||||
@@ -66,6 +68,61 @@ mod tests {
|
||||
}
|
||||
BitSetDocSet::from(doc_bitset)
|
||||
}
|
||||
|
||||
struct CountingScorer {
|
||||
docset: VecDocSet,
|
||||
score_calls: Arc<AtomicUsize>,
|
||||
score_doc_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl CountingScorer {
|
||||
fn new(
|
||||
doc_ids: Vec<DocId>,
|
||||
score_calls: Arc<AtomicUsize>,
|
||||
score_doc_calls: Arc<AtomicUsize>,
|
||||
) -> Self {
|
||||
CountingScorer {
|
||||
docset: VecDocSet::from(doc_ids),
|
||||
score_calls,
|
||||
score_doc_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for CountingScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.docset.advance()
|
||||
}
|
||||
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
self.docset.seek(target)
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.docset.doc()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.docset.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for CountingScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score_calls.fetch_add(1, Ordering::SeqCst);
|
||||
1.0
|
||||
}
|
||||
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score_doc_calls.fetch_add(1, Ordering::SeqCst);
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
fn aux_test_union(docs_list: &[Vec<DocId>]) {
|
||||
for constructor in [
|
||||
posting_list_union_from_docs_list,
|
||||
@@ -168,6 +225,22 @@ mod tests {
|
||||
]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_do_nothing_combiner_does_not_score_buffered_docs() {
|
||||
let score_calls = Arc::new(AtomicUsize::new(0));
|
||||
let score_doc_calls = Arc::new(AtomicUsize::new(0));
|
||||
let scorers = vec![
|
||||
CountingScorer::new(vec![1, 3, 5], score_calls.clone(), score_doc_calls.clone()),
|
||||
CountingScorer::new(vec![2, 3, 6], score_calls.clone(), score_doc_calls.clone()),
|
||||
];
|
||||
|
||||
let mut union = BufferedUnionScorer::build(scorers, DoNothingCombiner::default, 10);
|
||||
|
||||
assert_eq!(union.count_including_deleted(), 5);
|
||||
assert_eq!(score_calls.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(score_doc_calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
|
||||
for constructor in [
|
||||
posting_list_union_from_docs_list,
|
||||
|
||||
@@ -23,7 +23,7 @@ zstd-compression = ["zstd"]
|
||||
|
||||
[dev-dependencies]
|
||||
proptest = "1"
|
||||
criterion = { version = "0.5", default-features = false }
|
||||
criterion = { version = "0.8", default-features = false }
|
||||
names = "0.14"
|
||||
rand = "0.9"
|
||||
|
||||
|
||||
@@ -14,11 +14,8 @@ use itertools::Itertools;
|
||||
use tantivy_fst::Automaton;
|
||||
use tantivy_fst::automaton::AlwaysMatch;
|
||||
|
||||
use crate::sstable_index_v3::SSTableIndexV3Empty;
|
||||
use crate::streamer::{Streamer, StreamerBuilder};
|
||||
use crate::{
|
||||
BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, SSTableIndexV3, TermOrdinal, VoidSSTable,
|
||||
};
|
||||
use crate::{BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, TermOrdinal, VoidSSTable};
|
||||
|
||||
/// An SSTable is a sorted map that associates sorted `&[u8]` keys
|
||||
/// to any kind of typed values.
|
||||
@@ -288,33 +285,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
let (sstable_slice, index_slice) = main_slice.split(index_offset as usize);
|
||||
let sstable_index_bytes = index_slice.read_bytes()?;
|
||||
|
||||
let sstable_index = match version {
|
||||
2 => SSTableIndex::V2(
|
||||
crate::sstable_index_v2::SSTableIndex::load(sstable_index_bytes).map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
|
||||
})?,
|
||||
),
|
||||
3 => {
|
||||
let (sstable_index_bytes, mut footerv3_len_bytes) = sstable_index_bytes.rsplit(8);
|
||||
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
|
||||
if store_offset != 0 {
|
||||
SSTableIndex::V3(
|
||||
SSTableIndexV3::load(sstable_index_bytes, store_offset).map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
|
||||
})?,
|
||||
)
|
||||
} else {
|
||||
// if store_offset is zero, there is no index, so we build a pseudo-index
|
||||
// assuming a single block of sstable covering everything.
|
||||
SSTableIndex::V3Empty(SSTableIndexV3Empty::load(index_offset as usize))
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(io::Error::other(format!(
|
||||
"Unsupported sstable version, expected one of [2, 3], found {version}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let sstable_index = SSTableIndex::open(version, index_offset, sstable_index_bytes)?;
|
||||
|
||||
Ok(Dictionary {
|
||||
sstable_slice,
|
||||
@@ -525,10 +496,15 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
|
||||
// Open the block for the first ordinal.
|
||||
let mut bytes = Vec::new();
|
||||
let mut current_block_addr = self.sstable_index.get_block_with_ord(ord);
|
||||
let (mut current_block_addr, block_id) = self.sstable_index.get_and_locate_with_ord(ord);
|
||||
let mut current_sstable_delta_reader =
|
||||
self.sstable_delta_reader_block(current_block_addr.clone())?;
|
||||
let mut current_block_ordinal = current_block_addr.first_ordinal;
|
||||
let mut current_block_end_bound = self
|
||||
.sstable_index
|
||||
.get_block(block_id + 1)
|
||||
.map(|block_addr| block_addr.first_ordinal)
|
||||
.unwrap_or(u64::MAX);
|
||||
|
||||
loop {
|
||||
// move to the ord inside the current block
|
||||
@@ -557,17 +533,19 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO optimization: it is silly to do a binary search to get the block every single
|
||||
// time.
|
||||
//
|
||||
// Check if block changed for new term_ord
|
||||
let new_block_addr = self.sstable_index.get_block_with_ord(next_ord);
|
||||
if new_block_addr != current_block_addr {
|
||||
if next_ord >= current_block_end_bound {
|
||||
let (new_block_addr, block_id) =
|
||||
self.sstable_index.get_and_locate_with_ord(next_ord);
|
||||
current_block_addr = new_block_addr;
|
||||
current_block_ordinal = current_block_addr.first_ordinal;
|
||||
current_sstable_delta_reader =
|
||||
self.sstable_delta_reader_block(current_block_addr.clone())?;
|
||||
bytes.clear();
|
||||
current_block_end_bound = self
|
||||
.sstable_index
|
||||
.get_block(block_id + 1)
|
||||
.map(|block_addr| block_addr.first_ordinal)
|
||||
.unwrap_or(u64::MAX)
|
||||
}
|
||||
ord = next_ord;
|
||||
}
|
||||
|
||||
319
sstable/src/index/mod.rs
Normal file
319
sstable/src/index/mod.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
pub(crate) mod v2;
|
||||
pub(crate) mod v3;
|
||||
|
||||
use std::io::{self, Read, Write};
|
||||
use std::ops::Range;
|
||||
|
||||
use common::{BinarySerializable, FixedSize, OwnedBytes};
|
||||
use tantivy_fst::{Automaton, MapBuilder};
|
||||
|
||||
use crate::{TermOrdinal, common_prefix_len};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SSTableIndex {
|
||||
V2(v2::SSTableIndex),
|
||||
V3(v3::SSTableIndexV3),
|
||||
V3Empty(v3::SSTableIndexV3Empty),
|
||||
}
|
||||
|
||||
impl SSTableIndex {
|
||||
pub(crate) fn open(
|
||||
version: u32,
|
||||
index_offset: u64,
|
||||
index_bytes: OwnedBytes,
|
||||
) -> io::Result<Self> {
|
||||
let index = match version {
|
||||
2 => {
|
||||
SSTableIndex::V2(v2::SSTableIndex::load(index_bytes).map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
|
||||
})?)
|
||||
}
|
||||
3 => {
|
||||
let (index_bytes, mut footerv3_len_bytes) = index_bytes.rsplit(8);
|
||||
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
|
||||
if store_offset != 0 {
|
||||
SSTableIndex::V3(v3::SSTableIndexV3::load(index_bytes, store_offset).map_err(
|
||||
|_| io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption"),
|
||||
)?)
|
||||
} else {
|
||||
// if store_offset is zero, there is no index, so we build a pseudo-index
|
||||
// assuming a single block of sstable covering everything.
|
||||
SSTableIndex::V3Empty(v3::SSTableIndexV3Empty::load(index_offset as usize))
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(io::Error::other(format!(
|
||||
"Unsupported sstable version, expected one of [2, 3], found {version}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the requested block.
|
||||
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the block id of the block that would contain `key`.
|
||||
///
|
||||
/// Returns None if `key` is lexicographically after the last key recorded.
|
||||
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
|
||||
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the block that would contain `key`.
|
||||
///
|
||||
/// Returns None if `key` is lexicographically after the last key recorded.
|
||||
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
|
||||
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
|
||||
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_and_locate_with_ord(ord),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_and_locate_with_ord(ord),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_and_locate_with_ord(ord),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_block_for_automaton<'a>(
|
||||
&'a self,
|
||||
automaton: &'a impl Automaton,
|
||||
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => {
|
||||
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
|
||||
}
|
||||
SSTableIndex::V3(v3_index) => {
|
||||
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
|
||||
}
|
||||
SSTableIndex::V3Empty(v3_empty) => {
|
||||
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum BlockIter<V2, V3, T> {
|
||||
V2(V2),
|
||||
V3(V3),
|
||||
V3Empty(std::iter::Once<T>),
|
||||
}
|
||||
|
||||
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
BlockIter::V2(v2) => v2.next(),
|
||||
BlockIter::V3(v3) => v3.next(),
|
||||
BlockIter::V3Empty(once) => once.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct BlockAddr {
|
||||
pub first_ordinal: u64,
|
||||
pub byte_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl BlockAddr {
|
||||
fn to_block_start(&self) -> BlockStartAddr {
|
||||
BlockStartAddr {
|
||||
first_ordinal: self.first_ordinal,
|
||||
byte_range_start: self.byte_range.start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct BlockStartAddr {
|
||||
first_ordinal: u64,
|
||||
byte_range_start: usize,
|
||||
}
|
||||
|
||||
impl BlockStartAddr {
|
||||
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
|
||||
BlockAddr {
|
||||
first_ordinal: self.first_ordinal,
|
||||
byte_range: self.byte_range_start..byte_range_end,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockMeta {
|
||||
/// Any byte string that is lexicographically greater or equal to
|
||||
/// the last key in the block,
|
||||
/// and yet strictly smaller than the first key in the next block.
|
||||
pub last_key_or_greater: Vec<u8>,
|
||||
pub block_addr: BlockAddr,
|
||||
}
|
||||
|
||||
impl BinarySerializable for BlockStartAddr {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
|
||||
let start = self.byte_range_start as u64;
|
||||
start.serialize(writer)?;
|
||||
self.first_ordinal.serialize(writer)
|
||||
}
|
||||
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
let byte_range_start = u64::deserialize(reader)? as usize;
|
||||
let first_ordinal = u64::deserialize(reader)?;
|
||||
Ok(BlockStartAddr {
|
||||
first_ordinal,
|
||||
byte_range_start,
|
||||
})
|
||||
}
|
||||
|
||||
// Provided method
|
||||
fn num_bytes(&self) -> u64 {
|
||||
BlockStartAddr::SIZE_IN_BYTES as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedSize for BlockStartAddr {
|
||||
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
|
||||
}
|
||||
|
||||
/// Given that left < right,
|
||||
/// mutates `left into a shorter byte string left'` that
|
||||
/// matches `left <= left' < right`.
|
||||
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
|
||||
assert!(&left[..] < right);
|
||||
let common_len = common_prefix_len(left, right);
|
||||
if left.len() == common_len {
|
||||
return;
|
||||
}
|
||||
// It is possible to do one character shorter in some case,
|
||||
// but it is not worth the extra complexity
|
||||
for pos in (common_len + 1)..left.len() {
|
||||
if left[pos] != u8::MAX {
|
||||
left[pos] += 1;
|
||||
left.truncate(pos + 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SSTableIndexBuilder {
|
||||
blocks: Vec<BlockMeta>,
|
||||
}
|
||||
|
||||
impl SSTableIndexBuilder {
|
||||
/// In order to make the index as light as possible, we
|
||||
/// try to find a shorter alternative to the last key of the last block
|
||||
/// that is still smaller than the next key.
|
||||
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
|
||||
if let Some(last_block) = self.blocks.last_mut() {
|
||||
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
|
||||
self.blocks.push(BlockMeta {
|
||||
last_key_or_greater: last_key.to_vec(),
|
||||
block_addr: BlockAddr {
|
||||
byte_range,
|
||||
first_ordinal,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
|
||||
if self.blocks.len() <= 1 {
|
||||
return Ok(0);
|
||||
}
|
||||
let counting_writer = common::CountingWriter::wrap(wrt);
|
||||
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
map_builder
|
||||
.insert(&block.last_key_or_greater, i as u64)
|
||||
.map_err(fst_error_to_io_error)?;
|
||||
}
|
||||
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
|
||||
let written_bytes = counting_writer.written_bytes();
|
||||
let mut wrt = counting_writer.finish();
|
||||
|
||||
let mut block_store_writer = v3::BlockAddrStoreWriter::new();
|
||||
for block in &self.blocks {
|
||||
block_store_writer.write_block_meta(block.block_addr.clone())?;
|
||||
}
|
||||
block_store_writer.serialize(&mut wrt)?;
|
||||
|
||||
Ok(written_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
|
||||
match error {
|
||||
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
|
||||
tantivy_fst::Error::Io(ioerror) => ioerror,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[track_caller]
|
||||
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
|
||||
let mut left_buf = left.to_vec();
|
||||
super::find_shorter_str_in_between(&mut left_buf, right);
|
||||
assert!(left_buf.len() <= left.len());
|
||||
assert!(left <= &left_buf);
|
||||
assert!(&left_buf[..] < right);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_shorter_str_in_between() {
|
||||
test_find_shorter_str_in_between_aux(b"", b"hello");
|
||||
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
|
||||
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(100))]
|
||||
#[test]
|
||||
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
|
||||
if left < right {
|
||||
test_find_shorter_str_in_between_aux(&left, &right);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,13 @@ impl SSTableIndex {
|
||||
self.get_block(self.locate_with_ord(ord)).unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
|
||||
let location = self.locate_with_ord(ord);
|
||||
// locate_with_ord always returns an index within range
|
||||
let block_addr = self.get_block(location).unwrap();
|
||||
(block_addr, location as u64)
|
||||
}
|
||||
|
||||
pub(crate) fn get_block_for_automaton<'a>(
|
||||
&'a self,
|
||||
automaton: &'a impl Automaton,
|
||||
@@ -1,106 +1,14 @@
|
||||
use std::io::{self, Read, Write};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{BinarySerializable, FixedSize, OwnedBytes};
|
||||
use tantivy_bitpacker::{BitPacker, compute_num_bits};
|
||||
use tantivy_fst::raw::Fst;
|
||||
use tantivy_fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer};
|
||||
use tantivy_fst::{Automaton, IntoStreamer, Map, Streamer};
|
||||
|
||||
use super::{BlockAddr, BlockStartAddr};
|
||||
use crate::block_match_automaton::can_block_match_automaton;
|
||||
use crate::{SSTableDataCorruption, TermOrdinal, common_prefix_len};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SSTableIndex {
|
||||
V2(crate::sstable_index_v2::SSTableIndex),
|
||||
V3(SSTableIndexV3),
|
||||
V3Empty(SSTableIndexV3Empty),
|
||||
}
|
||||
|
||||
impl SSTableIndex {
|
||||
/// Get the [`BlockAddr`] of the requested block.
|
||||
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the block id of the block that would contain `key`.
|
||||
///
|
||||
/// Returns None if `key` is lexicographically after the last key recorded.
|
||||
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
|
||||
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the block that would contain `key`.
|
||||
///
|
||||
/// Returns None if `key` is lexicographically after the last key recorded.
|
||||
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
|
||||
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
|
||||
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
|
||||
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
|
||||
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_block_for_automaton<'a>(
|
||||
&'a self,
|
||||
automaton: &'a impl Automaton,
|
||||
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
|
||||
match self {
|
||||
SSTableIndex::V2(v2_index) => {
|
||||
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
|
||||
}
|
||||
SSTableIndex::V3(v3_index) => {
|
||||
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
|
||||
}
|
||||
SSTableIndex::V3Empty(v3_empty) => {
|
||||
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum BlockIter<V2, V3, T> {
|
||||
V2(V2),
|
||||
V3(V3),
|
||||
V3Empty(std::iter::Once<T>),
|
||||
}
|
||||
|
||||
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
BlockIter::V2(v2) => v2.next(),
|
||||
BlockIter::V3(v3) => v3.next(),
|
||||
BlockIter::V3Empty(once) => once.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::{SSTableDataCorruption, TermOrdinal};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SSTableIndexV3 {
|
||||
@@ -160,6 +68,11 @@ impl SSTableIndexV3 {
|
||||
self.block_addr_store.binary_search_ord(ord).1
|
||||
}
|
||||
|
||||
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
|
||||
let (location, block_addr) = self.block_addr_store.binary_search_ord(ord);
|
||||
(block_addr, location)
|
||||
}
|
||||
|
||||
pub(crate) fn get_block_for_automaton<'a>(
|
||||
&'a self,
|
||||
automaton: &'a impl Automaton,
|
||||
@@ -216,7 +129,7 @@ impl<A: Automaton> Iterator for GetBlockForAutomaton<'_, A> {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SSTableIndexV3Empty {
|
||||
block_addr: BlockAddr,
|
||||
pub block_addr: BlockAddr,
|
||||
}
|
||||
|
||||
impl SSTableIndexV3Empty {
|
||||
@@ -230,8 +143,8 @@ impl SSTableIndexV3Empty {
|
||||
}
|
||||
|
||||
/// Get the [`BlockAddr`] of the requested block.
|
||||
pub(crate) fn get_block(&self, _block_id: u64) -> Option<BlockAddr> {
|
||||
Some(self.block_addr.clone())
|
||||
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
|
||||
(block_id == 0).then(|| self.block_addr.clone())
|
||||
}
|
||||
|
||||
/// Get the block id of the block that would contain `key`.
|
||||
@@ -256,146 +169,9 @@ impl SSTableIndexV3Empty {
|
||||
pub(crate) fn get_block_with_ord(&self, _ord: TermOrdinal) -> BlockAddr {
|
||||
self.block_addr.clone()
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct BlockAddr {
|
||||
pub first_ordinal: u64,
|
||||
pub byte_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl BlockAddr {
|
||||
fn to_block_start(&self) -> BlockStartAddr {
|
||||
BlockStartAddr {
|
||||
first_ordinal: self.first_ordinal,
|
||||
byte_range_start: self.byte_range.start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct BlockStartAddr {
|
||||
first_ordinal: u64,
|
||||
byte_range_start: usize,
|
||||
}
|
||||
|
||||
impl BlockStartAddr {
|
||||
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
|
||||
BlockAddr {
|
||||
first_ordinal: self.first_ordinal,
|
||||
byte_range: self.byte_range_start..byte_range_end,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockMeta {
|
||||
/// Any byte string that is lexicographically greater or equal to
|
||||
/// the last key in the block,
|
||||
/// and yet strictly smaller than the first key in the next block.
|
||||
pub last_key_or_greater: Vec<u8>,
|
||||
pub block_addr: BlockAddr,
|
||||
}
|
||||
|
||||
impl BinarySerializable for BlockStartAddr {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
|
||||
let start = self.byte_range_start as u64;
|
||||
start.serialize(writer)?;
|
||||
self.first_ordinal.serialize(writer)
|
||||
}
|
||||
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
let byte_range_start = u64::deserialize(reader)? as usize;
|
||||
let first_ordinal = u64::deserialize(reader)?;
|
||||
Ok(BlockStartAddr {
|
||||
first_ordinal,
|
||||
byte_range_start,
|
||||
})
|
||||
}
|
||||
|
||||
// Provided method
|
||||
fn num_bytes(&self) -> u64 {
|
||||
BlockStartAddr::SIZE_IN_BYTES as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedSize for BlockStartAddr {
|
||||
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
|
||||
}
|
||||
|
||||
/// Given that left < right,
|
||||
/// mutates `left into a shorter byte string left'` that
|
||||
/// matches `left <= left' < right`.
|
||||
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
|
||||
assert!(&left[..] < right);
|
||||
let common_len = common_prefix_len(left, right);
|
||||
if left.len() == common_len {
|
||||
return;
|
||||
}
|
||||
// It is possible to do one character shorter in some case,
|
||||
// but it is not worth the extra complexity
|
||||
for pos in (common_len + 1)..left.len() {
|
||||
if left[pos] != u8::MAX {
|
||||
left[pos] += 1;
|
||||
left.truncate(pos + 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SSTableIndexBuilder {
|
||||
blocks: Vec<BlockMeta>,
|
||||
}
|
||||
|
||||
impl SSTableIndexBuilder {
|
||||
/// In order to make the index as light as possible, we
|
||||
/// try to find a shorter alternative to the last key of the last block
|
||||
/// that is still smaller than the next key.
|
||||
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
|
||||
if let Some(last_block) = self.blocks.last_mut() {
|
||||
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
|
||||
self.blocks.push(BlockMeta {
|
||||
last_key_or_greater: last_key.to_vec(),
|
||||
block_addr: BlockAddr {
|
||||
byte_range,
|
||||
first_ordinal,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
|
||||
if self.blocks.len() <= 1 {
|
||||
return Ok(0);
|
||||
}
|
||||
let counting_writer = common::CountingWriter::wrap(wrt);
|
||||
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
map_builder
|
||||
.insert(&block.last_key_or_greater, i as u64)
|
||||
.map_err(fst_error_to_io_error)?;
|
||||
}
|
||||
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
|
||||
let written_bytes = counting_writer.written_bytes();
|
||||
let mut wrt = counting_writer.finish();
|
||||
|
||||
let mut block_store_writer = BlockAddrStoreWriter::new();
|
||||
for block in &self.blocks {
|
||||
block_store_writer.write_block_meta(block.block_addr.clone())?;
|
||||
}
|
||||
block_store_writer.serialize(&mut wrt)?;
|
||||
|
||||
Ok(written_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
|
||||
match error {
|
||||
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
|
||||
tantivy_fst::Error::Io(ioerror) => ioerror,
|
||||
pub(crate) fn get_and_locate_with_ord(&self, _ord: TermOrdinal) -> (BlockAddr, u64) {
|
||||
(self.block_addr.clone(), 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,14 +423,14 @@ fn binary_search(max: u64, cmp_fn: impl Fn(u64) -> std::cmp::Ordering) -> Result
|
||||
Err(left)
|
||||
}
|
||||
|
||||
struct BlockAddrStoreWriter {
|
||||
pub(crate) struct BlockAddrStoreWriter {
|
||||
buffer_block_metas: Vec<u8>,
|
||||
buffer_addrs: Vec<u8>,
|
||||
block_addrs: Vec<BlockAddr>,
|
||||
}
|
||||
|
||||
impl BlockAddrStoreWriter {
|
||||
fn new() -> Self {
|
||||
pub(crate) fn new() -> Self {
|
||||
BlockAddrStoreWriter {
|
||||
buffer_block_metas: Vec::new(),
|
||||
buffer_addrs: Vec::new(),
|
||||
@@ -662,7 +438,7 @@ impl BlockAddrStoreWriter {
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_block(&mut self) -> io::Result<()> {
|
||||
pub(crate) fn flush_block(&mut self) -> io::Result<()> {
|
||||
if self.block_addrs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -741,7 +517,7 @@ impl BlockAddrStoreWriter {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
|
||||
pub(crate) fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
|
||||
self.block_addrs.push(block_addr);
|
||||
if self.block_addrs.len() >= STORE_BLOCK_LEN {
|
||||
self.flush_block()?;
|
||||
@@ -749,7 +525,7 @@ impl BlockAddrStoreWriter {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
|
||||
pub(crate) fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
|
||||
self.flush_block()?;
|
||||
let len = self.buffer_block_metas.len() as u64;
|
||||
len.serialize(wrt)?;
|
||||
@@ -824,8 +600,9 @@ mod tests {
|
||||
use common::OwnedBytes;
|
||||
|
||||
use super::*;
|
||||
use crate::SSTableDataCorruption;
|
||||
use crate::block_match_automaton::tests::EqBuffer;
|
||||
use crate::index::BlockMeta;
|
||||
use crate::{SSTableDataCorruption, SSTableIndexBuilder};
|
||||
|
||||
#[test]
|
||||
fn test_sstable_index() {
|
||||
@@ -874,36 +651,7 @@ mod tests {
|
||||
assert!(matches!(data_corruption_err, SSTableDataCorruption));
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
|
||||
let mut left_buf = left.to_vec();
|
||||
super::find_shorter_str_in_between(&mut left_buf, right);
|
||||
assert!(left_buf.len() <= left.len());
|
||||
assert!(left <= &left_buf);
|
||||
assert!(&left_buf[..] < right);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_shorter_str_in_between() {
|
||||
test_find_shorter_str_in_between_aux(b"", b"hello");
|
||||
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
|
||||
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
|
||||
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(100))]
|
||||
#[test]
|
||||
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
|
||||
if left < right {
|
||||
test_find_shorter_str_in_between_aux(&left, &right);
|
||||
}
|
||||
}
|
||||
}
|
||||
// use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_find_best_slop() {
|
||||
@@ -47,9 +47,8 @@ pub mod merge;
|
||||
mod streamer;
|
||||
pub mod value;
|
||||
|
||||
mod sstable_index_v3;
|
||||
pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3};
|
||||
mod sstable_index_v2;
|
||||
mod index;
|
||||
pub use index::{BlockAddr, SSTableIndex, SSTableIndexBuilder};
|
||||
pub(crate) mod vint;
|
||||
pub use dictionary::{Dictionary, TermOrdHit};
|
||||
pub use streamer::{Streamer, StreamerBuilder};
|
||||
|
||||
@@ -27,7 +27,7 @@ rand = "0.9"
|
||||
zipf = "7.0.0"
|
||||
rustc-hash = "2.1.0"
|
||||
proptest = "1.2.0"
|
||||
binggan = { version = "0.16.1" }
|
||||
binggan = { version = "0.17.0" }
|
||||
rand_distr = "0.5"
|
||||
|
||||
[features]
|
||||
|
||||
Reference in New Issue
Block a user