Compare commits

..

53 Commits

Author SHA1 Message Date
pascal
74f37f045d Avoid scoring buffered unions when scores are ignored
BufferedUnionScorer can use score_doc during refill only when the score combiner needs scores. DoNothingCombiner now advertises that scoring is unnecessary, preserving the no-score path for count collectors and avoiding wasted score_doc calls.

Add a regression test that verifies DoNothingCombiner does not invoke score() or score_doc() while counting a buffered union.
2026-05-31 21:58:29 +02:00
pascal
72cca113cd cargo fmt, remove impl 2026-05-31 21:48:03 +02:00
pascal
672bf45235 Clarify postings copy variable names 2026-05-31 20:50:35 +02:00
pascal
33ef167441 Share BM25 fieldnorm caches per thread
Reuse BM25 TF normalization caches for weights with the same average fieldnorm using a bounded thread-local LRU. This avoids recomputing and duplicating the cache for many terms on the same field without adding cross-thread contention.
2026-05-31 19:13:18 +02:00
pascal
bf8b263f16 Optimize buffered union scoring with block refills
Add horizon-limited buffering APIs for docsets and scorers so buffered union can refill from block-oriented postings while preserving term frequencies. This lets term scorers score buffered docs directly and reduces per-document refill overhead for dense unions.
2026-05-31 19:13:18 +02:00
pascal
24a97dbe69 Split buffered refill from scorer removal 2026-05-31 12:20:07 +02:00
pascal
34fec8b23e Defer terminated scorer removal during buffered refill 2026-05-31 11:53:11 +02:00
Paul Masurel
46b3fb9ed3 Relying on upstream version of datasketch and stop using HLL 4. (#2936)
We were relying on a fork for:

a bugfix in LIST serialization
a better API exposing a new Coupon type, required for caching coupons.
We also stop using HLL8 in hope to fix
https://datadoghq.atlassian.net/browse/CLOUDPREM-625

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-05-19 13:29:35 +02:00
trinity-1686a
fbe620b9b4 Merge pull request #2933 from quickwit-oss/1686a/sstable-opt
optimise sstable index access pattern
2026-05-19 11:43:17 +02:00
trinity-1686a
95d8a3989a cr 2026-05-19 11:38:48 +02:00
trinity-1686a
ea61a68db4 skip sstable index binary search when ordinal is in same block 2026-05-16 11:35:38 +02:00
trinity-1686a
c367df37c1 refactor sstable index 2026-05-16 11:30:02 +02:00
Mohammad Dashti
d99a5d4e91 Rename validate_aggregation_fields to validate_aggregation_fields_exist
Applies @PSeitz's review suggestion to make the function name more
descriptive of what it checks. Also adds a doc note clarifying why
validation is opt-in rather than enforced by default.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
2de6f075ce Fixed the example 2026-05-16 15:45:20 +08:00
Mohammad Dashti
18080067c7 Applied PR comment:
I would move it outside of the aggregation. You can fetch the fields from the aggregation request and do a validation in a helper function
2026-05-16 15:45:20 +08:00
Mohammad Dashti
95db7d2e5c Revert "Revert all impl."
This reverts commit d5e0991549a05bf80f19f853f7689ad69f96e7e5.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
fc017c4c74 Applied PR comments. 2026-05-16 15:45:20 +08:00
Mohammad Dashti
141c91d028 Added a flag: strict_validation 2026-05-16 15:45:20 +08:00
Mohammad Dashti
36a83e7c1a Fixed agg validation 2026-05-16 15:45:20 +08:00
jinhelin
be11f8a6a1 Fix opening positions file error 2026-05-14 15:55:59 +08:00
dependabot[bot]
4305e4029e Update binggan requirement from 0.16.1 to 0.17.0
Updates the requirements on [binggan](https://github.com/pseitz/binggan) to permit the latest version.
- [Changelog](https://github.com/PSeitz/binggan/blob/main/CHANGELOG.md)
- [Commits](https://github.com/pseitz/binggan/commits)

---
updated-dependencies:
- dependency-name: binggan
  dependency-version: 0.17.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-12 15:10:20 +08:00
Pascal Seitz
edfb02b47e switch to enum, fix mixed types for cardinality agg 2026-05-05 16:39:51 +08:00
Pascal Seitz
d0fad88bac use bitsets for card agg 2026-05-05 16:39:51 +08:00
Pascal Seitz
351280c0b4 add card bench for high card 2026-05-05 16:39:51 +08:00
James Sewell
4480cf0a98 Enable BMW for single-scorer boolean queries by removing early return in scorer_union (#2915)
The early return for `scorers.len() == 1` in `scorer_union` short-circuits a single TermScorer into `SpecializedScorer::Other`, bypassing the `TermUnion` path that enables block-max WAND (BMW) in `for_each_pruning`.

This was originally addressed in PR #2898 (backed out), which added a special case in `BooleanWeight::for_each_pruning`. PR #2912 (merged as d27ca164a) added a single-scorer fast path inside `block_wand` itself, but did not remove this early return — so a single SHOULD TermScorer still never reaches the BMW path.

Removing the early return lets a single TermScorer with freq reading flow through to `SpecializedScorer::TermUnion`, where `block_wand` → `block_wand_single_scorer` handles it efficiently.
2026-04-28 14:49:53 -07:00
Pascal Seitz
d47abdf104 early cut off for order by sub agg in term agg 2026-04-28 16:59:59 +02:00
Pascal Seitz
c11952eb7c add order by agg benchmark 2026-04-28 16:59:59 +02:00
trinity-1686a
09667ee9c8 Merge pull request #2909 from osyniakov/claude/add-ossf-scorecard-1z6Vn
Add OpenSSF Scorecard workflow
2026-04-28 11:57:36 +02:00
trinity-1686a
333ccf5300 Merge pull request #2896 from osyniakov/claude/fix-issues-5945-5937-eQm1Q
ci: pin GitHub Actions to full commit SHAs and restrict token permissions
2026-04-28 11:57:18 +02:00
Oleksii Syniakov
60a39a4689 Merge branch 'main' into claude/fix-issues-5945-5937-eQm1Q 2026-04-28 10:28:23 +02:00
Oleksii Syniakov
f8f3e4277f remove not neeeded permissions for the public repo 2026-04-28 10:09:30 +02:00
Oleksii Syniakov
ff1433713a bump upload-sarif -> 4.35.2
Co-authored-by: trinity-1686a <trinity.pointard@gmail.com>
2026-04-28 10:07:45 +02:00
trinity-1686a
ca139d8eb1 Merge pull request #2910 from quickwit-oss/abdul.andha/composite-agg-after
Composite aggregations: send after key on last page
2026-04-27 23:38:52 +02:00
Abdul Andha
ac508108aa address pr comment 2026-04-27 12:39:38 -04:00
Paul Masurel
63da5a21b2 Optimizing top K using Adrien Grand's ideas (#2865)
* Optimizing top K using Adrien Grand's ideas

https://jpountz.github.io/2025/08/28/compiled-vs-vectorized-search-engine-edition.html

* Suffix-sum pruning for multi-term intersection candidates

After scoring each secondary in Phase 2, check whether remaining
secondaries' block_max scores can still beat the threshold. Skip
to the next candidate early if impossible, avoiding expensive seeks
into later secondaries.

Improves three-term intersection by ~8% on the balanced benchmark
while keeping two-term performance neutral.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Claude CR comment

* Removed 16 term scorer limit.

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-26 12:14:40 +02:00
lif
54cd5bba98 fix: skip sentinel facet ords in harvest to prevent wrong root (#2867)
When a document has the exact registered facet path (not a child),
compute_collapse_mapping_one maps it to a sentinel (u64::MAX, 0).
Without filtering, harvest() passes u64::MAX to ord_to_term which
resolves to the last dictionary entry, producing a spurious facet
from an unrelated branch.

Skip entries where facet_ord == u64::MAX in harvest().

Closes #2494

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-04-25 22:23:30 +02:00
Paul Masurel
d27ca164a9 block_wand: use single-scorer path when there is only one scorer 2026-04-25 16:35:00 +02:00
dependabot[bot]
2f5a48e8b1 Update criterion requirement from 0.5 to 0.8 (#2873)
Updates the requirements on [criterion](https://github.com/criterion-rs/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/criterion-rs/criterion.rs/releases)
- [Changelog](https://github.com/criterion-rs/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/criterion-rs/criterion.rs/compare/0.5.0...criterion-v0.8.2)

---
updated-dependencies:
- dependency-name: criterion
  dependency-version: 0.8.2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:15:53 +02:00
dependabot[bot]
ae0ab907fe Bump actions/checkout from 4 to 6 (#2875)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:15:27 +02:00
dependabot[bot]
7d62e084e7 Bump codecov/codecov-action from 3 to 6 (#2876)
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 6.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v3...v6)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:14:54 +02:00
James Sewell
322286ee16 Tighen Block-Max in single-scorer (#2897)
In the Block-Max WAND single-scorer, it uses block_max_score() < threshold,
whereas the multi-term one uses  block_max_score_upperbound <= threshold.

As both of these are guarded later on with if score > threshold we can
use the more efficent form in single-scorer.

Single-scorer block skip (<, should be <=): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L231
Multi-scorer block skip (already <=): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L179
Single-scorer per-doc guard (>): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L246
Multi-scorer per-doc guard (>): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L206

This will improve performance when there are many identical scores.
2026-04-25 14:13:07 +02:00
RJ Barman
73ad18fa1e fix: Add space for missing sentinel in allowed bitset when a missing key is provided (#119) (#2907)
## Bug Overview
Under certain conditions, a `terms` aggregation request can cause a
bounds-check panic. Those conditions are:
- The queried field must be a text field
- There must be a segment where the number of distinct terms in it's
dictionary for the queried field is divisible by 64 (i.e.e where
`count(term_dict.keys) % 64 == 0`)
- That same segment must contain at least one document that does not
contain this field.
- The request contain a `missing` key that is a string.
- The request must contain an `include` or `exclude` filter.
For example:
```json
{
    "my_bool": {
        "terms": {
            "field": "title",
            "include": "foo",
            "missing": "__NULL__",
        }
    }
}
```
Check out the added tests in `src/aggregation/bucket/term_agg.rs` to see
this in action

## How the bug happens
### Preparation
While preparing the aggregation nodes:
1) When we've provided a `missing` key, we derive a missing sentinel.
For string keys this column's max value (which for string keys is always
the number of terms in this segment) + 1.
2) for string columns only, we optionally prep an "allowed" `BitSet` for
allowed term ids. (`build_allowed_term_ids_for_str` in
`src/aggregation/agg_data.rs`)
- If no `include` or `exclude` filter is provided, this just returns
`None`, causing this check to be skipped down the line
- Otherwise the bitset is initialized to be able to hold the exact
number of terms in the segments term dictionary, and the bits are set to
signify which terms are to be included in the results.

### Collection
If we have an "allowed" `BitSet`, filter documents against that. For
each document, we check if the `BitSet` contains the documents term id.
For documents without the field, this is the missing sentinel we derived
earlier, minus 1 (to account for zero-based indexing): `(num_terms + 1)
- 1`.However, the `BitSet`s size is only `num_terms`. Normally, this
slips by without a problem, but if `num_terms % 64 == 0`, this will
cause a panic.

### Why `BitSet` panics
`BitSet` is represented under the hood by a boxed slice of `u64`s. When
you go to check a bit using `BitSet::contains`, it must determine which
of those `u64`s the bit is in, and then the position within that `u64`
of the bit.

In cases where the number of terms is not divisible by 64, the `BitSet`
must waste some bits. When we then look up the missing sentinel's bit,
it happens to be one of those wasted bits, for which `BitSet` is happy
to return the value of. For example, if the number of terms was 63:
```rust
let bitset_init_size = 63; // so BitSet's boxed slice has a length of 1, capable of holding 64 bits, term id [0, 62]
let missing_sentinel = 63; // num_terms + 1 - 1;
let byte_pos = missing_sentinel / 64; // 0 - within the valid slice
let bit_pos = missing_sentinel % 64; // 63 - hits the 1 wasted bit
```

But if the number of terms is indeed divisible by 64, then the `BitSet`
is perfectly aligned to the byte boundary:
```rust
let bitset_init_size = 64; // so BitSet's boxed slice has a length of 1, capable of holding 64 bits, term ids [0, 63]
let missing_sentinel = 64; // num_terms + 1 - 1, 
let byte_pos = missing_sentinel / 64; // 1 - idx 1 >= slice length 1
let bit_pos = missing_sentinel % 64; // 0 
```
We try to access a byte outside of the bounds of the boxed slice,
causing a panic from the bounds check to failing.

## Fixing it
The fix is simple. If we need to account for the missing sentinel,
initialize the `BitSet` with capacity for one more bit.

## Tests
- Added a bunch of unit tests that hit these conditions. I ensured they
failed without the fix, and that they now pass.
- All unit tests pass with the fix in place

## Other
- The investigation that led to finding this bug began with
https://github.com/paradedb/paradedb/issues/4746.
2026-04-25 14:11:47 +02:00
Abdul Andha
4fbae92187 send after key on last page 2026-04-24 15:33:26 -04:00
Cameron
89f0cef807 Fix O(2^n) query parser regression for deeply-nested queries (#2905)
* Fix O(2^n) query parser regression for deeply-nested queries

The top-level `ast()` parser used `alt((boolean_expr, single_leaf))` at
every group level. When the group contained a single leaf with no
trailing operand, `boolean_expr` would parse `occur_leaf` (recursing
into the inner group), fail at `multispace1`, backtrack, and then
`single_leaf` would re-parse `occur_leaf` from scratch. Every nesting
level doubled the work, giving O(2^n) time for queries like
`(((((title:test)))))`.

Parse `occur_leaf` once and peek ahead for a trailing operand instead
of backtracking. This keeps parsing O(n) and also avoids the duplicate
parse for simple single-leaf queries.

Fixes #2498.

Measured on the issue reproducer (release build):

    depth   before     after
       20   0.87 s   <1 us
       25  28.23 s   <1 us
       60  (years)   ~5 us

Non-pathological queries are unaffected or slightly faster:

    query                     before     after
    hello                     650 ns     308 ns
    a AND b AND c            1380 ns    1364 ns
    title:rust AND (...)     3426 ns    3460 ns

All 53 existing grammar tests and 56 query_parser tests pass. Adds a
regression test at depth 60 that would not complete under the old
parser.

* Add ignored benchmark for nested query parsing at depth 20/21

Matches the depths from issue #2498 which reported 0.87 s / 1.72 s
under the regression. With the fix these parse in single-digit
microseconds. Runs via:

  cargo test -p tantivy-query-grammar --release bench_deeply_nested \
      -- --ignored --nocapture

* Propagate Err::Failure and Err::Incomplete from operand parser

`alt((boolean_expr, single_leaf))` only retried on `Err::Error` and
propagated `Err::Failure` and `Err::Incomplete`. The replacement was
catching all three with `Err(_)`, which would silently fall back to
a single leaf if any cut point were ever added to `operand_leaf` or
its descendants. Match specifically on `Err::Error` to preserve the
original `alt` semantics.

* Replace inline bench with binggan bench in benches/

Move the nested-query benchmark out of the query-grammar test module
and into a proper binggan benchmark at benches/query_parser_nested.rs,
registered as a harnessless bench in Cargo.toml. Keeps the correctness
regression test (depth 60) in place.

Run with: cargo bench --bench query_parser_nested

* Fix rustfmt import ordering in query_parser_nested bench
2026-04-24 03:54:00 -04:00
Claude
a5d297c75f Add OpenSSF Scorecard workflow
Runs weekly security analysis and uploads SARIF results to GitHub code
scanning. Third-party actions are pinned by commit SHA. Adds the Scorecard
badge to the README.

Based on quickwit-oss/quickwit#5969.
2026-04-24 06:56:58 +00:00
Pascal Seitz
2e16243f9a fix memory consumption for histogram 2026-04-21 13:58:39 +02:00
Pascal Seitz
e015abab8e docs: add 0.26.1 changelog entry for aggregation perf fix 2026-04-21 11:12:37 +02:00
Pascal Seitz
73c711ec74 perf(agg): only measure active parent bucket in composite collect
Same change as 26a589e for SegmentCompositeCollector: get_memory_consumption
summed across all parent_buckets on every block, scaling with outer bucket
cardinality. Pass parent_bucket_id and index the single bucket.
2026-04-21 07:26:58 +02:00
Pascal Seitz
cb037c8079 add inline 2026-04-21 07:26:58 +02:00
Pascal Seitz
ed3453606b agg fix: compute memory consumption only for current bucket 2026-04-21 07:26:58 +02:00
Pascal Seitz
e9641f99c5 add nested term benchmark 2026-04-21 07:26:58 +02:00
Claude
3a6a3de8d7 ci: update pinned Action SHAs to current latest versions
The previous commit pinned actions to commit SHAs but used stale
version tags (v4.2.2, v2.7.5, old nextest/cargo-llvm-cov refs).
Update to the actual current HEAD of each pinned tag:

  actions/checkout        v4.2.2 → v4.3.1  (34e114876b0b...)
  Swatinem/rust-cache     v2.7.5 → v2.9.1  (c19371144df3...)
  taiki-e/install-action  nextest           (56cc9adf3a3e...)
  taiki-e/install-action  cargo-llvm-cov    (e4b3a0453201...)

actions-rs/toolchain, actions-rs/clippy-check, and
codecov/codecov-action SHAs were already correct.

https://claude.ai/code/session_01VD7Bo8upj3cQwWDf9ni2Ln
2026-04-16 06:49:47 +00:00
Claude
af3c6c0070 ci: pin GitHub Actions to full commit SHAs and restrict token permissions
Fixes two supply chain / token security issues:

- Pin all third-party Actions to immutable full commit SHAs instead of
  mutable version tags (addresses unpinned-dependencies risk, analogous
  to quickwit-oss/quickwit#5937):
    actions/checkout v4.2.2
    actions-rs/toolchain v1.0.7
    Swatinem/rust-cache v2.7.5
    taiki-e/install-action nextest / cargo-llvm-cov
    actions-rs/clippy-check v1.0.7
    codecov/codecov-action v3.1.6

- Add explicit least-privilege `permissions` blocks at workflow and job
  level (addresses excessive GITHUB_TOKEN permissions, analogous to
  quickwit-oss/quickwit#5945):
    default: contents: read
    check job: also grants checks: write (required by clippy-check)

https://claude.ai/code/session_01VD7Bo8upj3cQwWDf9ni2Ln
2026-04-15 20:55:43 +00:00
59 changed files with 3551 additions and 724 deletions

View File

@@ -4,6 +4,9 @@ on:
push:
branches: [main]
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -12,16 +15,20 @@ concurrency:
jobs:
coverage:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View File

@@ -8,6 +8,9 @@ env:
CARGO_TERM_COLOR: always
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -18,10 +21,13 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal

49
.github/workflows/scorecard.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
name: OpenSSF Scorecard
on:
schedule:
- cron: '0 0 * * 0'
push:
branches:
- main
permissions:
contents: read
jobs:
analysis:
name: Scorecards analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results
id-token: write
steps:
- name: 'Checkout code'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: 'Run analysis'
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
repo_token: ${{ secrets.GITHUB_TOKEN }}
publish_results: true
# Upload the results as artifacts.
- name: 'Upload artifact'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: 'Upload to code-scanning'
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2
with:
sarif_file: results.sarif

View File

@@ -9,6 +9,9 @@ on:
env:
CARGO_TERM_COLOR: always
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -19,23 +22,27 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install nightly
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: nightly
profile: minimal
components: rustfmt
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal
components: clippy
- uses: Swatinem/rust-cache@v2
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
@@ -47,7 +54,7 @@ jobs:
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
- uses: actions-rs/clippy-check@v1
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
with:
toolchain: stable
token: ${{ secrets.GITHUB_TOKEN }}
@@ -57,6 +64,9 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
matrix:
features:
@@ -67,17 +77,17 @@ jobs:
name: test-${{ matrix.features.label}}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal
override: true
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- name: Run tests
run: |

View File

@@ -1,3 +1,9 @@
Tantivy 0.26.1
================================
## Performance
- Fix quadratic runtime in nested term and composite aggregations: memory accounting scanned all parent buckets on every collect instead of just the current parent (@PSeitz @fulmicoton)
Tantivy 0.26 (Unreleased)
================================

View File

@@ -65,7 +65,7 @@ tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
datasketches = { git = "https://github.com/fulmicoton-dd/datasketches-rust", rev = "7635fb8" }
datasketches = { version = "0.3.0", features = ["hll"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -75,7 +75,7 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.16.1"
binggan = "0.17.0"
rand = "0.9"
maplit = "1.0.2"
matches = "0.1.9"
@@ -92,7 +92,7 @@ postcard = { version = "1.0.4", features = [
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
criterion = { version = "0.5", default-features = false }
criterion = { version = "0.8", default-features = false }
[dev-dependencies.fail]
version = "0.5.0"
@@ -201,3 +201,11 @@ harness = false
[[bench]]
name = "regex_all_terms"
harness = false
[[bench]]
name = "query_parser_nested"
harness = false
[[bench]]
name = "intersection_bench"
harness = false

View File

@@ -1,6 +1,7 @@
[![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/)
[![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy)
[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/quickwit-oss/tantivy/badge)](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
[![Join the chat at https://discord.gg/MT27AG5EVE](https://shields.io/discord/908281611840282624?label=chat%20on%20discord)](https://discord.gg/MT27AG5EVE)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy)

View File

@@ -63,6 +63,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
@@ -77,8 +79,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_100_buckets_with_cardinality_agg);
register!(group, terms_many_with_single_term_order_by_card);
register!(group, terms_many_with_single_term_2_order_by_card);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
@@ -166,6 +172,32 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a near-1M-cardinality string field.
// Hits the dense (PagedBitset) path: every doc has a unique term,
// so the bucket promotes from FxHashSet shortly into the scan.
fn cardinality_agg_high_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_all_unique_terms"
},
}
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
// values). Stays on the FxHashSet path — the promotion threshold is
// never crossed. Validates no regression on the sparse path.
fn cardinality_agg_low_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -198,6 +230,58 @@ fn terms_100_buckets_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many_with_single_term_order_by_card(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"nested_terms": {
"terms": {
"field": "single_term",
"order": { "cardinality": "desc" }
},
"aggs": {
"cardinality": {
"cardinality": { "field": "text_few_terms" }
}
}
}
}
},
});
execute_agg(index, agg_req);
}
// Two-level terms ordered by cardinality at each level: a high-card outer terms
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
let agg_req = json!({
"by_ip": {
"terms": {
"field": "text_many_terms",
"order": { "card_few_terms": "desc" }
},
"aggs": {
"card_few_terms": {
"cardinality": { "field": "text_few_terms" }
},
"nested_terms": {
"terms": {
"field": " single_term",
"order": { "distinct_path2": "desc" }
},
"aggs": {
"avg_botscore": { "avg": { "field": "score" } },
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
}
}
}
}
});
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
@@ -270,6 +354,30 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_terms_zipf_1000_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"nested_terms": { "terms": { "field": "text_1000_terms_zipf" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_terms_status_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"nested_terms": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -583,7 +691,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let single_term = schema_builder.add_text_field("single_term", FAST);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
@@ -647,6 +756,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
single_term => "single_term",
single_term => "single_term",
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
@@ -681,6 +792,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),

View File

@@ -0,0 +1,149 @@
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
//
// What's measured:
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
// - Realistic term frequencies (geometric distribution, mostly low)
// - 1M-doc single segment
//
// Run with: cargo bench --bench intersection_bench
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
const NUM_DOCS: usize = 1_000_000;
struct BenchIndex {
searcher: Searcher,
query_parser: QueryParser,
}
/// Generate term frequency from a geometric-like distribution.
/// Most values are 1, a few are 2-3, rarely higher.
/// p controls the decay: higher p → more weight on tf=1.
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
let mut tf = 1u32;
while tf < 10 && rng.random_bool(1.0 - p) {
tf += 1;
}
tf
}
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
/// Each term occurrence has a realistic term frequency (geometric distribution).
/// Field length is padded with filler tokens to create varied fieldnorms.
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut rng = StdRng::from_seed([42u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..NUM_DOCS {
let mut tokens: Vec<String> = Vec::new();
if rng.random_bool(p_a) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("aaa".to_string());
}
}
if rng.random_bool(p_b) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("bbb".to_string());
}
}
if rng.random_bool(p_c) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("ccc".to_string());
}
}
// Pad with filler to create varied field lengths (5-30 tokens).
let filler_count = rng.random_range(5u32..30u32);
for _ in 0..filler_count {
tokens.push("filler".to_string());
}
let text = tokens.join(" ");
writer.add_document(doc!(body => text)).unwrap();
}
writer.commit().unwrap();
}
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![body]);
BenchIndex {
searcher,
query_parser,
}
}
fn main() {
// Scenarios: (label, p_a, p_b, p_c)
//
// "balanced": all terms ~10% → intersection ~1% of docs
// "skewed": one common (50%), one rare (2%) → intersection ~1%
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
// "three_balanced": three terms ~20% each → intersection ~0.8%
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
("balanced_10%_10%", 0.10, 0.10, 0.0),
("skewed_50%_2%", 0.50, 0.02, 0.0),
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
];
let mut runner = BenchRunner::new();
for (label, p_a, p_b, p_c) in &scenarios {
let bench_index = build_index(*p_a, *p_b, *p_c);
let mut group = runner.new_group();
group.set_name(format!("intersection — {label}"));
// Two-term intersection
if *p_a > 0.0 && *p_b > 0.0 {
let query_str = "+aaa +bbb";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
// Three-term intersection
if *p_c > 0.0 {
let query_str = "+aaa +bbb +ccc";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
group.run();
}
}

View File

@@ -0,0 +1,35 @@
// Benchmark for the query grammar parsing deeply nested queries.
//
// Regression guard for https://github.com/quickwit-oss/tantivy/issues/2498:
// at depth 20/21 the old parser took 0.87 s / 1.72 s respectively because
// `ast()` retried `occur_leaf` on backtrack, giving O(2^n) time. With the
// fix parsing is linear and completes in microseconds.
//
// Run with: `cargo bench --bench query_parser_nested`.
use binggan::{black_box, BenchRunner};
use tantivy::query_grammar::parse_query;
fn nested_query(depth: usize, leading_plus: bool) -> String {
let leading = "(".repeat(depth);
let trailing = ")".repeat(depth);
let prefix = if leading_plus { "+" } else { "" };
format!("{prefix}{leading}title:test{trailing}")
}
fn main() {
let mut runner = BenchRunner::new();
for depth in [20, 21] {
for leading_plus in [false, true] {
let query = nested_query(depth, leading_plus);
let label = format!(
"parse_nested_depth_{depth}_{}",
if leading_plus { "plus" } else { "plain" },
);
runner.bench_function(&label, move |_| {
black_box(parse_query(black_box(&query)).unwrap());
});
}
}
}

View File

@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
binggan = "0.16.1"
binggan = "0.17.0"
[[bench]]
name = "bench_merge"

View File

@@ -19,6 +19,6 @@ time = { version = "0.3.47", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.16.1"
binggan = "0.17.0"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -1045,18 +1045,43 @@ fn operand_leaf(inp: &str) -> IResult<&str, (Option<BinaryOperand>, Option<Occur
}
fn ast(inp: &str) -> IResult<&str, UserInputAst> {
let boolean_expr = map_res(
separated_pair(occur_leaf, multispace1, many1(operand_leaf)),
|(left, right)| aggregate_binary_expressions(left, right),
);
let single_leaf = map(occur_leaf, |(occur, ast)| {
if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
}
});
delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp)
// Parse `occur_leaf` once, then conditionally extend into a boolean
// expression. The previous implementation used `alt((boolean_expr,
// single_leaf))` which, when the input was a single leaf with no
// following operand, would parse `occur_leaf` once for `boolean_expr`,
// fail at `multispace1`, backtrack, then re-parse `occur_leaf` for
// `single_leaf`. With recursively-nested groups like `(+(+(+a)))`, that
// doubling at every level produced O(2^n) parse time. Parsing once and
// peeking ahead for the operand keeps it O(n).
delimited(
multispace0,
|inp| {
let (rest, first) = occur_leaf(inp)?;
// Only fall back on `Err::Error` (recoverable), mirroring
// `alt`'s behaviour. `Err::Failure` and `Err::Incomplete`
// must propagate so cut points and streaming needs are not
// accidentally swallowed if they are ever introduced in the
// operand parsers.
match preceded(multispace1, many1(operand_leaf))(rest) {
Ok((rest, more)) => {
let combined = aggregate_binary_expressions(first, more)
.map_err(|_| nom::Err::Error(Error::new(inp, ErrorKind::MapRes)))?;
Ok((rest, combined))
}
Err(nom::Err::Error(_)) => {
let (occur, ast) = first;
let single = if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
};
Ok((rest, single))
}
Err(e) => Err(e),
}
},
multispace0,
)(inp)
}
fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> {
@@ -1891,4 +1916,23 @@ mod test {
r#"(+"field":'happy tax payer' +"other_field":1)"#,
);
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2498:
// deeply nested parenthesized queries used to take O(2^n) time because the
// top-level `ast()` parser tried `boolean_expr` first and re-parsed the
// inner `occur_leaf` when it backtracked to `single_leaf`. Depth 60 would
// take ~10^18 operations under the regression; with the fix it parses
// instantly. We use `test_parse_query_to_ast_helper` so this test would
// never finish if the regression returned.
#[test]
fn test_parse_deeply_nested_query() {
let depth = 60;
let leading: String = "(".repeat(depth);
let trailing: String = ")".repeat(depth);
let query = format!("{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query, r#""title":test"#);
let query_with_plus = format!("+{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query_with_plus, r#""title":test"#);
}
}

View File

@@ -20,8 +20,8 @@ use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
};
use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
@@ -413,12 +413,38 @@ pub(crate) fn build_segment_agg_collector(
}
AggKind::Cardinality => {
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
// For str columns, choose the per-bucket entries representation
// based on the segment's column.max_value():
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
// For non-str columns the `entries` field is unused (values go
// straight into the HLL sketch); we still pick `TermOrdSet`
// because its empty Sparse(FxHashSet) costs nothing.
let is_str = req_data.column_type == ColumnType::Str;
let max_term_ord_inclusive = if is_str {
req_data.accessor.max_value()
} else {
0
};
let collector: Box<dyn SegmentAggregationCollector> =
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
} else {
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
};
Ok(collector)
}
AggKind::StatsKind(stats_type) => {
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
@@ -985,8 +1011,12 @@ fn build_terms_or_cardinality_nodes(
let str_col = str_dict_column
.as_ref()
.expect("str_dict_column must exist for string column");
allowed_term_ids =
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
allowed_term_ids = build_allowed_term_ids_for_str(
str_col,
&req.include,
&req.exclude,
missing.is_some(),
)?;
};
let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
accessor,
@@ -1002,10 +1032,20 @@ fn build_terms_or_cardinality_nodes(
(idx_in_req_data, AggKind::Terms)
}
TermsOrCardinalityRequest::Cardinality(ref req) => {
// `str_dict_column` is computed once per field; for JSON paths
// with mixed types it's `Some` even on the numeric req_data.
// Cardinality only consults it for the str column path, so
// gate by column_type to avoid driving non-str collectors
// through the coupon-cache path.
let str_dict_column_for_req = if column_type == ColumnType::Str {
str_dict_column.clone()
} else {
None
};
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
accessor,
column_type,
str_dict_column: str_dict_column.clone(),
str_dict_column: str_dict_column_for_req,
missing_value_for_accessor,
name: agg_name.to_string(),
req: req.clone(),
@@ -1025,16 +1065,21 @@ fn build_terms_or_cardinality_nodes(
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
///
/// When `reserve_missing_sentinel` is true, the bitset will have 1 additional slot for the missing
/// term ordinal
fn build_allowed_term_ids_for_str(
str_col: &StrColumn,
include: &Option<IncludeExcludeParam>,
exclude: &Option<IncludeExcludeParam>,
reserve_missing_sentinel: bool,
) -> crate::Result<Option<BitSet>> {
let mut allowed: Option<BitSet> = None;
let num_terms = str_col.dictionary().num_terms() as u32;
let missing_sentinel_adjustment = if reserve_missing_sentinel { 1 } else { 0 };
let allowed_capacity = str_col.dictionary().num_terms() as u32 + missing_sentinel_adjustment;
if let Some(include) = include {
// add matches
allowed = Some(BitSet::with_max_value(num_terms));
allowed = Some(BitSet::with_max_value(allowed_capacity));
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
};
@@ -1042,7 +1087,7 @@ fn build_allowed_term_ids_for_str(
if let Some(exclude) = exclude {
if allowed.is_none() {
// Start with all terms allowed
allowed = Some(BitSet::with_max_value_and_full(num_terms));
allowed = Some(BitSet::with_max_value_and_full(allowed_capacity));
}
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;

View File

@@ -115,6 +115,71 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Validates that all fields referenced in the aggregation request exist in the schema
/// and are configured as fast fields.
///
/// This is a convenience function for upfront validation before executing aggregations.
/// Returns an error if any field doesn't exist or is not a fast field.
///
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
/// default lenient behavior (returning empty results for missing fields) supports
/// schema evolution and federated queries where the same request runs against segments
/// or indices with different schemas.
///
/// # Example
/// ```
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
/// use tantivy::schema::{Schema, FAST};
/// use tantivy::Index;
///
/// # fn main() -> tantivy::Result<()> {
/// // Create a simple index
/// let mut schema_builder = Schema::builder();
/// schema_builder.add_f64_field("price", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// // Parse aggregation request
/// let agg_req: Aggregations = serde_json::from_str(r#"{
/// "avg_price": { "avg": { "field": "price" } }
/// }"#)?;
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Validate fields before executing
/// for segment_reader in searcher.segment_readers() {
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
/// }
/// # Ok(())
/// # }
/// ```
pub fn validate_aggregation_fields_exist(
aggs: &Aggregations,
reader: &crate::SegmentReader,
) -> crate::Result<()> {
let field_names = get_fast_field_names(aggs);
let schema = reader.schema();
for field_name in field_names {
// Check if the field is either directly in the schema or could be part of a json field
// present in the schema, and verify it's a fast field.
if let Some((field, _path)) = schema.find_field(&field_name) {
let field_type = schema.get_field_entry(field).field_type();
if !field_type.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field '{}' is not a fast field. Aggregations require fast fields.",
field_name
)));
}
} else {
return Err(crate::TantivyError::FieldNotFound(field_name));
}
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {

View File

@@ -1436,3 +1436,46 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
)
);
}
#[test]
fn test_aggregation_field_validation_helper() {
// Test the standalone validation helper function for field validation
let index = get_test_index_2_segments(false).unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
// Test with invalid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "nonexistent_field" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_err());
match result {
Err(crate::TantivyError::FieldNotFound(field_name)) => {
assert_eq!(field_name, "nonexistent_field");
}
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
}
// Test with valid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "score" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_ok());
}

View File

@@ -152,7 +152,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
@@ -172,7 +172,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
sub_agg.check_flush_local(agg_data)?;
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
@@ -199,14 +199,22 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Composite is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self) -> u64 {
self.parent_buckets
.iter()
.map(|m| m.memory_consumption())
.sum()
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
pub(crate) fn from_req_and_validate(

View File

@@ -559,34 +559,30 @@ mod tests {
page_size,
agg_req,
);
if page_idx + 1 < page_count {
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on all but last page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
} else if res["my_composite"].get("after_key").is_some() {
// currently we sometime have an after_key on the last page,
// check that the next "page" is empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": res["my_composite"]["after_key"].clone(),
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on every non-empty page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
}
// Using the after_key from the last page must yield an empty page.
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": after_key,
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
}
@@ -711,8 +707,28 @@ mod tests {
{"key": {"myterm": "terme"}, "doc_count": 1}
])
);
assert!(res["my_composite"].get("after_key").is_none());
// paginating past last page should be empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": [
{"myterm": {"terms": {"field": "string_id"}}}
],
"size": 3,
"after": &res["my_composite"]["after_key"]
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert!(res["my_composite"].get("after_key").is_none());
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
Ok(())
}
@@ -820,7 +836,10 @@ mod tests {
{"key": {"myterm": "apple"}, "doc_count": 1}
])
);
assert!(res["fruity_aggreg"].get("after_key").is_none());
assert_eq!(
res["fruity_aggreg"]["after_key"],
json!({"myterm": "str:apple"})
);
Ok(())
}
@@ -1792,7 +1811,14 @@ mod tests {
{"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1},
]),
);
assert!(res["my_composite"].get("after_key").is_none());
let feb_2021_ns = ms_timestamp_from_iso_str("2021-02-01T00:00:00Z") * 1_000_000;
assert_eq!(
res["my_composite"]["after_key"],
json!({
"month": format!("dt:{}", feb_2021_ns),
"category": "str:books"
})
);
Ok(())
}

View File

@@ -674,6 +674,17 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
None
}
}
/// Intermediate result for filter aggregation

View File

@@ -283,6 +283,11 @@ impl SegmentHistogramBucketEntry {
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
impl HistogramBuckets {
fn memory_consumption(&self) -> u64 {
self.buckets.capacity() as u64 * std::mem::size_of::<SegmentHistogramBucketEntry>() as u64
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
@@ -324,7 +329,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
@@ -358,12 +363,9 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption() - mem_pre;
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
@@ -392,14 +394,24 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Histogram is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,

View File

@@ -328,6 +328,17 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Range is a multi-bucket agg with no single value to extract.
None
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.

View File

@@ -352,19 +352,15 @@ pub(crate) fn build_segment_term_collector(
)));
}
// Validate sub aggregation exists when ordering by sub-aggregation.
{
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
// Validate that the referenced sub-aggregation exists when ordering by one.
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric sub_aggregations"
))
})?;
}
// Build sub-aggregation blueprint if there are children.
@@ -809,7 +805,7 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// let mem_pre = self.get_memory_consumption();
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let req_data = &mut self.terms_req_data;
@@ -853,16 +849,13 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
}
// let mem_delta = self.get_memory_consumption() - mem_pre;
// if mem_delta > 0 {
// agg_data
// .context
// .limits
// .add_memory_consumed(mem_delta as u64)?;
// }
// After commenting out -> 6000ms -> 36ms
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
@@ -890,6 +883,17 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Terms is a multi-bucket agg with no single value to extract.
None
}
}
/// Missing value are represented as a sentinel value in the column.
@@ -949,11 +953,9 @@ where
TermMap: TermAggregationMap,
B: SubAggBuffer,
{
fn get_memory_consumption(&self) -> usize {
self.parent_buckets
.iter()
.map(|b| b.get_memory_consumption())
.sum()
#[inline]
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> usize {
self.parent_buckets[parent_bucket_id as usize].get_memory_consumption()
}
#[inline]
@@ -965,9 +967,6 @@ where
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
match &term_req.req.order.target {
OrderTarget::Key => {
// We rely on the fact, that term ordinals match the order of the strings
@@ -979,10 +978,37 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
}
}
OrderTarget::SubAggregation(_name) => {
// don't sort and cut off since it's hard to make assumptions on the quality of the
// results when cutting off du to unknown nature of the sub_aggregation (possible
// to check).
OrderTarget::SubAggregation(sub_agg_path) => {
// Peek segment-level metric values, sort, then fall through to
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
// by a sub-agg: top-K results are approximate and may differ from the
// global ordering, especially for non-monotonic metrics like avg/min.
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"Could not find sub-aggregation collector for path {sub_agg_path}"
))
})?;
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
// Fetch values up-front; otherwise sort would re-compute per comparison
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
.into_iter()
.map(|bucket| {
let metric_value = coll
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
.unwrap_or(0.0);
(metric_value, bucket)
})
.collect();
if term_req.req.order.order == Order::Desc {
keyed.sort_unstable_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
keyed.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
entries = keyed.into_iter().map(|(_, e)| e).collect();
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
@@ -993,11 +1019,8 @@ where
}
}
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
};
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
@@ -1228,7 +1251,6 @@ pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
mod tests {
use std::net::IpAddr;
use std::str::FromStr;
use std::time::Instant;
use common::DateTime;
use time::{Date, Month};
@@ -1242,10 +1264,9 @@ mod tests {
get_test_index_from_terms, get_test_index_from_values_and_terms,
};
use crate::aggregation::{AggregationLimitsGuard, DistributedAggregationCollector};
use crate::collector::{Collector, default_collect_segment_impl};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, EnableScoring, Query};
use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING};
use crate::query::AllQuery;
use crate::schema::{IntoIpv6Addr, Schema, FAST, INDEXED, STRING, TEXT};
use crate::{Index, IndexWriter};
#[test]
@@ -1774,6 +1795,263 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(true)
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(false)
}
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
// Distinct score values per bucket key: A→5, B→1, C→3.
// Order by cardinality desc must yield A, C, B.
let segment_and_terms = vec![vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
]];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
// and `sum_other_doc_count` reflects the dropped B (3 docs).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
Ok(())
}
#[test]
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(true)
}
#[test]
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(false)
}
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
let segment_and_terms = vec![
vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
],
vec![
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// Desc on a Sum metric engages the fast path (column is U64).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "asc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 desc with cutoff: top-2 by sum (A, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "mystats.sum": "desc" }
},
"aggs": {
"mystats": { "stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Sum on a signed column (I64) takes the same cutoff path. Results may be
// approximate near the boundary on adversarial data, but for this dataset the
// top-K is unambiguous.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score_i64" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Order by extended_stats sub-property exercises compute_metric_value on the
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "ext.max": "desc" }
},
"aggs": {
"ext": { "extended_stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)
@@ -2940,102 +3218,100 @@ mod tests {
Ok(())
}
#[test]
fn test_terms_double_nesting() {
fn prep_index_with_n_unique_terms_plus_one_null(n: u64) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let outer_field = schema_builder.add_text_field("outer_term", STRING | FAST);
let inner_field = schema_builder.add_text_field("inner_term", STRING | FAST);
let id_field = schema_builder.add_u64_field("id", INDEXED);
let title_field = schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// set to one thread to guarantee all docs end up in the same segment
let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
let outer_values = (0..10_000)
.map(|i| format!("outer_{i}"))
.collect::<Vec<_>>();
let inner_values = ["INFO", "ERROR", "WARN", "DEBUG"];
{
let mut index_writer: IndexWriter = index.writer_with_num_threads(1, 200_000_000).unwrap();
for doc_id in 0..1_000_000u64 {
let outer_val = &outer_values[doc_id as usize % outer_values.len()];
let inner_val = inner_values[doc_id as usize % inner_values.len()];
index_writer.add_document(doc!(
outer_field => outer_val.as_str(),
inner_field => inner_val,
)).unwrap();
}
index_writer.commit().unwrap();
writer.add_document(doc!(
id_field => 0u64,
))?;
for i in 1u64..=n {
let title = format!("foo{i}");
writer.add_document(doc!(
id_field => i,
title_field => title,
))?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"outer": {
"terms": { "field": "outer_term", "size": 10 },
"aggs": {
"inner": {
"terms": { "field": "inner_term" }
writer.commit()?;
Ok(index)
}
#[test]
fn null_bitset_bounds_check_regression() -> crate::Result<()> {
// include cases
for i in 0..=4 {
let index = prep_index_with_n_unique_terms_plus_one_null(i * 64)?;
let normal_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let include_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"include": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let exclude_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"exclude": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let normal_res = exec_request(normal_req, &index)?;
let normal_buckets = normal_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
normal_buckets.len(),
(i * 64) as usize + 1,
"The normal request should return all 'foo' buckets, plus the missing term bucket",
);
let include_res = exec_request(include_req, &index)?;
eprintln!("include_res: {include_res:?}");
let include_buckets = include_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
include_buckets.len(),
(i * 64) as usize,
"The include request should return all 'foo' buckets, and not the missing term \
bucket",
);
assert!(include_buckets
.iter()
.all(|b| b["key"].as_str().unwrap().starts_with("foo")));
let exclude_res = exec_request(exclude_req, &index)?;
let exclude_buckets = exclude_res["my_bool"]["buckets"].as_array().unwrap();
if i != 0 {
// TODO: Remove this if after fixing exclude + missing bug
assert_eq!(
exclude_buckets.len(),
1,
"The exclude request should exclude all 'foo' buckets, and only the missing \
term bucket",
);
assert_eq!(exclude_buckets[0]["key"], "__NULL__");
}
}))
.unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector =
crate::aggregation::AggregationCollector::from_aggs(agg_req, Default::default());
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0u32);
let all_weight = AllQuery.weight(EnableScoring::disabled_from_schema(&schema)).unwrap();
let mut segment_collector = collector.for_segment(0u32, segment_reader).unwrap();
let start = Instant::now();
default_collect_segment_impl(&mut segment_collector, &*all_weight, segment_reader, false).unwrap();
dbg!(start.elapsed());
}
#[test]
fn test_terms_simple_nesting() {
let mut schema_builder = Schema::builder();
let outer_field = schema_builder.add_text_field("outer_term", STRING | FAST);
let inner_field = schema_builder.add_text_field("inner_term", STRING | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let outer_values = (0..10_000)
.map(|i| format!("outer_{i}"))
.collect::<Vec<_>>();
let inner_values = ["INFO", "ERROR", "WARN", "DEBUG"];
{
let mut index_writer: IndexWriter = index.writer_with_num_threads(1, 200_000_000).unwrap();
for doc_id in 0..1_000_000u64 {
let outer_val = &outer_values[doc_id as usize % outer_values.len()];
let inner_val = inner_values[doc_id as usize % inner_values.len()];
index_writer.add_document(doc!(
outer_field => outer_val.as_str(),
inner_field => inner_val,
)).unwrap();
}
index_writer.commit().unwrap();
}
let agg_req: Aggregations = serde_json::from_value(json!({
"outer": {
"terms": { "field": "outer_term", "size": 10 },
}
}))
.unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector =
crate::aggregation::AggregationCollector::from_aggs(agg_req, Default::default());
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0u32);
let all_weight = AllQuery.weight(EnableScoring::disabled_from_schema(&schema)).unwrap();
let mut segment_collector = collector.for_segment(0u32, segment_reader).unwrap();
let start = Instant::now();
default_collect_segment_impl(&mut segment_collector, &*all_weight, segment_reader, false).unwrap();
dbg!(start.elapsed());
Ok(())
}
}

View File

@@ -177,6 +177,17 @@ impl SegmentAggregationCollector for TermMissingAgg {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
None
}
}
#[cfg(test)]

View File

@@ -1004,24 +1004,20 @@ impl IntermediateCompositeBucketResult {
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let after_key = trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap_or_default();
let buckets = trimmed_entry_vec
.into_iter()

View File

@@ -4,6 +4,7 @@ use std::io;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::{BitSet, TinySet};
use datasketches::hll::{Coupon, HllSketch, HllType, HllUnion};
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -20,6 +21,12 @@ use crate::TantivyError;
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
const LG_K: u8 = 11;
/// Promote FxHashSet<u64> -> PagedBitset at ~3% density (`len * 32 >
/// dict_num_terms`). Past this point the bitset (~`dict_num_terms / 7.5`
/// bytes) is smaller than the hashset (~10 B/entry minimum) and avoids
/// the per-insert hash.
const PROMOTION_RATIO: u64 = 32;
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
@@ -159,8 +166,12 @@ impl CouponCache {
let should_use_dense =
highest_term_ord < 1_000_000u64 || highest_term_ord < num_terms as u64 * 3u64;
if should_use_dense {
let mut coupon_map: Vec<Coupon> = vec![Coupon::EMPTY; highest_term_ord as usize + 1];
for (term_ord, coupon) in term_ords.into_iter().zip(coupons.into_iter()) {
// We don't really care about the value here. We will populate all the values we will
// read anyway.
let uninitialized_coupon = Coupon::from_hash(0);
let mut coupon_map: Vec<Coupon> =
vec![uninitialized_coupon; highest_term_ord as usize + 1];
for (term_ord, coupon) in term_ords.into_iter().zip(coupons) {
coupon_map[term_ord as usize] = coupon;
}
CouponCache::Dense {
@@ -177,9 +188,263 @@ impl CouponCache {
}
}
pub(crate) struct SegmentCardinalityCollector {
// =================================================================
// PagedBitset: a sparse bitset indexed by term_ord.
//
// Used as the dense alternative to FxHashSet<u64> once a string
// cardinality bucket has accumulated enough unique term ordinals.
// Memory is bounded to (touched pages) * (page bytes), not
// (max_term_ord / 8).
//
// Page geometry mirrors `PagedTermMap` in `term_agg.rs`: 1024 ords
// per page, lazy `Vec<Option<Box<Page>>>` directory.
// =================================================================
const BITSET_PAGE_SHIFT: u32 = 10;
const BITSET_PAGE_BITS: u64 = 1u64 << BITSET_PAGE_SHIFT; // 1024
const BITSET_PAGE_MASK: u64 = BITSET_PAGE_BITS - 1;
const BITSET_WORDS_PER_PAGE: usize = (BITSET_PAGE_BITS / 64) as usize; // 16
#[derive(Clone)]
struct PagedBitsetPage {
words: [TinySet; BITSET_WORDS_PER_PAGE],
}
impl PagedBitsetPage {
fn new() -> Self {
Self {
words: [TinySet::empty(); BITSET_WORDS_PER_PAGE],
}
}
}
pub(crate) struct PagedBitset {
pages: Vec<Option<Box<PagedBitsetPage>>>,
/// Cached number of set bits, maintained on insert.
count: u64,
}
impl PagedBitset {
/// Allocates a directory big enough to hold ords up to and including
/// `max_term_ord`. Pages are allocated lazily on first set.
fn with_max_term_ord(max_term_ord: u64) -> Self {
let max_page_idx = (max_term_ord >> BITSET_PAGE_SHIFT) as usize;
let num_pages = max_page_idx + 1;
Self {
pages: vec![None; num_pages],
count: 0,
}
}
#[inline]
fn insert(&mut self, term_ord: u64) {
let page_idx = (term_ord >> BITSET_PAGE_SHIFT) as usize;
let intra = term_ord & BITSET_PAGE_MASK;
let word_idx = (intra >> 6) as usize;
let bit_idx = (intra & 63) as u32;
let page = match &mut self.pages[page_idx] {
Some(p) => p,
None => {
self.pages[page_idx] = Some(Box::new(PagedBitsetPage::new()));
self.pages[page_idx].as_mut().unwrap()
}
};
if page.words[word_idx].insert_mut(bit_idx) {
self.count += 1;
}
}
/// Number of set bits. O(1).
#[inline]
fn len(&self) -> u64 {
self.count
}
/// Iterate set ords in ascending order.
fn iter_sorted(&self) -> impl Iterator<Item = u64> + '_ {
self.pages
.iter()
.enumerate()
.filter_map(|(page_idx, page_opt)| page_opt.as_ref().map(|p| (page_idx, p)))
.flat_map(|(page_idx, page)| {
let page_base_ord = (page_idx as u64) << BITSET_PAGE_SHIFT;
page.words
.iter()
.enumerate()
.flat_map(move |(word_idx, &word)| {
let word_base_ord = page_base_ord + (word_idx as u64) * 64;
word.into_iter()
.map(move |bit| word_base_ord + u64::from(bit))
})
})
}
}
/// Threshold below which we use `BitSet` instead of `TermOrdSet`.
///
/// Both `BitSet` and `FxHashSet<u64>` have the same 32-byte struct, so the comparison is heap only:
/// * `BitSet` at T=256: 5 `TinySet` words covering 258 bits (with the missing-value sentinel) =
/// 40 bytes.
/// * `FxHashSet<u64>` after one insert: 4-bucket hashbrown table ≈ 56 bytes
pub(crate) const BITSET_MAX_TERM_ORD: u64 = 256;
// =================================================================
// TermOrdAccumulator: per-bucket abstraction over the entries set.
//
// Implementations:
// - `BitSet` (from `common`): used when `column.max_value()` is small (< BITSET_MAX_TERM_ORD).
// Pre-allocated, no promotion.
// - `TermOrdSet`: adaptive, starts as FxHashSet and promotes to a paged bitset when occupancy
// crosses the density threshold (only if promotion is enabled — typically gated on top-level
// aggregation).
//
// The trait lets `SegmentCardinalityCollector` be generic over the choice
// so the hot collect() loop monomorphizes to a direct call (no enum
// dispatch per insert).
// =================================================================
pub(crate) trait TermOrdAccumulator: Sized {
/// Construct an empty accumulator.
/// `max_term_ord_inclusive` is the largest term_ord that may be
/// inserted (used to size pre-allocated bitsets and the dense bitset
/// on promotion).
fn new(max_term_ord_inclusive: u64) -> Self;
fn insert(&mut self, term_ord: u64);
/// Bulk insert. Implementations may override to hoist any inner
/// dispatch outside the loop. Default loops `insert`.
#[inline]
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
for ord in ords {
self.insert(ord);
}
}
/// Hook called once per ingested block. Adaptive impls use this to
/// decide on sparse->dense promotion.
fn maybe_compact(&mut self) {}
fn len(&self) -> usize;
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_;
}
impl TermOrdAccumulator for BitSet {
#[inline]
fn new(max_term_ord_inclusive: u64) -> Self {
// `BitSet::with_max_value(M)` accepts ords in [0, M).
// We need ords up to and including `max_term_ord_inclusive`, plus
// the missing-value sentinel `column.max_value() + 1`.
BitSet::with_max_value((max_term_ord_inclusive + 2) as u32)
}
#[inline]
fn insert(&mut self, term_ord: u64) {
BitSet::insert(self, term_ord as u32);
}
#[inline]
fn len(&self) -> usize {
BitSet::len(self)
}
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
// `BitSet` itself doesn't expose iteration, but
// `BitSet::tinyset(bucket)` does. Walk per-bucket and yield each
// set bit. The capacity is `max_value()`; iterating to
// `div_ceil(64)` covers every possible ord exactly once.
let num_buckets = self.max_value().div_ceil(64);
(0..num_buckets).flat_map(move |bucket| {
let chunk_base = u64::from(bucket) * 64;
self.tinyset(bucket)
.into_iter()
.map(move |bit| chunk_base + u64::from(bit))
})
}
}
// =================================================================
// TermOrdSet: adaptive sparse->dense accumulator.
//
// Starts as an FxHashSet (cheap when few ords are seen). When occupancy
// crosses `len * PROMOTION_RATIO > max_term_ord_inclusive`, drains into
// a `PagedBitset` and continues dense. Promotion is one-way.
// =================================================================
pub(crate) struct TermOrdSet {
inner: TermOrdSetInner,
/// Largest term_ord that may be inserted. Used for both sizing the
/// dense bitset on promotion and as the promotion-threshold reference.
max_term_ord_inclusive: u64,
}
enum TermOrdSetInner {
Sparse(FxHashSet<u64>),
Dense(PagedBitset),
}
impl TermOrdAccumulator for TermOrdSet {
fn new(max_term_ord_inclusive: u64) -> Self {
Self {
inner: TermOrdSetInner::Sparse(FxHashSet::default()),
max_term_ord_inclusive,
}
}
#[inline]
fn insert(&mut self, term_ord: u64) {
match &mut self.inner {
TermOrdSetInner::Sparse(set) => {
set.insert(term_ord);
}
TermOrdSetInner::Dense(bitset) => bitset.insert(term_ord),
}
}
/// Hoist the Sparse/Dense match outside the per-ord loop so that a
/// block of inserts dispatches once.
fn extend_from_iter<I: IntoIterator<Item = u64>>(&mut self, ords: I) {
match &mut self.inner {
TermOrdSetInner::Sparse(set) => {
for ord in ords {
set.insert(ord);
}
}
TermOrdSetInner::Dense(bitset) => {
for ord in ords {
bitset.insert(ord);
}
}
}
}
fn maybe_compact(&mut self) {
let TermOrdSetInner::Sparse(set) = &mut self.inner else {
return;
};
if set.len() as u64 * PROMOTION_RATIO <= self.max_term_ord_inclusive {
return;
}
// Size for ord <= max_term_ord_inclusive plus the missing sentinel
// (column.max_value() + 1, which may equal max_term_ord_inclusive
// when the column references every dictionary term).
let mut bitset = PagedBitset::with_max_term_ord(self.max_term_ord_inclusive + 1);
let set = std::mem::take(set);
for ord in set {
bitset.insert(ord);
}
self.inner = TermOrdSetInner::Dense(bitset);
}
fn len(&self) -> usize {
match &self.inner {
TermOrdSetInner::Sparse(set) => set.len(),
TermOrdSetInner::Dense(bitset) => bitset.len() as usize,
}
}
fn iter_ords(&self) -> impl Iterator<Item = u64> + '_ {
match &self.inner {
TermOrdSetInner::Sparse(set) => itertools::Either::Left(set.iter().copied()),
TermOrdSetInner::Dense(bitset) => itertools::Either::Right(bitset.iter_sorted()),
}
}
}
pub(crate) struct SegmentCardinalityCollector<S: TermOrdAccumulator> {
/// Buckets are Some(_) until they get consumed by into_intermediate_results().
buckets: Vec<Option<SegmentCardinalityCollectorBucket>>,
buckets: Vec<Option<SegmentCardinalityCollectorBucket<S>>>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
@@ -188,9 +453,13 @@ pub(crate) struct SegmentCardinalityCollector {
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
coupon_cache: Option<CouponCache>,
/// Largest term_ord that may be inserted into a bucket. For str columns
/// this is `accessor.max_value()`; for non-str columns this is unused
/// (no inserts go into `entries`) and set to 0.
max_term_ord_inclusive: u64,
}
impl Debug for SegmentCardinalityCollector {
impl<S: TermOrdAccumulator> Debug for SegmentCardinalityCollector<S> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentCardinalityCollector")
.field("column_type", &self.column_type)
@@ -202,16 +471,21 @@ impl Debug for SegmentCardinalityCollector {
}
}
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
/// Per-bucket state. Shape depends on column kind: str columns dedup
/// term ords and only build the HLL sketch at finalization (saves the
/// ~96 B `CardinalityCollector` per bucket during collect); numeric/IpAddr
/// columns feed the sketch directly during collect.
pub(crate) enum SegmentCardinalityCollectorBucket<S: TermOrdAccumulator> {
Str(S),
Numeric(CardinalityCollector),
}
impl SegmentCardinalityCollectorBucket {
impl<S: TermOrdAccumulator> SegmentCardinalityCollectorBucket<S> {
#[inline(always)]
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: FxHashSet::default(),
pub fn new(column_type: ColumnType, max_term_ord_inclusive: u64) -> Self {
if column_type == ColumnType::Str {
Self::Str(S::new(max_term_ord_inclusive))
} else {
Self::Numeric(CardinalityCollector::new(column_type as u8))
}
}
@@ -222,37 +496,57 @@ impl SegmentCardinalityCollectorBucket {
//
// If the column is str, then the values are dictionary encoded
// and have not been added to the sketch yet.
// We need to resolves the term ords accumulated in self.entries
// with the coupon cache, and append the results to the sketch.
// We need to resolves the term ords accumulated in the str entries
// with the coupon cache, and append the results to a fresh sketch.
fn into_intermediate_metric_result(
mut self,
self,
coupon_cache_opt: Option<&CouponCache>,
) -> crate::Result<IntermediateMetricResult> {
if let Some(coupon_cache) = coupon_cache_opt {
assert!(self.cardinality.sketch.is_empty());
append_to_sketch(&self.entries, coupon_cache, &mut self.cardinality);
}
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
let cardinality = match self {
Self::Str(entries) => {
let mut cardinality = CardinalityCollector::new(ColumnType::Str as u8);
if let Some(coupon_cache) = coupon_cache_opt {
// Sketch must be empty for str columns: coupons are appended here
// from the term_ord set (and not directly during collection).
assert!(cardinality.sketch.is_empty());
append_to_sketch(&entries, coupon_cache, &mut cardinality);
}
cardinality
}
Self::Numeric(cardinality) => cardinality,
};
Ok(IntermediateMetricResult::Cardinality(cardinality))
}
}
/// Builds a coupon cache from the given buckets, dictionary, and optional missing value.
/// Returns a mapping from term_ord to the hash (coupon) of the associated term.
fn build_coupon_cache(
buckets: &[Option<SegmentCardinalityCollectorBucket>],
fn build_coupon_cache<S: TermOrdAccumulator>(
buckets: &[Option<SegmentCardinalityCollectorBucket<S>>],
dictionary: &Dictionary,
missing_value_opt: Option<&Key>,
) -> io::Result<CouponCache> {
let term_ords_capacity: usize = buckets
.iter()
.flatten()
.map(|bucket| bucket.entries.len())
.max()
.unwrap_or(0)
* 2;
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(term_ords_capacity, FxBuildHasher);
// Caller restricts this to str cardinality collectors, so every
// present bucket must be the `Str` variant. Pass 1 validates and
// computes the capacity hint; pass 2 inserts.
let mut max_bucket_len = 0usize;
for bucket in buckets.iter().flatten() {
term_ords_set.extend(bucket.entries.iter().copied());
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => {
max_bucket_len = max_bucket_len.max(entries.len());
}
SegmentCardinalityCollectorBucket::Numeric(_) => {
return Err(io::Error::other(
"build_coupon_cache invoked with a non-str bucket",
));
}
}
}
let mut term_ords_set = FxHashSet::with_capacity_and_hasher(max_bucket_len * 2, FxBuildHasher);
for bucket in buckets.iter().flatten() {
if let SegmentCardinalityCollectorBucket::Str(entries) = bucket {
term_ords_set.extend(entries.iter_ords());
}
}
let mut term_ords: Vec<u64> = term_ords_set.into_iter().collect();
term_ords.sort_unstable();
@@ -284,8 +578,8 @@ fn build_coupon_cache(
Ok(CouponCache::new(term_ords, coupons, missing_coupon_opt))
}
fn append_to_sketch(
term_ords: &FxHashSet<u64>,
fn append_to_sketch<S: TermOrdAccumulator>(
term_ords: &S,
coupon_cache: &CouponCache,
sketch: &mut CardinalityCollector,
) {
@@ -294,7 +588,7 @@ fn append_to_sketch(
coupon_map,
missing_coupon_opt,
} => {
for &term_ord in term_ords {
for term_ord in term_ords.iter_ords() {
if let Some(coupon) = coupon_map
.get(term_ord as usize)
.copied()
@@ -308,8 +602,8 @@ fn append_to_sketch(
coupon_map,
missing_coupon_opt,
} => {
for term_ord in term_ords {
if let Some(coupon) = coupon_map.get(term_ord).copied().or(*missing_coupon_opt) {
for term_ord in term_ords.iter_ords() {
if let Some(coupon) = coupon_map.get(&term_ord).copied().or(*missing_coupon_opt) {
sketch.insert_coupon(coupon);
}
}
@@ -317,12 +611,13 @@ fn append_to_sketch(
}
}
impl SegmentCardinalityCollector {
impl<S: TermOrdAccumulator> SegmentCardinalityCollector<S> {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
max_term_ord_inclusive: u64,
) -> Self {
Self {
buckets: Vec::new(),
@@ -331,6 +626,7 @@ impl SegmentCardinalityCollector {
accessor,
missing_value_for_accessor,
coupon_cache: None,
max_term_ord_inclusive,
}
}
@@ -347,7 +643,9 @@ impl SegmentCardinalityCollector {
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
impl<S: TermOrdAccumulator + 'static> SegmentAggregationCollector
for SegmentCardinalityCollector<S>
{
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -402,31 +700,41 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
));
};
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
bucket.entries.insert(term_ord);
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => {
// Promotion check runs on the pre-block state: the first call
// sees an empty set (no-op), and the last block of inserts
// doesn't trigger a promotion of a set we won't grow further.
// The trait dispatches once per block (via `extend_from_iter`)
// for adaptive variants and inlines to a tight loop for the
// BitSet path.
entries.maybe_compact();
entries.extend_from_iter(col_block_accessor.iter_vals());
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.insert(val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.insert(val);
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
cardinality.insert(val);
}
} else {
for val in col_block_accessor.iter_vals() {
cardinality.insert(val);
}
}
}
}
@@ -439,12 +747,40 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
let column_type = self.column_type;
let max_term_ord_inclusive = self.max_term_ord_inclusive;
self.buckets.resize_with(max_bucket as usize + 1, || {
Some(SegmentCardinalityCollectorBucket::new(self.column_type))
Some(SegmentCardinalityCollectorBucket::<S>::new(
column_type,
max_term_ord_inclusive,
))
});
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.name != sub_agg_name || !sub_agg_property.is_empty() {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?.as_ref()?;
// For string columns the sketch isn't built until finalization; the
// term_ord set's len is the exact distinct count. For numeric columns
// the sketch is populated during collect.
match bucket {
SegmentCardinalityCollectorBucket::Str(entries) => Some(entries.len() as f64),
SegmentCardinalityCollectorBucket::Numeric(cardinality) => {
Some(cardinality.sketch.estimate().trunc())
}
}
}
}
#[derive(Clone, Debug)]
@@ -489,7 +825,7 @@ impl<'de> Deserialize<'de> for CardinalityCollector {
impl CardinalityCollector {
fn new(salt: u8) -> Self {
Self {
sketch: HllSketch::new(LG_K, HllType::Hll4),
sketch: HllSketch::new(LG_K, HllType::Hll8),
salt,
}
}
@@ -520,7 +856,7 @@ impl CardinalityCollector {
let mut union = HllUnion::new(LG_K);
union.update(&self.sketch);
union.update(&right.sketch);
self.sketch = union.to_sketch(HllType::Hll4);
self.sketch = union.to_sketch(HllType::Hll8);
Ok(())
}
}
@@ -592,6 +928,134 @@ mod tests {
Ok(())
}
/// Build a single-segment string-cardinality index with 32 unique terms.
/// `column.max_value() = 31` is well below `BITSET_MAX_TERM_ORD`,
/// so the bucket exercises the `BitSet` path end to end.
#[test]
fn cardinality_aggregation_test_str_bitset() -> crate::Result<()> {
let terms: Vec<String> = (0..32).map(|i| format!("term_{i}")).collect();
let term_refs: Vec<Vec<&str>> = terms.iter().map(|t| vec![t.as_str()]).collect::<Vec<_>>();
// single segment so we have a single dictionary of 32 terms.
let index = get_test_index_from_terms(true, &term_refs)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": { "field": "string_id" }
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 32.0);
Ok(())
}
/// `BitSet` path with a `missing` parameter: the column-level missing
/// sentinel (`column.max_value() + 1`) flows into the bitset, the
/// dict lookup filter at finalization drops it, and the missing
/// coupon is applied separately.
#[test]
fn cardinality_aggregation_test_str_bitset_with_missing() {
let mut schema_builder = Schema::builder();
let name_field = schema_builder.add_text_field("name", STRING | FAST);
let index = Index::create_in_ram(schema_builder.build());
let mut writer = index.writer_for_tests().unwrap();
for i in 0..16 {
let term = format!("t{i:02}");
writer.add_document(doc!(name_field => term)).unwrap();
}
// One empty doc, exercising the missing sentinel.
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "name",
"missing": "MISSING_SENTINEL_KEY",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index).unwrap();
// 16 distinct real terms + 1 distinct "missing" value = 17.
assert_eq!(res["cardinality"]["value"], 17.0);
}
/// Unit-test the PagedBitset itself: cross-page inserts produce sorted
/// iteration, len() matches the inserted set, and duplicates are
/// idempotent.
#[test]
fn paged_bitset_basic() {
use super::PagedBitset;
// Span several pages: BITSET_PAGE_BITS = 1024, so ords > 1024 land
// on the second page, > 2048 on the third, etc.
let ords = [0u64, 1, 63, 64, 1023, 1024, 1025, 4096, 4097, 9999, 10_000];
let max_ord = *ords.iter().max().unwrap();
let mut bitset = PagedBitset::with_max_term_ord(max_ord);
for &ord in &ords {
bitset.insert(ord);
// Idempotent: inserting again must not increase count.
bitset.insert(ord);
}
assert_eq!(bitset.len(), ords.len() as u64);
let collected: Vec<u64> = bitset.iter_sorted().collect();
let mut expected: Vec<u64> = ords.to_vec();
expected.sort_unstable();
assert_eq!(collected, expected);
}
/// Unit-test `TermOrdSet`: starts Sparse, promotes to Dense on
/// `maybe_compact` once the density threshold is crossed, and
/// `iter_ords()` yields the same set in either state. Ords spanning
/// multiple paged-bitset pages exercise the Dense iter ordering.
#[test]
fn term_ord_set_promotes_on_maybe_compact() {
use super::{TermOrdAccumulator, TermOrdSet, PROMOTION_RATIO};
// Pick max so promotion needs few inserts: len * RATIO > max with
// RATIO=32 and max=64 trips at len=3 (3*32=96 > 64).
let max_term_ord = 64u64;
let mut set = <TermOrdSet as TermOrdAccumulator>::new(max_term_ord);
// Two inserts: should stay Sparse after maybe_compact (2 * RATIO = 64, not > 64).
set.insert(0);
set.insert(7);
set.maybe_compact();
assert_eq!(set.len(), 2);
// Third insert promotes on next maybe_compact.
set.insert(20);
assert_eq!(set.len(), 3);
// Sanity check: at len=3, 3 * PROMOTION_RATIO = 96 > 64.
assert!(3u64 * PROMOTION_RATIO > max_term_ord);
set.maybe_compact();
// Post-promotion: extending continues to work.
set.insert(15);
set.insert(15); // dup
assert_eq!(set.len(), 4);
let mut collected: Vec<u64> = set.iter_ords().collect();
collected.sort_unstable();
assert_eq!(collected, vec![0, 7, 15, 20]);
}
/// Unit-test the `BitSet` impl of `TermOrdAccumulator`: insert,
/// dedup, and iter_ords order.
#[test]
fn bitset_accumulator_basic() {
use common::BitSet;
use super::TermOrdAccumulator;
let mut set = <BitSet as TermOrdAccumulator>::new(255);
for ord in [0u64, 1, 63, 64, 65, 128, 200, 200, 0] {
<BitSet as TermOrdAccumulator>::insert(&mut set, ord);
}
assert_eq!(<BitSet as TermOrdAccumulator>::len(&set), 7);
let collected: Vec<u64> = set.iter_ords().collect();
assert_eq!(collected, vec![0, 1, 63, 64, 65, 128, 200]);
}
#[test]
fn cardinality_aggregation_u64() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
@@ -683,6 +1147,42 @@ mod tests {
Ok(())
}
/// A JSON path that resolves to both a Str column and a numeric column
/// produces two collector instances per segment — one with `Str` buckets
/// and one with `Numeric` buckets. Their `IntermediateMetricResult`s must
/// merge into the union cardinality.
#[test]
fn cardinality_aggregation_json_str_and_numeric() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_json_field("json", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(field => json!({"value": "hello"})))?;
writer.add_document(doc!(field => json!({"value": "world"})))?;
writer.add_document(doc!(field => json!({"value": "hello"})))?; // dup str
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(42u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(7u64)})))?; // dup num
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "json.value"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// 4 distinct values: "hello", "world", 7, 42.
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
#[test]
fn cardinality_collector_serde_roundtrip() {
use super::CardinalityCollector;

View File

@@ -399,6 +399,26 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let extended = self.buckets.get(bucket_id as usize)?;
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
// doesn't disturb the eventual intermediate result.
extended
.finalize()
.get_value(sub_agg_property)
.ok()
.flatten()
}
}
#[cfg(test)]

View File

@@ -312,6 +312,26 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
return None;
}
let percentile: f64 = sub_agg_property.parse().ok()?;
if !(0.0..=100.0).contains(&percentile) {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?;
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
// not affect the intermediate state used for the final result.
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
}
}
#[cfg(test)]

View File

@@ -321,6 +321,40 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let stats = self.buckets.get(bucket_id as usize)?;
// The property depends on what we're collecting:
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
// the value they were configured to collect.
let prop = match self.collecting_for {
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
StatsType::Sum if sub_agg_property.is_empty() => "sum",
StatsType::Count if sub_agg_property.is_empty() => "count",
StatsType::Max if sub_agg_property.is_empty() => "max",
StatsType::Min if sub_agg_property.is_empty() => "min",
StatsType::Average if sub_agg_property.is_empty() => "avg",
_ => return None,
};
match prop {
"count" => Some(stats.count as f64),
"sum" => Some(stats.sum),
"min" if stats.count > 0 => Some(stats.min),
"max" if stats.count > 0 => Some(stats.max),
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
_ => None,
}
}
}
#[inline]

View File

@@ -644,6 +644,17 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
);
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// top_hits is not a numeric metric and cannot be used as an order target.
None
}
}
#[cfg(test)]

View File

@@ -76,6 +76,31 @@ pub trait SegmentAggregationCollector: Debug {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
///
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
/// this value and cuts off at segment time, matching the approximation tradeoff
/// Elasticsearch makes for any sub-agg ordering.
///
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
/// the metric is a single-value kind such as cardinality.
///
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
/// must be idempotent so the final intermediate result is unaffected.
///
/// No default impl on purpose: every collector must decide explicitly whether it
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
/// buckets to the same key and drop arbitrary winners.
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64>;
}
#[derive(Default)]
@@ -137,4 +162,21 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
for agg in &self.aggs {
if let Some(value) =
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
{
return Some(value);
}
}
None
}
}

View File

@@ -389,6 +389,13 @@ impl SegmentCollector for FacetSegmentCollector {
}
let mut facet = vec![];
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
// document has the exact registered facet, not a child of it).
// Passing it to ord_to_term would resolve to the last dictionary
// entry and produce a spurious facet from an unrelated branch.
if facet_ord == u64::MAX {
continue;
}
// TODO handle errors.
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
if let Some((end_collapsed_facet, _)) = facet
@@ -814,6 +821,63 @@ mod tests {
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
// When a document has the exact registered facet path (not just a child),
// harvest() must not turn the unmapped sentinel into a spurious root entry.
#[test]
fn test_facet_collector_wrong_root() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let facets: Vec<&str> = vec![
"/science-fiction/asimov",
"/science-fiction/clarke",
"/science-fiction/dick",
"/science-fiction/herbert",
"/science-fiction/orwell",
// This exact match on the registered facet is the bug trigger:
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
// mapping, which without the fix resolves to an unrelated term.
"/fantasy/epic-fantasy",
"/fantasy/epic-fantasy/tolkien",
"/fantasy/epic-fantasy/martin",
];
for facet_str in &facets {
index_writer.add_document(doc!(
facet_field => Facet::from(*facet_str)
))?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut facet_collector = FacetCollector::for_field("facet");
facet_collector.add_facet("/fantasy/epic-fantasy");
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
let result: Vec<(String, u64)> = counts
.get("/")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
assert_eq!(
result,
vec![
("/fantasy/epic-fantasy/martin".to_string(), 1),
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
]
);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -138,6 +138,31 @@ pub trait DocSet: Send {
buffer.len()
}
/// Fills a given mutable buffer with the next doc ids smaller than `horizon`.
///
/// Unlike [`DocSet::fill_buffer`], this method must not advance past a doc id greater than or
/// equal to `horizon`.
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
if self.doc() == TERMINATED {
return 0;
}
for (pos, buffer_val) in buffer.iter_mut().enumerate() {
let doc = self.doc();
if doc >= horizon {
return pos;
}
*buffer_val = doc;
if self.advance() == TERMINATED {
return pos + 1;
}
}
buffer.len()
}
/// Returns the current document
/// Right after creating a new `DocSet`, the docset points to the first document.
///
@@ -251,6 +276,14 @@ impl DocSet for &mut dyn DocSet {
(**self).fill_buffer(buffer)
}
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
(**self).fill_buffer_up_to(horizon, buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,

View File

@@ -6,6 +6,7 @@ use common::{ByteCount, HasLen};
use fnv::FnvHashMap;
use itertools::Itertools;
use crate::directory::error::OpenReadError;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
@@ -159,12 +160,10 @@ impl SegmentReader {
let postings_file = segment.open_read(SegmentComponent::Postings)?;
let postings_composite = CompositeFile::open(&postings_file)?;
let positions_composite = {
if let Ok(positions_file) = segment.open_read(SegmentComponent::Positions) {
CompositeFile::open(&positions_file)?
} else {
CompositeFile::empty()
}
let positions_composite = match segment.open_read(SegmentComponent::Positions) {
Ok(positions_file) => CompositeFile::open(&positions_file)?,
Err(OpenReadError::FileDoesNotExist(_)) => CompositeFile::empty(),
Err(open_read_error) => return Err(open_read_error.into()),
};
let schema = segment.schema();

View File

@@ -240,6 +240,42 @@ impl BlockSegmentPostings {
self.freq_decoder.output_array()
}
pub(crate) fn copy_docs_and_term_freqs(
&self,
block_offset: usize,
horizon: DocId,
docs: &mut [DocId],
term_freqs: &mut [u32],
) -> usize {
debug_assert_eq!(docs.len(), term_freqs.len());
let block_docs = self.docs();
let remaining_docs_in_block = block_docs.len().saturating_sub(block_offset);
let max_len = remaining_docs_in_block.min(docs.len());
if max_len == 0 {
return 0;
}
let source_docs = &block_docs[block_offset..block_offset + max_len];
let len = if source_docs[max_len - 1] < horizon {
max_len
} else {
source_docs
.iter()
.position(|&doc| doc >= horizon)
.unwrap_or(max_len)
};
docs[..len].copy_from_slice(&source_docs[..len]);
let block_freqs = self.freq_output_array();
if block_freqs.len() >= block_offset + len {
term_freqs[..len].copy_from_slice(&block_freqs[block_offset..block_offset + len]);
} else {
term_freqs[..len].fill(1);
}
len
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
@@ -249,6 +285,12 @@ impl BlockSegmentPostings {
/// Returns the length of the current block.
///
/// Returns the decoded term-frequency buffer for the current block.
#[inline]
pub(crate) fn freq_output_array(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
@@ -298,6 +340,11 @@ impl BlockSegmentPostings {
}
}
#[inline]
pub(crate) fn has_remaining_docs(&self) -> bool {
self.skip_reader.has_remaining_docs()
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.block_loaded
}

View File

@@ -532,6 +532,16 @@ pub(crate) mod tests {
fn score(&mut self) -> Score {
self.0.score()
}
#[inline]
fn can_score_doc(&self) -> bool {
self.0.can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.0.score_doc(doc, term_freq)
}
}
pub fn test_skip_against_unoptimized<F: Fn() -> Box<dyn DocSet>>(

View File

@@ -1,6 +1,6 @@
use common::HasLen;
use crate::docset::DocSet;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::fastfield::AliveBitSet;
use crate::positions::PositionReader;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
@@ -151,6 +151,34 @@ impl SegmentPostings {
position_reader,
}
}
pub(crate) fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
let mut num_elems = 0;
while num_elems < COLLECT_BLOCK_BUFFER_LEN && self.doc() < horizon {
let copied = self.block_cursor.copy_docs_and_term_freqs(
self.cur,
horizon,
&mut docs[num_elems..],
&mut term_freqs[num_elems..],
);
if copied == 0 {
break;
}
num_elems += copied;
self.cur += copied;
if self.cur == COMPRESSION_BLOCK_SIZE {
self.cur = 0;
self.block_cursor.advance();
}
}
num_elems
}
}
impl DocSet for SegmentPostings {

View File

@@ -146,6 +146,11 @@ impl SkipReader {
skip_reader
}
#[inline(always)]
pub fn has_remaining_docs(&self) -> bool {
self.remaining_docs != 0
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0

View File

@@ -109,6 +109,16 @@ impl Scorer for AllScorer {
fn score(&mut self) -> Score {
1.0
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
1.0
}
}
#[cfg(test)]

View File

@@ -1,5 +1,9 @@
use std::cell::RefCell;
use std::num::NonZeroUsize;
use std::sync::Arc;
use lru::LruCache;
use crate::fieldnorm::FieldNormReader;
use crate::query::Explanation;
use crate::schema::Field;
@@ -59,7 +63,9 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
}
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
const BM25_TF_CACHE_CAPACITY: usize = 64;
fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let mut cache: [Score; 256] = [0.0; 256];
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
@@ -68,6 +74,36 @@ fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
Arc::new(cache)
}
thread_local! {
static TF_CACHES: RefCell<LruCache<u32, Arc<[Score; 256]>>> = RefCell::new(LruCache::new(
NonZeroUsize::new(BM25_TF_CACHE_CAPACITY).unwrap(),
));
}
/// The cache is shared across all [Bm25Weight] with the same average fieldnorm on the same thread.
/// It is stored in a thread local LRU cache.
///
/// On one query all terms on the same field will share the same average fieldnorm, and thus the
/// same cache. This will lower cache pressure.
///
/// Even between queries (on the same thread), the cache will be reused, which allows the cache to
/// better learn the memory address of the cache and access patterns.
///
/// Thread local is used in order to be defensive about potential contention on the cache.
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let cache_key = average_fieldnorm.to_bits();
TF_CACHES.with(|cache_by_average_fieldnorm| {
let mut cache_by_average_fieldnorm = cache_by_average_fieldnorm.borrow_mut();
if let Some(cache) = cache_by_average_fieldnorm.get(&cache_key) {
return cache.clone();
}
let cache = compute_tf_cache_uncached(average_fieldnorm);
cache_by_average_fieldnorm.put(cache_key, cache.clone());
cache
})
}
/// A struct used for computing BM25 scores.
#[derive(Clone)]
pub struct Bm25Weight {
@@ -229,7 +265,7 @@ impl Bm25Weight {
#[cfg(test)]
mod tests {
use super::idf;
use super::{idf, Bm25Weight};
use crate::{assert_nearly_equals, Score};
#[test]
@@ -237,4 +273,12 @@ mod tests {
let score: Score = 2.0;
assert_nearly_equals!(idf(1, 2), score.ln());
}
#[test]
fn test_bm25_tf_cache_is_shared_for_same_average_fieldnorm() {
let weight1 = Bm25Weight::for_one_term(1, 10, 3.0);
let weight2 = Bm25Weight::for_one_term(2, 10, 3.0);
assert!(std::sync::Arc::ptr_eq(&weight1.cache, &weight2.cache));
}
}

View File

@@ -0,0 +1,464 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
/// Block-max pruning for top-K over intersection of term scorers.
///
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
/// For each window, the sum of block_max_scores is compared to the current threshold;
/// if the block can't beat it, the entire block is skipped.
///
/// Within non-skipped blocks, individual documents are pruned by checking whether
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
/// performing the expensive intersection membership check (seeking into secondary scorers).
///
/// # Preconditions
/// - `scorers` has at least 2 elements
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
pub(crate) fn block_wand_intersection(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
assert!(scorers.len() >= 2);
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
scorers.sort_by_key(TermScorer::size_hint);
let (leader, secondaries) = scorers.split_first_mut().unwrap();
// Precompute global max scores for early termination checks.
let leader_max_score: Score = leader.max_score();
let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum();
// Early exit: no document can possibly beat the threshold.
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
// Borrow fieldnorm reader and BM25 weight before the main loop.
// These are immutable references to disjoint fields from block_cursor,
// but Rust's borrow checker can't see through method calls, so we
// extract them once upfront.
let fieldnorm_reader = leader.fieldnorm_reader().clone();
let bm25_weight = leader.bm25_weight().clone();
let mut doc = leader.doc();
let mut secondary_block_max_scores: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
let mut secondary_suffix_block_max: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
while doc < TERMINATED {
// --- Phase 1: Block-level pruning ---
//
// Position all skip readers on the block containing `doc`.
// seek_block is cheap: it only advances the skip reader, no block decompression.
leader.seek_block(doc);
let leader_block_max: Score = leader.block_max_score();
// Compute the window end as the minimum last_doc_in_block across all scorers.
// This ensures the block_max values are valid for all docs in [doc, window_end].
// Different scorers have independently aligned blocks, so we must use the
// smallest window where all block_max values hold.
let mut window_end: DocId = leader.last_doc_in_block();
let mut secondary_block_max_sum: Score = 0.0;
let num_secondaries = secondaries.len();
for (idx, secondary) in secondaries.iter_mut().enumerate() {
secondary.block_cursor().seek_block(doc);
if !secondary.block_cursor().has_remaining_docs() {
return;
}
window_end = window_end.min(secondary.last_doc_in_block());
let bms = secondary.block_max_score();
secondary_block_max_scores[idx] = bms;
secondary_block_max_sum += bms;
}
if leader_block_max + secondary_block_max_sum <= threshold {
// The entire window cannot beat the threshold. Skip past it.
doc = window_end + 1;
continue;
}
// --- Phase 2: Batch processing within the window ---
//
// Score-first approach: decode the leader's block, filter by threshold,
// then check intersection membership only for survivors. This avoids expensive
// secondary seeks for docs that can't beat the threshold.
let block_cursor = leader.block_cursor();
// seek loads the block and returns the in-block index of the first doc >= `doc`.
let start_idx = block_cursor.seek(doc);
// Use the branchless binary search on the doc decoder to find the first
// index past window_end.
let end_idx = block_cursor
.doc_decoder
.seek_within_block(window_end + 1)
.min(block_cursor.block_len());
let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx];
let block_freqs = &block_cursor.freq_output_array()[start_idx..end_idx];
// Pass 1: Batch-compute leader BM25 scores and branchlessly filter
// candidates that can't beat the threshold.
//
// The trick: always write to the buffer at `num_candidates`, then
// conditionally advance the count. The compiler can turn this into
// a cmov instead of a branch, avoiding misprediction costs.
let score_threshold = threshold - secondary_block_max_sum;
let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE];
let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE];
let mut num_candidates = 0usize;
for (candidate_doc, term_freq) in
block_docs.iter().copied().zip(block_freqs.iter().copied())
{
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc);
let leader_score = bm25_weight.score(fieldnorm_id, term_freq);
candidate_doc_ids[num_candidates] = candidate_doc;
candidate_scores[num_candidates] = leader_score;
num_candidates += (leader_score > score_threshold) as usize;
}
// Precompute suffix sums: suffix[i] = sum of block_max for secondaries[i+1..].
// Used in Phase 2 to prune candidates that can't beat threshold even with
// remaining secondaries contributing their block_max.
if num_candidates == 0 {
doc = window_end + 1;
continue;
}
let mut running = 0.0f32;
for idx in (0..num_secondaries).rev() {
secondary_suffix_block_max[idx] = running;
running += secondary_block_max_scores[idx];
}
// Pass 2: Check intersection membership only for survivors.
// score_threshold may be stale (threshold can increase from callbacks),
// but that's conservative — we may check a few extra candidates, never miss one.
'next_candidate: for candidate_idx in 0..num_candidates {
let candidate_doc = candidate_doc_ids[candidate_idx];
let mut total_score: Score = candidate_scores[candidate_idx];
for (secondary_idx, secondary) in secondaries.iter_mut().enumerate() {
// If a previous candidate already advanced this secondary past
// candidate_doc, the candidate can't be in the intersection.
if secondary.doc() > candidate_doc {
continue 'next_candidate;
}
let seek_result = secondary.seek(candidate_doc);
if seek_result != candidate_doc {
continue 'next_candidate;
}
total_score += secondary.score();
// Prune: even if all remaining secondaries score at their block max,
// can we still beat the threshold?
if total_score + secondary_suffix_block_max[secondary_idx] <= threshold {
continue 'next_candidate;
}
}
// All secondaries matched.
if total_score > threshold {
threshold = callback(candidate_doc, total_score);
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
}
}
doc = window_end + 1;
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use proptest::prelude::*;
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
struct Float(Score);
impl Eq for Float {}
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Float {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn nearly_equals(left: Score, right: Score) -> bool {
(left - right).abs() < 0.0001 * (left + right).abs()
}
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
fn compute_checkpoints_block_wand_intersection(
term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit: Score = 0.0;
let callback = &mut |doc, score| {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
limit
};
super::block_wand_intersection(term_scorers, Score::MIN, callback);
checkpoints
}
/// Naive baseline: intersect by iterating all docs.
fn compute_checkpoints_naive_intersection(
mut term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit = Score::MIN;
// Sort by cost to use the cheapest as driver.
term_scorers.sort_by_key(|s| s.cost());
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
let mut doc = leader.doc();
while doc != TERMINATED {
let mut all_match = true;
for secondary in secondaries.iter_mut() {
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
all_match = false;
break;
}
}
if all_match {
let score: Score =
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
if score > limit {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
}
}
doc = leader.advance();
}
checkpoints
}
const MAX_TERM_FREQ: u32 = 100u32;
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
(1..max_doc + 1)
.prop_flat_map(move |doc_freq| {
(
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
)
})
.prop_map(|(docset, term_freqs)| {
docset
.iter()
.map(|doc| doc as u32)
.zip(term_freqs.iter().cloned())
.collect::<Vec<_>>()
})
.boxed()
}
#[expect(clippy::type_complexity)]
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
(1u32..100u32)
.prop_flat_map(move |max_doc: u32| {
(
proptest::collection::vec(posting_list(max_doc), num_scorers),
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
)
})
.boxed()
}
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
// strategy.
const REPEAT: usize = 64;
let fieldnorms_expanded: Vec<u32> = fieldnorms
.iter()
.cloned()
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
.collect();
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
.iter()
.map(|posting_list| {
posting_list
.iter()
.cloned()
.flat_map(|(doc, term_freq)| {
(0_u32..REPEAT as u32).map(move |offset| {
(
doc * (REPEAT as u32) + offset,
if offset == 0 { term_freq } else { 1 },
)
})
})
.collect::<Vec<(DocId, u32)>>()
})
.collect();
let total_fieldnorms: u64 = fieldnorms_expanded
.iter()
.cloned()
.map(|fieldnorm| fieldnorm as u64)
.sum();
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
let max_doc = fieldnorms_expanded.len();
let make_scorers = || -> Vec<TermScorer> {
postings_lists_expanded
.iter()
.map(|postings| {
let bm25_weight = Bm25Weight::for_one_term(
postings.len() as u64,
max_doc as u64,
average_fieldnorm,
);
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
})
.collect()
};
for top_k in 1..4 {
let checkpoints_optimized =
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
assert_eq!(
checkpoints_optimized.len(),
checkpoints_naive.len(),
"Mismatch in checkpoint count for top_k={top_k}"
);
for (&(left_doc, left_score), &(right_doc, right_score)) in
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
{
assert_eq!(left_doc, right_doc);
assert!(
nearly_equals(left_score, right_score),
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_two_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(2)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_three_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(3)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
#[test]
fn test_block_wand_intersection_disjoint() {
// Two posting lists with no overlap — intersection is empty.
let fieldnorms: Vec<u32> = vec![10; 200];
let average_fieldnorm = 10.0;
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
let scorer_a = TermScorer::create_for_test(
&postings_a,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let scorer_b = TermScorer::create_for_test(
&postings_b,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
assert!(checkpoints.is_empty());
}
#[test]
fn test_block_wand_intersection_all_overlap() {
// Two posting lists with full overlap.
let fieldnorms: Vec<u32> = vec![10; 50];
let average_fieldnorm = 10.0;
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
let make_scorer = || {
TermScorer::create_for_test(
&postings,
&fieldnorms,
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
)
};
let checkpoints_opt =
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
let checkpoints_naive =
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
}
}

View File

@@ -50,7 +50,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers: &mut [TermScorerWithMaxScore],
pivot_len: usize,
) {
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
let mut scorer_to_seek = pivot_len - 1;
let mut global_max_score = scorers[scorer_to_seek].max_score;
let mut doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block();
@@ -76,7 +76,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers[scorer_to_seek].seek(doc_to_seek_after);
restore_ordering(scorers, scorer_to_seek);
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
}
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
@@ -90,7 +90,7 @@ fn restore_ordering(term_scorers: &mut [TermScorerWithMaxScore], ord: usize) {
}
term_scorers.swap(i, i - 1);
}
debug_assert!(is_sorted(term_scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(term_scorers.iter().map(|scorer| scorer.doc()).is_sorted());
}
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
@@ -150,17 +150,21 @@ pub fn block_wand(
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
scorers.retain(|scorer| scorer.doc() < TERMINATED);
if scorers.len() == 1 {
let scorer = scorers.pop().unwrap();
return block_wand_single_scorer(scorer, threshold, callback);
}
let mut scorers: Vec<TermScorerWithMaxScore> = scorers
.iter_mut()
.map(TermScorerWithMaxScore::from)
.collect();
scorers.sort_by_key(|scorer| scorer.doc());
// At this point we need to ensure that the scorers are sorted!
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
scorers.sort_by_key(|scorer| scorer.doc());
while let Some((before_pivot_len, pivot_len, pivot_doc)) =
find_pivot_doc(&scorers[..], threshold)
{
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert_ne!(pivot_doc, TERMINATED);
debug_assert!(before_pivot_len < pivot_len);
@@ -228,7 +232,7 @@ pub fn block_wand_single_scorer(
loop {
// We position the scorer on a block that can reach
// the threshold.
while scorer.block_max_score() < threshold {
while scorer.block_max_score() <= threshold {
let last_doc_in_block = scorer.last_doc_in_block();
if last_doc_in_block == TERMINATED {
return;
@@ -286,18 +290,6 @@ impl DerefMut for TermScorerWithMaxScore<'_> {
}
}
fn is_sorted<I: Iterator<Item = DocId>>(mut it: I) -> bool {
if let Some(first) = it.next() {
let mut prev = first;
for doc in it {
if doc < prev {
return false;
}
prev = doc;
}
}
true
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

View File

@@ -16,6 +16,7 @@ use crate::{DocId, Score};
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
TermIntersection(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
@@ -49,10 +50,9 @@ where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 {
if scorers.len() == 1 && !scorers[0].is::<TermScorer>() {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
}
{
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
if is_all_term_queries {
@@ -66,6 +66,9 @@ where
{
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else if scorers.len() == 1 {
// Single TermScorer without freq reading — unwrap directly.
return SpecializedScorer::Other(Box::new(scorers.into_iter().next().unwrap()));
} else {
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
@@ -88,10 +91,21 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
num_docs: u32,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
SpecializedScorer::TermUnion(mut term_scorers) => {
if term_scorers.len() == 1 {
Box::new(term_scorers.pop().unwrap())
} else {
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
intersect_scorers(boxed_scorers, num_docs)
}
SpecializedScorer::Other(scorer) => scorer,
}
@@ -297,14 +311,43 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
// Try to detect a pure TermScorer intersection for block-max optimization.
// Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer
// with frequency reading enabled.
if combined_all_scorer_count == 0
&& must_scorers.len() >= 2
&& must_scorers.iter().all(|s| s.is::<TermScorer>())
{
let term_scorers: Vec<TermScorer> = must_scorers
.into_iter()
.map(|s| *(s.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
if term_scorers
.iter()
.all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq)
{
SpecializedScorer::TermIntersection(term_scorers)
} else {
let must_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
let boxed_scorer: Box<dyn Scorer> =
effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
} else {
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
@@ -463,15 +506,21 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
let num_docs = reader.num_docs();
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_scorer(intersection.as_mut(), callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
@@ -485,17 +534,23 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let num_docs = reader.num_docs();
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|term_scorer| Box::new(term_scorer) as Box<dyn Scorer>)
.collect();
let mut intersection = intersect_scorers(boxed_scorers, num_docs);
for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
@@ -524,6 +579,9 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
super::block_wand_intersection(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}

View File

@@ -1,8 +1,10 @@
mod block_wand;
mod block_wand_intersection;
mod block_wand_union;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub(crate) use self::block_wand_intersection::block_wand_intersection;
pub(crate) use self::block_wand_union::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_weight::BooleanWeight;

View File

@@ -112,6 +112,14 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.fill_buffer(buffer)
}
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.underlying.fill_buffer_up_to(horizon, buffer)
}
fn doc(&self) -> u32 {
self.underlying.doc()
}
@@ -138,6 +146,27 @@ impl<S: Scorer> Scorer for BoostScorer<S> {
fn score(&mut self) -> Score {
self.underlying.score() * self.boost
}
#[inline]
fn can_score_doc(&self) -> bool {
self.underlying.can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.underlying.score_doc(doc, term_freq) * self.boost
}
#[inline]
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.underlying
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}
#[cfg(test)]

View File

@@ -141,6 +141,16 @@ impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score {
self.score
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score
}
}
#[cfg(test)]

View File

@@ -315,6 +315,20 @@ mod tests {
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, doc: DocId, _term_freq: u32) -> Score {
self.foo
.iter()
.find(|(candidate_doc, _)| *candidate_doc == doc)
.map(|(_, score)| *score)
.unwrap_or(0.0)
}
}
#[test]

View File

@@ -59,6 +59,16 @@ impl Scorer for EmptyScorer {
fn score(&mut self) -> Score {
0.0
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
0.0
}
}
#[cfg(test)]

View File

@@ -1,5 +1,40 @@
use crate::docset::{DocSet, TERMINATED};
use crate::query::Scorer;
use crate::Score;
use crate::{DocId, Score};
struct ScoreOnlyScorer {
doc: DocId,
score: Score,
}
impl DocSet for ScoreOnlyScorer {
fn advance(&mut self) -> DocId {
self.doc = TERMINATED;
TERMINATED
}
fn doc(&self) -> DocId {
self.doc
}
fn size_hint(&self) -> u32 {
1
}
}
impl Scorer for ScoreOnlyScorer {
fn score(&mut self) -> Score {
self.score
}
fn can_score_doc(&self) -> bool {
true
}
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score
}
}
/// The `ScoreCombiner` trait defines how to compute
/// an overall score given a list of scores.
@@ -10,6 +45,17 @@ pub trait ScoreCombiner: Default + Clone + Send + Copy + 'static {
/// or not.
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer);
/// Aggregates the score combiner with an already computed score.
fn update_score(&mut self, doc: DocId, score: Score) {
let mut scorer = ScoreOnlyScorer { doc, score };
self.update(&mut scorer);
}
/// Returns true if this combiner needs scorer scores to compute its state.
fn requires_scoring() -> bool {
true
}
/// Clears the score combiner state back to its initial state.
fn clear(&mut self);
@@ -27,6 +73,12 @@ pub struct DoNothingCombiner;
impl ScoreCombiner for DoNothingCombiner {
fn update<TScorer: Scorer>(&mut self, _scorer: &mut TScorer) {}
fn update_score(&mut self, _doc: DocId, _score: Score) {}
fn requires_scoring() -> bool {
false
}
fn clear(&mut self) {}
#[inline]
@@ -42,10 +94,16 @@ pub struct SumCombiner {
}
impl ScoreCombiner for SumCombiner {
#[inline]
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
self.score += scorer.score();
}
#[inline]
fn update_score(&mut self, _doc: DocId, score: Score) {
self.score += score;
}
fn clear(&mut self) {
self.score = 0.0;
}
@@ -77,12 +135,19 @@ impl DisjunctionMaxCombiner {
}
impl ScoreCombiner for DisjunctionMaxCombiner {
#[inline]
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
let score = scorer.score();
self.max = Score::max(score, self.max);
self.sum += score;
}
#[inline]
fn update_score(&mut self, _doc: DocId, score: Score) {
self.max = Score::max(score, self.max);
self.sum += score;
}
fn clear(&mut self) {
self.max = 0.0;
self.sum = 0.0;

View File

@@ -2,8 +2,8 @@ use std::ops::DerefMut;
use downcast_rs::impl_downcast;
use crate::docset::DocSet;
use crate::Score;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::{DocId, Score};
/// Scored set of documents matching a query within a specific segment.
///
@@ -13,6 +13,36 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
///
/// This method will perform a bit of computation and is not cached.
fn score(&mut self) -> Score;
/// Returns true if [`Scorer::score_doc`] can score buffered docs without
/// repositioning the scorer.
///
/// Scorers whose [`Scorer::score_doc`] needs term frequencies must also override
/// [`Scorer::fill_buffer_up_to_with_term_freqs`].
fn can_score_doc(&self) -> bool {
false
}
/// Returns the score for `doc` with its term frequency.
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
panic!(
"score_doc is not supported by this scorer. You need check can_score_doc() before \
calling this method."
)
}
/// Fills docs up to `horizon`.
///
/// The default implementation does not fill `term_freqs`. Scorers whose
/// [`Scorer::score_doc`] reads term frequencies must override this method.
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
_term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
DocSet::fill_buffer_up_to(self, horizon, docs)
}
}
impl_downcast!(Scorer);
@@ -22,4 +52,25 @@ impl Scorer for Box<dyn Scorer> {
fn score(&mut self) -> Score {
self.deref_mut().score()
}
#[inline]
fn can_score_doc(&self) -> bool {
self.as_ref().can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.deref_mut().score_doc(doc, term_freq)
}
#[inline]
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.deref_mut()
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}

View File

@@ -1,6 +1,6 @@
use crate::docset::DocSet;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::fieldnorm::FieldNormReader;
use crate::postings::{FreqReadingOption, Postings, SegmentPostings};
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
use crate::query::bm25::Bm25Weight;
use crate::query::{Explanation, Scorer};
use crate::{DocId, Score};
@@ -95,6 +95,21 @@ impl TermScorer {
pub fn last_doc_in_block(&self) -> DocId {
self.postings.block_cursor.skip_reader().last_doc_in_block()
}
/// Returns a mutable reference to the underlying block cursor.
pub(crate) fn block_cursor(&mut self) -> &mut BlockSegmentPostings {
&mut self.postings.block_cursor
}
/// Returns a reference to the fieldnorm reader for batch lookups.
pub(crate) fn fieldnorm_reader(&self) -> &FieldNormReader {
&self.fieldnorm_reader
}
/// Returns a reference to the BM25 weight for batch score computation.
pub(crate) fn bm25_weight(&self) -> &Bm25Weight {
&self.similarity_weight
}
}
impl DocSet for TermScorer {
@@ -132,6 +147,27 @@ impl Scorer for TermScorer {
let term_freq = self.term_freq();
self.similarity_weight.score(fieldnorm_id, term_freq)
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
self.similarity_weight.score(fieldnorm_id, term_freq)
}
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.postings
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}
#[cfg(test)]

View File

@@ -10,23 +10,7 @@ use crate::{DocId, Score};
// of upcoming document IDs (the "horizon").
const HORIZON_NUM_TINYBITSETS: usize = HORIZON as usize / 64;
const HORIZON: u32 = 64u32 * 64u32;
// `drain_filter` is not stable yet.
// This function is similar except that it does is not unstable, and
// it does not keep the original vector ordering.
//
// Elements are dropped and not yielded.
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
where P: FnMut(&mut T) -> bool {
let mut i = 0;
while i < v.len() {
if predicate(&mut v[i]) {
v.swap_remove(i);
} else {
i += 1;
}
}
}
const GROUPED_INSERT_MAX_BUCKET_SPAN: u32 = 2;
/// Creates a `DocSet` that iterate through the union of two or more `DocSet`s.
pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
@@ -53,31 +37,213 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
score: Score,
/// Number of documents in the segment.
num_docs: u32,
/// Scratch buffer for block-based refill.
refill_docs: [DocId; COLLECT_BLOCK_BUFFER_LEN],
/// Scratch buffer for term frequencies matching `refill_docs`.
refill_term_freqs: [u32; COLLECT_BLOCK_BUFFER_LEN],
/// Whether all children support scoring buffered docs after advancing.
use_score_doc_refill: bool,
}
#[inline]
fn union_bucket(
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
bucket_pos: u32,
tinyset: TinySet,
) {
debug_assert!((bucket_pos as usize) < HORIZON_NUM_TINYBITSETS);
// `bucket` comes from a doc delta below `HORIZON`; there are exactly
// `HORIZON / 64` buckets in the refill window.
bitsets[bucket_pos as usize] = bitsets[bucket_pos as usize].union(tinyset);
}
#[inline]
fn insert_delta(bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], delta: DocId) {
debug_assert!(delta < HORIZON);
// `delta < HORIZON`, so `delta / 64` is in the bitset array. The bit
// offset is reduced modulo 64 before being inserted in the TinySet.
bitsets[delta as usize / 64].insert_mut(delta % 64u32);
}
fn insert_and_score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
) {
debug_assert!(docs.windows(2).all(|pair| pair[0] < pair[1]));
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc < HORIZON);
let first_delta = docs[0] - min_doc;
let last_delta = docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc;
let first_bucket = first_delta / 64;
let last_bucket = last_delta / 64;
// Common for very dense scorers: 64 distinct doc ids in one 64-doc bucket
// means all bits in that bucket are present.
if first_bucket == last_bucket {
union_bucket(bitsets, first_bucket, TinySet::full());
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
return;
}
// 64 sorted distinct integers spanning exactly 64 values are consecutive.
// If they cross a TinySet boundary, this is just the suffix of the first
// bucket plus the prefix of the second bucket.
if last_delta - first_delta == COLLECT_BLOCK_BUFFER_LEN as u32 - 1 {
union_bucket(
bitsets,
first_bucket,
TinySet::range_greater_or_equal(first_delta % 64u32),
);
union_bucket(
bitsets,
last_bucket,
TinySet::range_lower((last_delta + 1) % 64u32),
);
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
return;
}
// Grouping wins only for very dense buffers that hit the same TinySet many
// times. Once the 64 docs are spread farther, a straight pass is cheaper.
if last_bucket - first_bucket <= GROUPED_INSERT_MAX_BUCKET_SPAN {
let mut bucket = first_bucket;
let mut tinyset = TinySet::empty();
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let delta = doc - min_doc;
let delta_bucket = delta / 64;
if delta_bucket != bucket {
union_bucket(bitsets, bucket, tinyset);
bucket = delta_bucket;
tinyset = TinySet::empty();
}
tinyset.insert_mut(delta % 64u32);
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
union_bucket(bitsets, bucket, tinyset);
} else {
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let delta = doc - min_doc;
insert_delta(bitsets, delta);
// TODO: score_doc access the field_norm reader for each _term_, instead of once per
// doc. We could optimize this by caching the field norm for the doc, and
// reusing it for all terms in the doc.
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
}
}
#[inline]
fn update_score_combiner<TScoreCombiner: ScoreCombiner>(
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
delta: DocId,
doc: DocId,
score: Score,
) {
debug_assert!(delta < HORIZON);
// Full and partial refill only buffer docs below `horizon`, so their
// deltas are always in the score-combiner window.
score_combiner[delta as usize].update_score(doc, score);
}
fn score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
) {
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, doc - min_doc, doc, score);
}
}
fn refill_scorer_with_score_docs<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
min_doc: DocId,
horizon: DocId,
) {
loop {
let len = scorer.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs);
if len == COLLECT_BLOCK_BUFFER_LEN {
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] != TERMINATED);
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] < horizon);
insert_and_score_full_buffer(
scorer,
docs,
term_freqs,
bitsets,
score_combiner,
min_doc,
);
} else {
for (&doc, &term_freq) in docs[..len].iter().zip(term_freqs[..len].iter()) {
let delta = doc - min_doc;
insert_delta(bitsets, delta);
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
break;
}
}
}
fn refill_scorer_from_current_doc<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
horizon: DocId,
) {
loop {
let doc = scorer.doc();
if doc >= horizon {
break;
}
let delta = doc - min_doc;
insert_delta(bitsets, delta);
debug_assert!(delta < HORIZON);
score_combiner[delta as usize].update(scorer);
scorer.advance();
}
}
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorers: &mut Vec<TScorer>,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
min_doc: DocId,
use_score_doc_refill: bool,
) {
unordered_drain_filter(scorers, |scorer| {
let horizon = min_doc + HORIZON;
loop {
let doc = scorer.doc();
if doc >= horizon {
return false;
}
// add this document
let delta = doc - min_doc;
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
score_combiner[delta as usize].update(scorer);
if scorer.advance() == TERMINATED {
// remove the docset, it has been entirely consumed.
return true;
}
let horizon = min_doc + HORIZON;
for scorer in scorers.iter_mut() {
if use_score_doc_refill {
refill_scorer_with_score_docs(
scorer,
bitsets,
score_combiner,
docs,
term_freqs,
min_doc,
horizon,
);
} else {
refill_scorer_from_current_doc(scorer, bitsets, score_combiner, min_doc, horizon);
}
});
}
scorers.retain(|scorer| scorer.doc() != TERMINATED);
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
@@ -87,6 +253,8 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
num_docs: u32,
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
let use_score_doc_refill =
TScoreCombiner::requires_scoring() && docsets.iter().all(Scorer::can_score_doc);
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
.filter(|docset| docset.doc() != TERMINATED)
@@ -100,6 +268,9 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
doc: 0,
score: 0.0,
num_docs,
refill_docs: [TERMINATED; COLLECT_BLOCK_BUFFER_LEN],
refill_term_freqs: [1u32; COLLECT_BLOCK_BUFFER_LEN],
use_score_doc_refill,
};
if union.refill() {
union.advance();
@@ -120,7 +291,10 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
&mut self.docsets,
&mut self.bitsets,
&mut self.scores,
&mut self.refill_docs,
&mut self.refill_term_freqs,
min_doc,
self.use_score_doc_refill,
);
true
} else {
@@ -248,12 +422,12 @@ where
// The target is outside of the buffered horizon.
// advance all docsets to a doc >= to the target.
unordered_drain_filter(&mut self.docsets, |docset| {
for docset in &mut self.docsets {
if docset.doc() < target {
docset.seek(target);
}
docset.doc() == TERMINATED
});
}
self.docsets.retain(|docset| docset.doc() != TERMINATED);
// at this point all of the docsets
// are positioned on a doc >= to the target.

View File

@@ -10,6 +10,8 @@ pub use simple_union::SimpleUnion;
mod tests {
use std::collections::BTreeSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use common::BitSet;
@@ -18,8 +20,8 @@ mod tests {
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::union::bitset_union::BitSetPostingUnion;
use crate::query::{BitSetDocSet, ConstScorer, VecDocSet};
use crate::{tests, DocId};
use crate::query::{BitSetDocSet, ConstScorer, Scorer, VecDocSet};
use crate::{tests, DocId, Score};
fn vec_doc_set_from_docs_list(
docs_list: &[Vec<DocId>],
@@ -66,6 +68,61 @@ mod tests {
}
BitSetDocSet::from(doc_bitset)
}
struct CountingScorer {
docset: VecDocSet,
score_calls: Arc<AtomicUsize>,
score_doc_calls: Arc<AtomicUsize>,
}
impl CountingScorer {
fn new(
doc_ids: Vec<DocId>,
score_calls: Arc<AtomicUsize>,
score_doc_calls: Arc<AtomicUsize>,
) -> Self {
CountingScorer {
docset: VecDocSet::from(doc_ids),
score_calls,
score_doc_calls,
}
}
}
impl DocSet for CountingScorer {
fn advance(&mut self) -> DocId {
self.docset.advance()
}
fn seek(&mut self, target: DocId) -> DocId {
self.docset.seek(target)
}
fn doc(&self) -> DocId {
self.docset.doc()
}
fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
}
impl Scorer for CountingScorer {
fn score(&mut self) -> Score {
self.score_calls.fetch_add(1, Ordering::SeqCst);
1.0
}
fn can_score_doc(&self) -> bool {
true
}
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score_doc_calls.fetch_add(1, Ordering::SeqCst);
1.0
}
}
fn aux_test_union(docs_list: &[Vec<DocId>]) {
for constructor in [
posting_list_union_from_docs_list,
@@ -168,6 +225,22 @@ mod tests {
]);
}
#[test]
fn test_do_nothing_combiner_does_not_score_buffered_docs() {
let score_calls = Arc::new(AtomicUsize::new(0));
let score_doc_calls = Arc::new(AtomicUsize::new(0));
let scorers = vec![
CountingScorer::new(vec![1, 3, 5], score_calls.clone(), score_doc_calls.clone()),
CountingScorer::new(vec![2, 3, 6], score_calls.clone(), score_doc_calls.clone()),
];
let mut union = BufferedUnionScorer::build(scorers, DoNothingCombiner::default, 10);
assert_eq!(union.count_including_deleted(), 5);
assert_eq!(score_calls.load(Ordering::SeqCst), 0);
assert_eq!(score_doc_calls.load(Ordering::SeqCst), 0);
}
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
for constructor in [
posting_list_union_from_docs_list,

View File

@@ -23,7 +23,7 @@ zstd-compression = ["zstd"]
[dev-dependencies]
proptest = "1"
criterion = { version = "0.5", default-features = false }
criterion = { version = "0.8", default-features = false }
names = "0.14"
rand = "0.9"

View File

@@ -14,11 +14,8 @@ use itertools::Itertools;
use tantivy_fst::Automaton;
use tantivy_fst::automaton::AlwaysMatch;
use crate::sstable_index_v3::SSTableIndexV3Empty;
use crate::streamer::{Streamer, StreamerBuilder};
use crate::{
BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, SSTableIndexV3, TermOrdinal, VoidSSTable,
};
use crate::{BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, TermOrdinal, VoidSSTable};
/// An SSTable is a sorted map that associates sorted `&[u8]` keys
/// to any kind of typed values.
@@ -288,33 +285,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let (sstable_slice, index_slice) = main_slice.split(index_offset as usize);
let sstable_index_bytes = index_slice.read_bytes()?;
let sstable_index = match version {
2 => SSTableIndex::V2(
crate::sstable_index_v2::SSTableIndex::load(sstable_index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
),
3 => {
let (sstable_index_bytes, mut footerv3_len_bytes) = sstable_index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(
SSTableIndexV3::load(sstable_index_bytes, store_offset).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
let sstable_index = SSTableIndex::open(version, index_offset, sstable_index_bytes)?;
Ok(Dictionary {
sstable_slice,
@@ -525,10 +496,15 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
// Open the block for the first ordinal.
let mut bytes = Vec::new();
let mut current_block_addr = self.sstable_index.get_block_with_ord(ord);
let (mut current_block_addr, block_id) = self.sstable_index.get_and_locate_with_ord(ord);
let mut current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
let mut current_block_ordinal = current_block_addr.first_ordinal;
let mut current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX);
loop {
// move to the ord inside the current block
@@ -557,17 +533,19 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
}
};
// TODO optimization: it is silly to do a binary search to get the block every single
// time.
//
// Check if block changed for new term_ord
let new_block_addr = self.sstable_index.get_block_with_ord(next_ord);
if new_block_addr != current_block_addr {
if next_ord >= current_block_end_bound {
let (new_block_addr, block_id) =
self.sstable_index.get_and_locate_with_ord(next_ord);
current_block_addr = new_block_addr;
current_block_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX)
}
ord = next_ord;
}

319
sstable/src/index/mod.rs Normal file
View File

@@ -0,0 +1,319 @@
pub(crate) mod v2;
pub(crate) mod v3;
use std::io::{self, Read, Write};
use std::ops::Range;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_fst::{Automaton, MapBuilder};
use crate::{TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(v2::SSTableIndex),
V3(v3::SSTableIndexV3),
V3Empty(v3::SSTableIndexV3Empty),
}
impl SSTableIndex {
pub(crate) fn open(
version: u32,
index_offset: u64,
index_bytes: OwnedBytes,
) -> io::Result<Self> {
let index = match version {
2 => {
SSTableIndex::V2(v2::SSTableIndex::load(index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?)
}
3 => {
let (index_bytes, mut footerv3_len_bytes) = index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(v3::SSTableIndexV3::load(index_bytes, store_offset).map_err(
|_| io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption"),
)?)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(v3::SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
Ok(index)
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_and_locate_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_and_locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_and_locate_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = v3::BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
}
}
#[cfg(test)]
mod tests {
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
}

View File

@@ -77,6 +77,13 @@ impl SSTableIndex {
self.get_block(self.locate_with_ord(ord)).unwrap()
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let location = self.locate_with_ord(ord);
// locate_with_ord always returns an index within range
let block_addr = self.get_block(location).unwrap();
(block_addr, location as u64)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,

View File

@@ -1,106 +1,14 @@
use std::io::{self, Read, Write};
use std::ops::Range;
use std::sync::Arc;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_bitpacker::{BitPacker, compute_num_bits};
use tantivy_fst::raw::Fst;
use tantivy_fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer};
use tantivy_fst::{Automaton, IntoStreamer, Map, Streamer};
use super::{BlockAddr, BlockStartAddr};
use crate::block_match_automaton::can_block_match_automaton;
use crate::{SSTableDataCorruption, TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(crate::sstable_index_v2::SSTableIndex),
V3(SSTableIndexV3),
V3Empty(SSTableIndexV3Empty),
}
impl SSTableIndex {
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
use crate::{SSTableDataCorruption, TermOrdinal};
#[derive(Debug, Clone)]
pub struct SSTableIndexV3 {
@@ -160,6 +68,11 @@ impl SSTableIndexV3 {
self.block_addr_store.binary_search_ord(ord).1
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let (location, block_addr) = self.block_addr_store.binary_search_ord(ord);
(block_addr, location)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
@@ -216,7 +129,7 @@ impl<A: Automaton> Iterator for GetBlockForAutomaton<'_, A> {
#[derive(Debug, Clone)]
pub struct SSTableIndexV3Empty {
block_addr: BlockAddr,
pub block_addr: BlockAddr,
}
impl SSTableIndexV3Empty {
@@ -230,8 +143,8 @@ impl SSTableIndexV3Empty {
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, _block_id: u64) -> Option<BlockAddr> {
Some(self.block_addr.clone())
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
(block_id == 0).then(|| self.block_addr.clone())
}
/// Get the block id of the block that would contain `key`.
@@ -256,146 +169,9 @@ impl SSTableIndexV3Empty {
pub(crate) fn get_block_with_ord(&self, _ord: TermOrdinal) -> BlockAddr {
self.block_addr.clone()
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
pub(crate) fn get_and_locate_with_ord(&self, _ord: TermOrdinal) -> (BlockAddr, u64) {
(self.block_addr.clone(), 0)
}
}
@@ -647,14 +423,14 @@ fn binary_search(max: u64, cmp_fn: impl Fn(u64) -> std::cmp::Ordering) -> Result
Err(left)
}
struct BlockAddrStoreWriter {
pub(crate) struct BlockAddrStoreWriter {
buffer_block_metas: Vec<u8>,
buffer_addrs: Vec<u8>,
block_addrs: Vec<BlockAddr>,
}
impl BlockAddrStoreWriter {
fn new() -> Self {
pub(crate) fn new() -> Self {
BlockAddrStoreWriter {
buffer_block_metas: Vec::new(),
buffer_addrs: Vec::new(),
@@ -662,7 +438,7 @@ impl BlockAddrStoreWriter {
}
}
fn flush_block(&mut self) -> io::Result<()> {
pub(crate) fn flush_block(&mut self) -> io::Result<()> {
if self.block_addrs.is_empty() {
return Ok(());
}
@@ -741,7 +517,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
pub(crate) fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
self.block_addrs.push(block_addr);
if self.block_addrs.len() >= STORE_BLOCK_LEN {
self.flush_block()?;
@@ -749,7 +525,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
pub(crate) fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
self.flush_block()?;
let len = self.buffer_block_metas.len() as u64;
len.serialize(wrt)?;
@@ -824,8 +600,9 @@ mod tests {
use common::OwnedBytes;
use super::*;
use crate::SSTableDataCorruption;
use crate::block_match_automaton::tests::EqBuffer;
use crate::index::BlockMeta;
use crate::{SSTableDataCorruption, SSTableIndexBuilder};
#[test]
fn test_sstable_index() {
@@ -874,36 +651,7 @@ mod tests {
assert!(matches!(data_corruption_err, SSTableDataCorruption));
}
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
// use proptest::prelude::*;
#[test]
fn test_find_best_slop() {

View File

@@ -47,9 +47,8 @@ pub mod merge;
mod streamer;
pub mod value;
mod sstable_index_v3;
pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3};
mod sstable_index_v2;
mod index;
pub use index::{BlockAddr, SSTableIndex, SSTableIndexBuilder};
pub(crate) mod vint;
pub use dictionary::{Dictionary, TermOrdHit};
pub use streamer::{Streamer, StreamerBuilder};

View File

@@ -27,7 +27,7 @@ rand = "0.9"
zipf = "7.0.0"
rustc-hash = "2.1.0"
proptest = "1.2.0"
binggan = { version = "0.16.1" }
binggan = { version = "0.17.0" }
rand_distr = "0.5"
[features]