Compare commits

...

97 Commits

Author SHA1 Message Date
Luca Cominardi
f154c96b75 feat: expose IndexMerger and SegmentDocIdMapping for external doc-order control
Add `merge_segments_with_doc_id_mapping`, a public variant of
`merge_filtered_segments` that accepts a caller-supplied `SegmentDocIdMapping`
so that consumers can drive the final document order without a persistent sort
field.  Also expose `IndexMerger`, `SegmentDocIdMapping`, and the new
`write_with_doc_id_mapping` method in the public API.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-22 10:44:40 +02:00
Mithun Chicklore Yogendra
92b519d177 fix: preserve NULL ordering during numeric segment merges (#106)
The merger used first_or_default_col(0u64) for numeric sort fields,
making NULLs indistinguishable from zero during the k-way merge and
interleaving them with zeros. Use Column<u64>::first() (Option<u64>)
and match None/Some like the Str/Bytes path, and account for nullable
columns in the disjunct-stacking check.
2026-06-18 12:11:53 +02:00
Mithun Chicklore Yogendra
c09eacb994 fix: use native typed comparison for numeric sort keys (#105)
ColumnarWriter::sort_order() computed the sort key as
`f64::coerce(nv) as f32`, which has only 24 bits of mantissa: values
above 2^24 could collide (e.g. 16_777_216 and 16_777_217), and the 0.0
default made NULL indistinguishable from zero. Compare each
NumericalValue in its native type (u64/i64 cmp, f64 total_cmp), with
NULL ordered separately from Some(0).
2026-06-18 12:11:53 +02:00
Mithun Chicklore Yogendra
b59fac74bb feat: enable sort_by for Str/Bytes fast fields (#101)
Adds support for sorting by Str and Bytes fast fields on both
single-segment writes and cross-segment merges. Dictionary-encoded
fields use per-segment ordinals, so segment dictionaries are merged via
the columnar TermMerger to compute per-segment local-ord -> global-ord
mappings; the remapped u64 ordinals are then compared during kmerge.
2026-06-18 12:11:53 +02:00
Stu Hood
47ff79e2fc feat: Restore sort by field (#92)
Restores index sorting: an index can be physically ordered by a fast
field (IndexSettings::sort_by_field), applied both on single-segment
writes (stream to TempStore, then resort into Store) and on
cross-segment merges (disjunct fast path, otherwise a full k-merge with
a generated doc-id mapping).

Motivation (downstream join / sortedness work):
* https://github.com/paradedb/paradedb/issues/2997
* https://github.com/paradedb/paradedb/issues/3053

Re-expressed against current upstream APIs: no merge CancelSentinel,
MergeOptimizedInvertedIndexReader, or ignore_store plumbing.
2026-06-18 12:11:53 +02:00
Ming Ying
7b90642011 foundation: restore SegmentComponent::TempStore (revert #2815)
Index sorting streams documents to a temporary store before resorting
them into the final Store. Upstream #2815 ("Remove temp file") removed
SegmentComponent::TempStore as part of the index-sorting removal
cleanup; this restores it (the upstream-original mechanism, including
the include_temp_doc_store garbage-collection tracking).

This also subsumes paradedb/tantivy #104 ("Index sort test in CI"),
whose fix was re-adding TempStore to the component iterator and the
list_files GC handling.
2026-06-18 12:11:52 +02:00
Pascal Seitz
c096b2ad89 aggregation/terms: charge fused term_counts to the memory limit
term_counts (one u32/term) was allocated but not charged to
AggregationLimitsGuard, so a memory limit could be exceeded silently.
Charge it, skip allocating it when unbounded, and add a regression test.
2026-06-16 21:23:23 +08:00
Pascal Seitz
ac7a3d347c add comment, hoist variables 2026-06-16 21:23:23 +08:00
Pascal Seitz
03520a0719 add top level comment 2026-06-16 21:23:23 +08:00
Pascal Seitz
86a4c47bed merge loops, histo with bounds may benefit from single vec opt 2026-06-16 21:23:23 +08:00
Pascal Seitz
fb23e8908f add histogram with bounds 2026-06-16 21:23:23 +08:00
Pascal Seitz
3ca510dff0 aggregation/terms: tidy fused term×histogram grid construction
Rename the value threaded through build_segment_term_collector and
maybe_build_collector from max_term_id to col_max_val/max_column_val — it
is the column's max value, only later reused as the max term id. Make the
grid-size arithmetic overflow-/zero-safe (saturating_add, checked_div).
2026-06-16 21:23:23 +08:00
Pascal Seitz
3cb400c300 clarify counts/term_counts field docs
Spell out that `counts` is the flattened per-term × time-bucket grid (each
term's own contiguous slice) and that `term_counts` is only needed when the
per-term total can't be derived from that grid (i.e. with hard bounds).
2026-06-16 21:23:23 +08:00
Pascal Seitz
ef13489d63 skip hard_bounds that can't exclude any value
When a histogram's hard_bounds are wider than the column's value range, the
per-doc `bounds.contains` check can never fail. Collapse such bounds to the
unbounded sentinel in `normalize_histogram_req`, so both the general histogram
hot loop and the fused term×histogram path skip the check — the latter then
derives per-term counts from the grid (the ~17% win) instead of falling back to
per-doc counting just because `bounds != [MIN, MAX]`.

Only the collect-time filter is affected: empty-bucket emission reads
`req.hard_bounds` directly, and hard_bounds only ever clips that range, so a
wider-than-data bound leaves results unchanged. Covered by new tests on the
general and fused paths, including mid-interval (bucket-splitting) bounds.

Also tighten the fused-path u32-overflow guard to bound on `num_vals()` (the
per-value increment count) rather than `num_docs()`, and document why the fused
collector's hot-loop fields are hoisted into locals (re-reading them from memory
each iteration measured ~15% slower).
2026-06-16 21:23:23 +08:00
Pascal Seitz
9f7aea4765 derive term counts 2026-06-16 21:23:23 +08:00
Pascal Seitz
2c8536ab11 add specialized TermHistogram 2026-06-16 21:23:23 +08:00
Pascal Seitz
05f4c02ac5 add dense histogram, optional sub-buckets 2026-06-16 21:23:23 +08:00
Pascal Seitz
d137779219 add no sub-gg fastpath 2026-06-16 21:23:23 +08:00
Pascal Seitz
8f9846ac80 use get_range when possible 2026-06-16 21:23:23 +08:00
Pascal Seitz
52e24a9757 add status -> date histogram bench 2026-06-16 21:23:23 +08:00
trinity-1686a
00714326af Merge pull request #2960 from Darkheir/fix/query_grammar_boost_and_escape
fix(query-grammar): Fix issues on boosted and regex queries
2026-06-16 12:03:23 +02:00
Mohammad Dashti
799f7b4646 Built SUM final result in each branch directly.
Keeps the empty-bucket coercion visible at the boundary instead of a
shared binding, following the reviewer's suggested shape.
2026-06-16 03:10:30 +08:00
Mohammad Dashti
fc88d80726 docs: drop downstream-specific name from none_if_no_match doc
The flag's purpose is described well enough by "SQL-style consumers";
no need to call out a specific downstream.
2026-06-16 03:10:30 +08:00
Mohammad Dashti
6a684e7c38 feat: opt-in none_if_no_match flag on SumAggregation for SQL-style null
Switch the default serialized output of `sum` on empty / all-missing
buckets back to `"value": 0` to match Elasticsearch, and gate the
SQL-style `"value": null` behavior behind a new
`none_if_no_match: Option<bool>` flag on `SumAggregation`.

`IntermediateSum::finalize` still returns `Option<f64>` internally so
the Rust API stays parallel to min/max/avg, but the ES-vs-SQL choice is
made at the boundary in `IntermediateMetricResult::into_final_metric_result`:
`None` is coerced to `Some(0.0)` unless `none_if_no_match` is set on the
aggregation request.

Adds `AggregationVariants::as_sum()` accessor for that boundary check
and two end-to-end tests covering both the default ES behavior and the
opt-in null behavior on an empty index.
2026-06-16 03:10:30 +08:00
Mohammad Dashti
94fe52cc67 docs: clarify SUM finalize returning None diverges from Elasticsearch
Surface the trade-off in the doc comment so future reviewers see why
this differs from ES (which returns "value": 0 for sum over
empty/all-missing buckets) and what consumers (ParadeDB SQL NULL) the
None variant is meant to serve.
2026-06-16 03:10:30 +08:00
Mohammad Dashti
2ff39f6f7f fix: return None from SUM when no values were collected
IntermediateSum::finalize() returned Some(0.0) even when count==0
(all documents had missing/NULL values). This differs from MIN, MAX,
and AVG which all return None for count==0.

The 0.0 came from IntermediateStats' default sum initialization.
Consumers (like ParadeDB) that map None to SQL NULL were incorrectly
getting 0 for SUM on all-NULL groups.

Fixes paradedb/paradedb#4621
2026-06-16 03:10:30 +08:00
Windforce17
1d06328cb3 Add BlockSegmentPostings::rank() for skip-list-based positional counting
Add a public rank(target) method on BlockSegmentPostings that returns the
number of docs with a doc id strictly smaller than target. It jumps to the
candidate block through the skip list and decodes a single block, so the cost
is O(skip-list entries) + one block decode rather than O(doc_freq).

This is a useful primitive for range counting over a posting list (e.g. number
of matches in a [lo, hi) doc-id window) without iterating every matched doc.

To support it, expose SkipReader::remaining_docs() (pub(crate)). Like seek(),
rank() advances the cursor forward only and must be called with non-decreasing,
valid (<= TERMINATED) targets. Adds a unit test covering multi-block lists and
the below-first / above-last / empty edge cases.
2026-06-15 18:56:49 +08:00
Darkheir
7fd1dbe9f5 fix(query-grammar): Fix issues on boosted and regex queries
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2026-06-15 10:50:07 +02:00
Pascal Seitz
b19f0ddc77 fix clippy 2026-06-09 23:14:12 +08:00
Pascal Seitz
b4acfcf881 cleanup AggregationsSegmentCtx
The metric/cardinality/histogram _mut getters had no callers needing
mutation; their two uses already pass the resulting reference as &T.

simplify req_data ownership: clone into collectors, Rc only for filter BitSet

Replace Vec<Option<Box<T>>> + take/put-back round-trip with Vec<T> +
direct clone into collector. Collectors now own their per-segment
request data outright, removing the borrow-checker dance that the
take/put-back pattern existed to satisfy.

The structural clones are cheap (Column<u64> is Arc-internal) except
for the filter aggregation, whose DocumentQueryEvaluator carries a
precomputed per-segment BitSet sized by max_doc. Wrap that in
Rc<DocumentQueryEvaluator> so FilterAggReqData::clone() bumps a
refcount instead of duplicating the BitSet. Move SegmentFilterCollector's
matching_docs_buffer out of FilterAggReqData so its pre-allocated
capacity is preserved per collector instead of being lost on every clone.
2026-06-09 23:14:12 +08:00
dependabot[bot]
3a8240b123 Bump codecov/codecov-action from 6.0.0 to 7.0.0
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 6.0.0 to 7.0.0.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](57e3a136b7...fb8b3582c8)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: 7.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-09 14:48:17 +08:00
dependabot[bot]
fd9713e1ca Bump actions/checkout from 6.0.2 to 6.0.3 (#2949)
Bumps [actions/checkout](https://github.com/actions/checkout) from 6.0.2 to 6.0.3.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](de0fac2e45...df4cb1c069)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 6.0.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-08 10:55:54 +02:00
dependabot[bot]
96f3784f79 Bump github/codeql-action from 4.35.2 to 4.36.1 (#2948)
Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.35.2 to 4.36.1.
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](95e58e9a2c...87557b9c84)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-version: 4.36.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-08 10:49:04 +02:00
dependabot[bot]
87a6679a79 Bump actions/upload-artifact from 7.0.0 to 7.0.1 (#2917)
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 7.0.0 to 7.0.1.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](bbbca2ddaa...043fb46d1a)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: 7.0.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-08 10:48:48 +02:00
dependabot[bot]
864a6aa72c Update murmurhash32 requirement from 0.3 to 0.4 (#2894)
Updates the requirements on [murmurhash32](https://github.com/quickwit-inc/murmurhash32) to permit the latest version.
- [Commits](https://github.com/quickwit-inc/murmurhash32/commits)

---
updated-dependencies:
- dependency-name: murmurhash32
  dependency-version: 0.4.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-08 10:48:32 +02:00
Paul Masurel
abcf6754a2 CR comments from https://github.com/quickwit-oss/tantivy/pull/2940 (#2952)
Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-06-08 10:47:58 +02:00
Kanishk Sachan
70a8e56ee5 test(postings): add unit tests for TermFrequencyRecorder
Closes #2285

The TermFrequencyRecorder was completely untested. Add five focused tests:

- term_frequency_recorder_has_term_freq: verifies the recorder
  correctly advertises term-frequency support via has_term_freq()
- term_frequency_recorder_zero_docs: term_doc_freq() returns Some(0)
  before any documents are recorded
- term_frequency_recorder_term_doc_freq_single_doc: one document with
  two occurrences yields term_doc_freq() == Some(1)
- term_frequency_recorder_term_doc_freq_multiple_docs: three documents
  with varying term frequencies yield term_doc_freq() == Some(3),
  confirming the count tracks documents, not occurrences
- term_frequency_recorder_single_occurrence_per_doc: each of three
  documents has exactly one occurrence
- term_frequency_recorder_high_frequency_doc: a single document with
  1000 occurrences still yields term_doc_freq() == Some(1)
2026-06-06 14:44:51 +08:00
Paul Masurel
62705526e8 Add sve + neon filter vec implementation as spotted by Adam (#2940)
* Add filter_vec benchmarks (dense, sparse, full coverage)

Uses get_ids_for_value_range to exercise both the bitpacking decode and
the filter_vec SIMD path together under realistic cache conditions.

* Add NEON and SVE implementations for filter_vec

Adds aarch64-specific SIMD paths (NEON always available on aarch64;
SVE gated on nightly + non-Apple target) with routing logic in mod.rs
that selects the best available instruction set at runtime.

* Using asm! to workaround the lack of stabilized SVE intrinsics

* showing instruction set

* improved proptesting

* removing build.rs

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-06-04 17:51:26 +02:00
Paul Masurel
a27c64998f Cargo clippy fix (#2943)
Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-06-01 14:39:44 +02:00
Paul Masurel
46b3fb9ed3 Relying on upstream version of datasketch and stop using HLL 4. (#2936)
We were relying on a fork for:

a bugfix in LIST serialization
a better API exposing a new Coupon type, required for caching coupons.
We also stop using HLL8 in hope to fix
https://datadoghq.atlassian.net/browse/CLOUDPREM-625

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-05-19 13:29:35 +02:00
trinity-1686a
fbe620b9b4 Merge pull request #2933 from quickwit-oss/1686a/sstable-opt
optimise sstable index access pattern
2026-05-19 11:43:17 +02:00
trinity-1686a
95d8a3989a cr 2026-05-19 11:38:48 +02:00
trinity-1686a
ea61a68db4 skip sstable index binary search when ordinal is in same block 2026-05-16 11:35:38 +02:00
trinity-1686a
c367df37c1 refactor sstable index 2026-05-16 11:30:02 +02:00
Mohammad Dashti
d99a5d4e91 Rename validate_aggregation_fields to validate_aggregation_fields_exist
Applies @PSeitz's review suggestion to make the function name more
descriptive of what it checks. Also adds a doc note clarifying why
validation is opt-in rather than enforced by default.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
2de6f075ce Fixed the example 2026-05-16 15:45:20 +08:00
Mohammad Dashti
18080067c7 Applied PR comment:
I would move it outside of the aggregation. You can fetch the fields from the aggregation request and do a validation in a helper function
2026-05-16 15:45:20 +08:00
Mohammad Dashti
95db7d2e5c Revert "Revert all impl."
This reverts commit d5e0991549a05bf80f19f853f7689ad69f96e7e5.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
fc017c4c74 Applied PR comments. 2026-05-16 15:45:20 +08:00
Mohammad Dashti
141c91d028 Added a flag: strict_validation 2026-05-16 15:45:20 +08:00
Mohammad Dashti
36a83e7c1a Fixed agg validation 2026-05-16 15:45:20 +08:00
jinhelin
be11f8a6a1 Fix opening positions file error 2026-05-14 15:55:59 +08:00
dependabot[bot]
4305e4029e Update binggan requirement from 0.16.1 to 0.17.0
Updates the requirements on [binggan](https://github.com/pseitz/binggan) to permit the latest version.
- [Changelog](https://github.com/PSeitz/binggan/blob/main/CHANGELOG.md)
- [Commits](https://github.com/pseitz/binggan/commits)

---
updated-dependencies:
- dependency-name: binggan
  dependency-version: 0.17.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-12 15:10:20 +08:00
Pascal Seitz
edfb02b47e switch to enum, fix mixed types for cardinality agg 2026-05-05 16:39:51 +08:00
Pascal Seitz
d0fad88bac use bitsets for card agg 2026-05-05 16:39:51 +08:00
Pascal Seitz
351280c0b4 add card bench for high card 2026-05-05 16:39:51 +08:00
James Sewell
4480cf0a98 Enable BMW for single-scorer boolean queries by removing early return in scorer_union (#2915)
The early return for `scorers.len() == 1` in `scorer_union` short-circuits a single TermScorer into `SpecializedScorer::Other`, bypassing the `TermUnion` path that enables block-max WAND (BMW) in `for_each_pruning`.

This was originally addressed in PR #2898 (backed out), which added a special case in `BooleanWeight::for_each_pruning`. PR #2912 (merged as d27ca164a) added a single-scorer fast path inside `block_wand` itself, but did not remove this early return — so a single SHOULD TermScorer still never reaches the BMW path.

Removing the early return lets a single TermScorer with freq reading flow through to `SpecializedScorer::TermUnion`, where `block_wand` → `block_wand_single_scorer` handles it efficiently.
2026-04-28 14:49:53 -07:00
Pascal Seitz
d47abdf104 early cut off for order by sub agg in term agg 2026-04-28 16:59:59 +02:00
Pascal Seitz
c11952eb7c add order by agg benchmark 2026-04-28 16:59:59 +02:00
trinity-1686a
09667ee9c8 Merge pull request #2909 from osyniakov/claude/add-ossf-scorecard-1z6Vn
Add OpenSSF Scorecard workflow
2026-04-28 11:57:36 +02:00
trinity-1686a
333ccf5300 Merge pull request #2896 from osyniakov/claude/fix-issues-5945-5937-eQm1Q
ci: pin GitHub Actions to full commit SHAs and restrict token permissions
2026-04-28 11:57:18 +02:00
Oleksii Syniakov
60a39a4689 Merge branch 'main' into claude/fix-issues-5945-5937-eQm1Q 2026-04-28 10:28:23 +02:00
Oleksii Syniakov
f8f3e4277f remove not neeeded permissions for the public repo 2026-04-28 10:09:30 +02:00
Oleksii Syniakov
ff1433713a bump upload-sarif -> 4.35.2
Co-authored-by: trinity-1686a <trinity.pointard@gmail.com>
2026-04-28 10:07:45 +02:00
trinity-1686a
ca139d8eb1 Merge pull request #2910 from quickwit-oss/abdul.andha/composite-agg-after
Composite aggregations: send after key on last page
2026-04-27 23:38:52 +02:00
Abdul Andha
ac508108aa address pr comment 2026-04-27 12:39:38 -04:00
Paul Masurel
63da5a21b2 Optimizing top K using Adrien Grand's ideas (#2865)
* Optimizing top K using Adrien Grand's ideas

https://jpountz.github.io/2025/08/28/compiled-vs-vectorized-search-engine-edition.html

* Suffix-sum pruning for multi-term intersection candidates

After scoring each secondary in Phase 2, check whether remaining
secondaries' block_max scores can still beat the threshold. Skip
to the next candidate early if impossible, avoiding expensive seeks
into later secondaries.

Improves three-term intersection by ~8% on the balanced benchmark
while keeping two-term performance neutral.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Claude CR comment

* Removed 16 term scorer limit.

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-26 12:14:40 +02:00
lif
54cd5bba98 fix: skip sentinel facet ords in harvest to prevent wrong root (#2867)
When a document has the exact registered facet path (not a child),
compute_collapse_mapping_one maps it to a sentinel (u64::MAX, 0).
Without filtering, harvest() passes u64::MAX to ord_to_term which
resolves to the last dictionary entry, producing a spurious facet
from an unrelated branch.

Skip entries where facet_ord == u64::MAX in harvest().

Closes #2494

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-04-25 22:23:30 +02:00
Paul Masurel
d27ca164a9 block_wand: use single-scorer path when there is only one scorer 2026-04-25 16:35:00 +02:00
dependabot[bot]
2f5a48e8b1 Update criterion requirement from 0.5 to 0.8 (#2873)
Updates the requirements on [criterion](https://github.com/criterion-rs/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/criterion-rs/criterion.rs/releases)
- [Changelog](https://github.com/criterion-rs/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/criterion-rs/criterion.rs/compare/0.5.0...criterion-v0.8.2)

---
updated-dependencies:
- dependency-name: criterion
  dependency-version: 0.8.2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:15:53 +02:00
dependabot[bot]
ae0ab907fe Bump actions/checkout from 4 to 6 (#2875)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:15:27 +02:00
dependabot[bot]
7d62e084e7 Bump codecov/codecov-action from 3 to 6 (#2876)
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 6.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v3...v6)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-25 14:14:54 +02:00
James Sewell
322286ee16 Tighen Block-Max in single-scorer (#2897)
In the Block-Max WAND single-scorer, it uses block_max_score() < threshold,
whereas the multi-term one uses  block_max_score_upperbound <= threshold.

As both of these are guarded later on with if score > threshold we can
use the more efficent form in single-scorer.

Single-scorer block skip (<, should be <=): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L231
Multi-scorer block skip (already <=): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L179
Single-scorer per-doc guard (>): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L246
Multi-scorer per-doc guard (>): https://github.com/quickwit-oss/tantivy/blob/main/src/query/boolean_query/block_wand.rs#L206

This will improve performance when there are many identical scores.
2026-04-25 14:13:07 +02:00
RJ Barman
73ad18fa1e fix: Add space for missing sentinel in allowed bitset when a missing key is provided (#119) (#2907)
## Bug Overview
Under certain conditions, a `terms` aggregation request can cause a
bounds-check panic. Those conditions are:
- The queried field must be a text field
- There must be a segment where the number of distinct terms in it's
dictionary for the queried field is divisible by 64 (i.e.e where
`count(term_dict.keys) % 64 == 0`)
- That same segment must contain at least one document that does not
contain this field.
- The request contain a `missing` key that is a string.
- The request must contain an `include` or `exclude` filter.
For example:
```json
{
    "my_bool": {
        "terms": {
            "field": "title",
            "include": "foo",
            "missing": "__NULL__",
        }
    }
}
```
Check out the added tests in `src/aggregation/bucket/term_agg.rs` to see
this in action

## How the bug happens
### Preparation
While preparing the aggregation nodes:
1) When we've provided a `missing` key, we derive a missing sentinel.
For string keys this column's max value (which for string keys is always
the number of terms in this segment) + 1.
2) for string columns only, we optionally prep an "allowed" `BitSet` for
allowed term ids. (`build_allowed_term_ids_for_str` in
`src/aggregation/agg_data.rs`)
- If no `include` or `exclude` filter is provided, this just returns
`None`, causing this check to be skipped down the line
- Otherwise the bitset is initialized to be able to hold the exact
number of terms in the segments term dictionary, and the bits are set to
signify which terms are to be included in the results.

### Collection
If we have an "allowed" `BitSet`, filter documents against that. For
each document, we check if the `BitSet` contains the documents term id.
For documents without the field, this is the missing sentinel we derived
earlier, minus 1 (to account for zero-based indexing): `(num_terms + 1)
- 1`.However, the `BitSet`s size is only `num_terms`. Normally, this
slips by without a problem, but if `num_terms % 64 == 0`, this will
cause a panic.

### Why `BitSet` panics
`BitSet` is represented under the hood by a boxed slice of `u64`s. When
you go to check a bit using `BitSet::contains`, it must determine which
of those `u64`s the bit is in, and then the position within that `u64`
of the bit.

In cases where the number of terms is not divisible by 64, the `BitSet`
must waste some bits. When we then look up the missing sentinel's bit,
it happens to be one of those wasted bits, for which `BitSet` is happy
to return the value of. For example, if the number of terms was 63:
```rust
let bitset_init_size = 63; // so BitSet's boxed slice has a length of 1, capable of holding 64 bits, term id [0, 62]
let missing_sentinel = 63; // num_terms + 1 - 1;
let byte_pos = missing_sentinel / 64; // 0 - within the valid slice
let bit_pos = missing_sentinel % 64; // 63 - hits the 1 wasted bit
```

But if the number of terms is indeed divisible by 64, then the `BitSet`
is perfectly aligned to the byte boundary:
```rust
let bitset_init_size = 64; // so BitSet's boxed slice has a length of 1, capable of holding 64 bits, term ids [0, 63]
let missing_sentinel = 64; // num_terms + 1 - 1, 
let byte_pos = missing_sentinel / 64; // 1 - idx 1 >= slice length 1
let bit_pos = missing_sentinel % 64; // 0 
```
We try to access a byte outside of the bounds of the boxed slice,
causing a panic from the bounds check to failing.

## Fixing it
The fix is simple. If we need to account for the missing sentinel,
initialize the `BitSet` with capacity for one more bit.

## Tests
- Added a bunch of unit tests that hit these conditions. I ensured they
failed without the fix, and that they now pass.
- All unit tests pass with the fix in place

## Other
- The investigation that led to finding this bug began with
https://github.com/paradedb/paradedb/issues/4746.
2026-04-25 14:11:47 +02:00
Abdul Andha
4fbae92187 send after key on last page 2026-04-24 15:33:26 -04:00
Cameron
89f0cef807 Fix O(2^n) query parser regression for deeply-nested queries (#2905)
* Fix O(2^n) query parser regression for deeply-nested queries

The top-level `ast()` parser used `alt((boolean_expr, single_leaf))` at
every group level. When the group contained a single leaf with no
trailing operand, `boolean_expr` would parse `occur_leaf` (recursing
into the inner group), fail at `multispace1`, backtrack, and then
`single_leaf` would re-parse `occur_leaf` from scratch. Every nesting
level doubled the work, giving O(2^n) time for queries like
`(((((title:test)))))`.

Parse `occur_leaf` once and peek ahead for a trailing operand instead
of backtracking. This keeps parsing O(n) and also avoids the duplicate
parse for simple single-leaf queries.

Fixes #2498.

Measured on the issue reproducer (release build):

    depth   before     after
       20   0.87 s   <1 us
       25  28.23 s   <1 us
       60  (years)   ~5 us

Non-pathological queries are unaffected or slightly faster:

    query                     before     after
    hello                     650 ns     308 ns
    a AND b AND c            1380 ns    1364 ns
    title:rust AND (...)     3426 ns    3460 ns

All 53 existing grammar tests and 56 query_parser tests pass. Adds a
regression test at depth 60 that would not complete under the old
parser.

* Add ignored benchmark for nested query parsing at depth 20/21

Matches the depths from issue #2498 which reported 0.87 s / 1.72 s
under the regression. With the fix these parse in single-digit
microseconds. Runs via:

  cargo test -p tantivy-query-grammar --release bench_deeply_nested \
      -- --ignored --nocapture

* Propagate Err::Failure and Err::Incomplete from operand parser

`alt((boolean_expr, single_leaf))` only retried on `Err::Error` and
propagated `Err::Failure` and `Err::Incomplete`. The replacement was
catching all three with `Err(_)`, which would silently fall back to
a single leaf if any cut point were ever added to `operand_leaf` or
its descendants. Match specifically on `Err::Error` to preserve the
original `alt` semantics.

* Replace inline bench with binggan bench in benches/

Move the nested-query benchmark out of the query-grammar test module
and into a proper binggan benchmark at benches/query_parser_nested.rs,
registered as a harnessless bench in Cargo.toml. Keeps the correctness
regression test (depth 60) in place.

Run with: cargo bench --bench query_parser_nested

* Fix rustfmt import ordering in query_parser_nested bench
2026-04-24 03:54:00 -04:00
Claude
a5d297c75f Add OpenSSF Scorecard workflow
Runs weekly security analysis and uploads SARIF results to GitHub code
scanning. Third-party actions are pinned by commit SHA. Adds the Scorecard
badge to the README.

Based on quickwit-oss/quickwit#5969.
2026-04-24 06:56:58 +00:00
Pascal Seitz
2e16243f9a fix memory consumption for histogram 2026-04-21 13:58:39 +02:00
Pascal Seitz
e015abab8e docs: add 0.26.1 changelog entry for aggregation perf fix 2026-04-21 11:12:37 +02:00
Pascal Seitz
73c711ec74 perf(agg): only measure active parent bucket in composite collect
Same change as 26a589e for SegmentCompositeCollector: get_memory_consumption
summed across all parent_buckets on every block, scaling with outer bucket
cardinality. Pass parent_bucket_id and index the single bucket.
2026-04-21 07:26:58 +02:00
Pascal Seitz
cb037c8079 add inline 2026-04-21 07:26:58 +02:00
Pascal Seitz
ed3453606b agg fix: compute memory consumption only for current bucket 2026-04-21 07:26:58 +02:00
Pascal Seitz
e9641f99c5 add nested term benchmark 2026-04-21 07:26:58 +02:00
Paul Masurel
13d74c3c20 Update binggan requirement from 0.16.0 to 0.16.1 (#2899) 2026-04-20 11:59:47 +02:00
Claude
3a6a3de8d7 ci: update pinned Action SHAs to current latest versions
The previous commit pinned actions to commit SHAs but used stale
version tags (v4.2.2, v2.7.5, old nextest/cargo-llvm-cov refs).
Update to the actual current HEAD of each pinned tag:

  actions/checkout        v4.2.2 → v4.3.1  (34e114876b0b...)
  Swatinem/rust-cache     v2.7.5 → v2.9.1  (c19371144df3...)
  taiki-e/install-action  nextest           (56cc9adf3a3e...)
  taiki-e/install-action  cargo-llvm-cov    (e4b3a0453201...)

actions-rs/toolchain, actions-rs/clippy-check, and
codecov/codecov-action SHAs were already correct.

https://claude.ai/code/session_01VD7Bo8upj3cQwWDf9ni2Ln
2026-04-16 06:49:47 +00:00
Claude
af3c6c0070 ci: pin GitHub Actions to full commit SHAs and restrict token permissions
Fixes two supply chain / token security issues:

- Pin all third-party Actions to immutable full commit SHAs instead of
  mutable version tags (addresses unpinned-dependencies risk, analogous
  to quickwit-oss/quickwit#5937):
    actions/checkout v4.2.2
    actions-rs/toolchain v1.0.7
    Swatinem/rust-cache v2.7.5
    taiki-e/install-action nextest / cargo-llvm-cov
    actions-rs/clippy-check v1.0.7
    codecov/codecov-action v3.1.6

- Add explicit least-privilege `permissions` blocks at workflow and job
  level (addresses excessive GITHUB_TOKEN permissions, analogous to
  quickwit-oss/quickwit#5945):
    default: contents: read
    check job: also grants checks: write (required by clippy-check)

https://claude.ai/code/session_01VD7Bo8upj3cQwWDf9ni2Ln
2026-04-15 20:55:43 +00:00
dependabot[bot]
058afff8b7 Update binggan requirement from 0.15.3 to 0.16.0
Updates the requirements on [binggan](https://github.com/pseitz/binggan) to permit the latest version.
- [Changelog](https://github.com/PSeitz/binggan/blob/main/CHANGELOG.md)
- [Commits](https://github.com/pseitz/binggan/commits)

---
updated-dependencies:
- dependency-name: binggan
  dependency-version: 0.16.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-15 08:58:03 +02:00
Paul Masurel
58aa4b7074 Fix cardinality aggregation using invalid coupons (#2893)
Previously, coupons were computed via murmurhash32 and fed as raw u32
to the HLL sketch, bypassing the sketch's internal hashing and producing
invalid (slot, value) pairs. Switch to Coupon::from_hash from the
datasketches crate which correctly derives coupons, and drop the
now-unused murmurhash32 dependency.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 19:14:30 +02:00
Paul Masurel
04beab3b29 Performance improvement for nested cardinality aggregation
When a string cardinality aggregation is nested it end up being applied to different buckets.
Dictionary encoding relies on a different dictionaries for each segment.

As a result, during segment collection, we only collect term ordinals in a HashSet, and decode them in the
term dictionary at the end of collection.

Before this PR, this decoding phase was done once for each bucket, causing the same work to be done over and over. This PR introduce a coupon cache. The HLL sketch relies on a hash of the string values.

We populate the cache before bucket collection, and get our values from it.

This PR also rename "caching" "buffering" in aggregation (it was never caching), and does several cleanups.
2026-04-10 14:51:00 +02:00
alexanderbianchi
3cd9011f87 Make BucketEntries::iter, PercentileValuesVecEntry fields, and TopNComputer::threshold public (#2890)
These items need to be accessible from the tantivy-datafusion crate:
- BucketEntries::iter() for iterating aggregation bucket results
- PercentileValuesVecEntry.key/.value for reading percentile results
- TopNComputer.threshold for Block-WAND score pruning in the inverted index provider

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Paul Masurel <paul@quickwit.io>
2026-04-09 13:32:31 +02:00
Paul Masurel
d2c1b8bc2c Optimized intersection count using a bitset when the first leg is dense 2026-04-06 12:01:52 -04:00
nuri
a65107135a Use BinaryHeap for score-based top-K collection (#2881)
* Use BinaryHeap for score-based top-K collection

* Use peek_mut and add proptest for TopNHeap

---------

Co-authored-by: nryoo <nryoo@nryooui-MacBookPro.local>
2026-04-04 19:49:05 +02:00
Pascal Seitz
5c344db1bf chore: Release 2026-03-31 17:15:34 +08:00
Pascal Seitz
dc0f31554d unbump for release and update Changelog.md 2026-03-31 17:15:34 +08:00
trinity-1686a
a28ce3ee54 Merge pull request #2869 from quickwit-oss/trinity.pointard/maint
add dependabot cooldown
2026-03-31 09:52:22 +02:00
dependabot[bot]
3abc137bfe Update binggan requirement from 0.14.2 to 0.15.3 (#2870)
Updates the requirements on [binggan](https://github.com/pseitz/binggan) to permit the latest version.
- [Commits](https://github.com/pseitz/binggan/commits)

---
updated-dependencies:
- dependency-name: binggan
  dependency-version: 0.15.3
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 07:59:02 +08:00
trinity Pointard
cf9800f981 add dependabot cooldown 2026-03-30 11:36:04 +02:00
114 changed files with 9816 additions and 1639 deletions

View File

@@ -6,6 +6,8 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2
- package-ecosystem: "github-actions"
directory: "/"
@@ -13,3 +15,5 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2

View File

@@ -4,6 +4,9 @@ on:
push:
branches: [main]
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -12,16 +15,20 @@ concurrency:
jobs:
coverage:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View File

@@ -8,6 +8,9 @@ env:
CARGO_TERM_COLOR: always
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -18,10 +21,13 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal

49
.github/workflows/scorecard.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
name: OpenSSF Scorecard
on:
schedule:
- cron: '0 0 * * 0'
push:
branches:
- main
permissions:
contents: read
jobs:
analysis:
name: Scorecards analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results
id-token: write
steps:
- name: 'Checkout code'
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with:
persist-credentials: false
- name: 'Run analysis'
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
repo_token: ${{ secrets.GITHUB_TOKEN }}
publish_results: true
# Upload the results as artifacts.
- name: 'Upload artifact'
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: 'Upload to code-scanning'
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4.36.1
with:
sarif_file: results.sarif

View File

@@ -9,6 +9,9 @@ on:
env:
CARGO_TERM_COLOR: always
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -19,23 +22,27 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Install nightly
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: nightly
profile: minimal
components: rustfmt
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal
components: clippy
- uses: Swatinem/rust-cache@v2
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
@@ -47,7 +54,7 @@ jobs:
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
- uses: actions-rs/clippy-check@v1
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
with:
toolchain: stable
token: ${{ secrets.GITHUB_TOKEN }}
@@ -57,6 +64,9 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
matrix:
features:
@@ -67,17 +77,17 @@ jobs:
name: test-${{ matrix.features.label}}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Install stable
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
with:
toolchain: stable
profile: minimal
override: true
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- name: Run tests
run: |

View File

@@ -1,3 +1,9 @@
Tantivy 0.26.1
================================
## Performance
- Fix quadratic runtime in nested term and composite aggregations: memory accounting scanned all parent buckets on every collect instead of just the current parent (@PSeitz @fulmicoton)
Tantivy 0.26 (Unreleased)
================================
@@ -45,6 +51,7 @@ Tantivy 0.26 (Unreleased)
- Add `seek_danger` on `DocSet` for more efficient intersections [#2538](https://github.com/quickwit-oss/tantivy/pull/2538) [#2810](https://github.com/quickwit-oss/tantivy/pull/2810)(@PSeitz @stuhood @fulmicoton)
- Skip column traversal in `RangeDocSet` when query range does not overlap with column bounds [#2783](https://github.com/quickwit-oss/tantivy/pull/2783)(@ChangRui-Ryan)
- Speed up exclude queries by supporting multiple excluded `DocSet`s without intermediate union [#2825](https://github.com/quickwit-oss/tantivy/pull/2825)(@PSeitz)
- Improve union performance for non-score unions with `fill_buffer` and optimized `TinySet` [#2863](https://github.com/quickwit-oss/tantivy/pull/2863)(@PSeitz)
Tantivy 0.25
================================

View File

@@ -57,15 +57,15 @@ measure_time = "0.9.0"
arc-swap = "1.5.0"
bon = "3.3.1"
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
stacker = { version = "0.6", path = "./stacker", package = "tantivy-stacker" }
query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tantivy-query-grammar" }
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
columnar = { version = "0.7", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.7", path = "./sstable", package = "tantivy-sstable", optional = true }
stacker = { version = "0.7", path = "./stacker", package = "tantivy-stacker" }
query-grammar = { version = "0.26.0", path = "./query-grammar", package = "tantivy-query-grammar" }
tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
datasketches = "0.2.0"
datasketches = { version = "0.3.0", features = ["hll"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -75,7 +75,7 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.2"
binggan = "0.17.0"
rand = "0.9"
maplit = "1.0.2"
matches = "0.1.9"
@@ -92,7 +92,7 @@ postcard = { version = "1.0.4", features = [
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
criterion = { version = "0.5", default-features = false }
criterion = { version = "0.8", default-features = false }
[dev-dependencies.fail]
version = "0.5.0"
@@ -201,3 +201,11 @@ harness = false
[[bench]]
name = "regex_all_terms"
harness = false
[[bench]]
name = "query_parser_nested"
harness = false
[[bench]]
name = "intersection_bench"
harness = false

View File

@@ -1,6 +1,7 @@
[![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/)
[![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy)
[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/quickwit-oss/tantivy/badge)](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
[![Join the chat at https://discord.gg/MT27AG5EVE](https://shields.io/discord/908281611840282624?label=chat%20on%20discord)](https://discord.gg/MT27AG5EVE)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy)

View File

@@ -63,7 +63,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_status_with_date_histogram);
register!(group, terms_status_with_date_histogram_hard_bounds);
register!(group, terms_status_with_date_histogram_and_sibling_terms);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
@@ -77,7 +82,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_100_buckets_with_cardinality_agg);
register!(group, terms_many_with_single_term_order_by_card);
register!(group, terms_many_with_single_term_2_order_by_card);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
@@ -165,10 +175,52 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a near-1M-cardinality string field.
// Hits the dense (PagedBitset) path: every doc has a unique term,
// so the bucket promotes from FxHashSet shortly into the scan.
fn cardinality_agg_high_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_all_unique_terms"
},
}
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
// values). Stays on the FxHashSet path — the promotion threshold is
// never crossed. Validates no regression on the sparse path.
fn cardinality_agg_low_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_100_buckets_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf", "size": 100 },
"aggs": {
"cardinality": {
"cardinality": {
@@ -181,6 +233,58 @@ fn terms_status_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many_with_single_term_order_by_card(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"nested_terms": {
"terms": {
"field": "single_term",
"order": { "cardinality": "desc" }
},
"aggs": {
"cardinality": {
"cardinality": { "field": "text_few_terms" }
}
}
}
}
},
});
execute_agg(index, agg_req);
}
// Two-level terms ordered by cardinality at each level: a high-card outer terms
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
let agg_req = json!({
"by_ip": {
"terms": {
"field": "text_many_terms",
"order": { "card_few_terms": "desc" }
},
"aggs": {
"card_few_terms": {
"cardinality": { "field": "text_few_terms" }
},
"nested_terms": {
"terms": {
"field": " single_term",
"order": { "distinct_path2": "desc" }
},
"aggs": {
"avg_botscore": { "avg": { "field": "score" } },
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
}
}
}
}
});
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
@@ -253,6 +357,30 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_terms_zipf_1000_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"nested_terms": { "terms": { "field": "text_1000_terms_zipf" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_terms_status_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"nested_terms": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -265,6 +393,57 @@ fn terms_status_with_histogram(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_status_with_date_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": { "date_histogram": { "field": "timestamp", "fixed_interval": "1h" } }
}
}
});
execute_agg(index, agg_req);
}
/// Same fused terms × date_histogram, but with `hard_bounds`. The timestamps span 0..120h; the
/// bounds drop only the first and last hour (ms: 1h=3_600_000, 119h=428_400_000), so almost every
/// doc is in-bounds. This exercises the collector's hard-bounds path: `bounds.contains` runs per
/// doc (the `all_docs_in_bounds` short-circuit is off) and the rare out-of-bounds doc takes the
/// `term_counts` branch.
fn terms_status_with_date_histogram_hard_bounds(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": "1h",
"hard_bounds": { "min": 3_600_000, "max": 428_400_000 }
}
}
}
}
});
execute_agg(index, agg_req);
}
/// Same fused terms × date_histogram, but with a sibling terms aggregation next to it. The fused
/// fast path should still trigger for `my_texts` (sibling aggregations are independent top-level
/// aggregations, so they don't change its eligibility).
fn terms_status_with_date_histogram_and_sibling_terms(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": { "date_histogram": { "field": "timestamp", "fixed_interval": "1h" } }
}
},
"other_texts": { "terms": { "field": "text_few_terms" } }
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -566,7 +745,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let single_term = schema_builder.add_text_field("single_term", FAST);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
@@ -630,6 +810,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
single_term => "single_term",
single_term => "single_term",
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
@@ -655,7 +837,9 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
doc_with_value /= 20;
}
let _val_max = 1_000_000.0;
for _ in 0..doc_with_value {
const SPAN_MS: i64 = 120 * 3600 * 1000; // 120 hours in ms
const NOISE_MS: i64 = 2 * 3600 * 1000; // ±2h noise
for i in 0..doc_with_value {
let val: f64 = rng.random_range(0.0..1_000_000.0);
let json = if rng.random_bool(0.1) {
// 10% are numeric values
@@ -663,7 +847,11 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
} else {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
let base_ms = (i as i64 * SPAN_MS) / doc_with_value as i64;
let noise_ms = rng.random_range(-NOISE_MS..NOISE_MS);
let ts_ms = (base_ms + noise_ms).clamp(0, SPAN_MS);
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
@@ -674,7 +862,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
date_field => DateTime::from_timestamp_millis(ts_ms),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {

View File

@@ -0,0 +1,149 @@
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
//
// What's measured:
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
// - Realistic term frequencies (geometric distribution, mostly low)
// - 1M-doc single segment
//
// Run with: cargo bench --bench intersection_bench
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
const NUM_DOCS: usize = 1_000_000;
struct BenchIndex {
searcher: Searcher,
query_parser: QueryParser,
}
/// Generate term frequency from a geometric-like distribution.
/// Most values are 1, a few are 2-3, rarely higher.
/// p controls the decay: higher p → more weight on tf=1.
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
let mut tf = 1u32;
while tf < 10 && rng.random_bool(1.0 - p) {
tf += 1;
}
tf
}
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
/// Each term occurrence has a realistic term frequency (geometric distribution).
/// Field length is padded with filler tokens to create varied fieldnorms.
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut rng = StdRng::from_seed([42u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..NUM_DOCS {
let mut tokens: Vec<String> = Vec::new();
if rng.random_bool(p_a) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("aaa".to_string());
}
}
if rng.random_bool(p_b) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("bbb".to_string());
}
}
if rng.random_bool(p_c) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("ccc".to_string());
}
}
// Pad with filler to create varied field lengths (5-30 tokens).
let filler_count = rng.random_range(5u32..30u32);
for _ in 0..filler_count {
tokens.push("filler".to_string());
}
let text = tokens.join(" ");
writer.add_document(doc!(body => text)).unwrap();
}
writer.commit().unwrap();
}
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![body]);
BenchIndex {
searcher,
query_parser,
}
}
fn main() {
// Scenarios: (label, p_a, p_b, p_c)
//
// "balanced": all terms ~10% → intersection ~1% of docs
// "skewed": one common (50%), one rare (2%) → intersection ~1%
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
// "three_balanced": three terms ~20% each → intersection ~0.8%
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
("balanced_10%_10%", 0.10, 0.10, 0.0),
("skewed_50%_2%", 0.50, 0.02, 0.0),
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
];
let mut runner = BenchRunner::new();
for (label, p_a, p_b, p_c) in &scenarios {
let bench_index = build_index(*p_a, *p_b, *p_c);
let mut group = runner.new_group();
group.set_name(format!("intersection — {label}"));
// Two-term intersection
if *p_a > 0.0 && *p_b > 0.0 {
let query_str = "+aaa +bbb";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
// Three-term intersection
if *p_c > 0.0 {
let query_str = "+aaa +bbb +ccc";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
group.run();
}
}

View File

@@ -0,0 +1,35 @@
// Benchmark for the query grammar parsing deeply nested queries.
//
// Regression guard for https://github.com/quickwit-oss/tantivy/issues/2498:
// at depth 20/21 the old parser took 0.87 s / 1.72 s respectively because
// `ast()` retried `occur_leaf` on backtrack, giving O(2^n) time. With the
// fix parsing is linear and completes in microseconds.
//
// Run with: `cargo bench --bench query_parser_nested`.
use binggan::{black_box, BenchRunner};
use tantivy::query_grammar::parse_query;
fn nested_query(depth: usize, leading_plus: bool) -> String {
let leading = "(".repeat(depth);
let trailing = ")".repeat(depth);
let prefix = if leading_plus { "+" } else { "" };
format!("{prefix}{leading}title:test{trailing}")
}
fn main() {
let mut runner = BenchRunner::new();
for depth in [20, 21] {
for leading_plus in [false, true] {
let query = nested_query(depth, leading_plus);
let label = format!(
"parse_nested_depth_{depth}_{}",
if leading_plus { "plus" } else { "plain" },
);
runner.bench_function(&label, move |_| {
black_box(parse_query(black_box(&query)).unwrap());
});
}
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-bitpacker"
version = "0.9.0"
version = "0.10.0"
edition = "2024"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
@@ -18,5 +18,10 @@ homepage = "https://github.com/quickwit-oss/tantivy"
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
[dev-dependencies]
binggan = "0.17.0"
rand = "0.9"
proptest = "1"
[[bench]]
name = "bench"
harness = false

View File

@@ -1,65 +1,110 @@
#![feature(test)]
use std::cell::RefCell;
extern crate test;
use binggan::{BenchRunner, black_box};
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
#[cfg(test)]
mod tests {
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
buffer
}
#[inline(never)]
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
// the values do not matter.
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
const N: usize = 100_000;
const MAX_VAL: u64 = 1_000;
const BIT_WIDTH: u8 = 10; // 2^10 = 1024 > MAX_VAL
fn create_packed_data() -> (BitUnpacker, Vec<u8>) {
let mut bitpacker = BitPacker::new();
let mut data = Vec::new();
for i in 0..N as u64 {
let val = i * MAX_VAL / N as u64;
bitpacker.write(val, BIT_WIDTH, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
(BitUnpacker::new(BIT_WIDTH), data)
}
fn bench_bitpacking() {
let mut runner = BenchRunner::new();
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
runner.bench_function("bitpacking_read", move |_| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
}
buffer
}
black_box(out);
});
}
#[bench]
fn bench_bitpacking_read(b: &mut Bencher) {
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
}
out
});
fn bench_blocked_bitpacker() {
let mut runner = BenchRunner::new();
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
#[bench]
fn bench_blockedbitp_read(b: &mut Bencher) {
runner.bench_function("blockedbitp_read", move |_| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
black_box(out);
});
runner.bench_function("blockedbitp_create", |_| {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
b.iter(|| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
out
});
}
#[bench]
fn bench_blockedbitp_create(b: &mut Bencher) {
b.iter(|| {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
blocked_bitpacker
});
}
black_box(blocked_bitpacker);
});
}
fn bench_filter_vec() {
let mut runner = BenchRunner::new();
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_dense", move |_| {
unpacker.get_ids_for_value_range(
250..=750,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_sparse", move |_| {
unpacker.get_ids_for_value_range(0..=50, 0..N as u32, &data, &mut positions.borrow_mut());
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_full", move |_| {
unpacker.get_ids_for_value_range(
0..=MAX_VAL,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
}
fn main() {
bench_bitpacking();
bench_blocked_bitpacker();
bench_filter_vec();
}

View File

@@ -1,8 +1,17 @@
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
use std::arch::is_aarch64_feature_detected;
use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "aarch64")]
mod neon;
// SVE intrinsics are not exposed on aarch64-apple-darwin.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
mod sve;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
@@ -10,6 +19,10 @@ mod scalar;
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
SVE = 3u8,
#[cfg(target_arch = "aarch64")]
Neon = 2u8,
Scalar = 1u8,
}
@@ -19,29 +32,57 @@ impl FilterImplPerInstructionSet {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::SVE => is_aarch64_feature_detected!("sve"),
// TIL Neon is required on aarch 64.
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => true,
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementation in preferred order.
// List of available implementations in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(target_arch = "x86_64"))]
// Non-Apple aarch64: try SVE, NEON, Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::SVE,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Apple aarch64 (M-series): SVE not available; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[inline]
#[allow(unused_variables)] // on non-x86_64, code is unused.
#[allow(unused_variables)]
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
if code == FilterImplPerInstructionSet::SVE as u8 {
return FilterImplPerInstructionSet::SVE;
}
#[cfg(target_arch = "aarch64")]
if code == FilterImplPerInstructionSet::Neon as u8 {
return FilterImplPerInstructionSet::Neon;
}
FilterImplPerInstructionSet::Scalar
}
@@ -50,6 +91,13 @@ impl FilterImplPerInstructionSet {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
// SAFETY: SVE availability was verified by is_available() before selecting this impl.
FilterImplPerInstructionSet::SVE => unsafe {
sve::filter_vec_in_place(range, offset, output)
},
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => neon::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
@@ -57,6 +105,12 @@ impl FilterImplPerInstructionSet {
}
}
fn available_impls() -> impl Iterator<Item = FilterImplPerInstructionSet> {
IMPLS
.into_iter()
.filter(FilterImplPerInstructionSet::is_available)
}
#[inline]
fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
use std::sync::atomic::{AtomicU8, Ordering};
@@ -64,10 +118,7 @@ fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
.unwrap();
let instruction_set = available_impls().next().unwrap();
INSTRUCTION_SET_BYTE.store(instruction_set as u8, Ordering::Relaxed);
return instruction_set;
}
@@ -80,12 +131,12 @@ pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut
#[cfg(test)]
mod tests {
use proptest::strategy::Strategy;
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
@@ -102,6 +153,31 @@ mod tests {
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::SVE,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
@@ -126,11 +202,20 @@ mod tests {
assert_eq!(&output, &[1, 3, 4, 5, 6, 7, 8]);
}
fn test_filter_impl_empty_range_aux(filter_impl: FilterImplPerInstructionSet) {
// start > end: RangeInclusive::contains always returns false; output must be empty.
// The SVE path's wrapping_sub would otherwise produce a huge range_width.
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(10..=5, 0, &mut output);
assert_eq!(&output, &[]);
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
test_filter_impl_empty_range_aux(filter_impl);
}
#[test]
@@ -141,25 +226,60 @@ mod tests {
}
}
#[test]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn test_filter_implementation_sve() {
if FilterImplPerInstructionSet::SVE.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::SVE);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_filter_implementation_neon() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Neon);
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
}
#[cfg(target_arch = "x86_64")]
fn max_val_strategy() -> impl proptest::strategy::Strategy<Value = u32> {
proptest::prop_oneof![
0u32..10u32,
255u32..258u32,
proptest::prelude::Just(1u32 << 25),
proptest::prelude::Just(u32::MAX - 1),
proptest::prelude::Just(u32::MAX),
]
}
fn vals_strategy() -> impl proptest::strategy::Strategy<Value = Vec<u32>> {
proptest::prop_oneof![
proptest::collection::vec(proptest::prelude::any::<u32>(), 0..300),
max_val_strategy()
.prop_flat_map(|max_val| { proptest::collection::vec(0..=max_val, 0..300) })
]
}
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_avx2_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
fn test_filter_compare_scalar_and_impls_impl_proptest(
start in 0u32..400u32,
end in 0u32..400u32,
offset in 0u32..2u32,
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
if FilterImplPerInstructionSet::AVX2.is_available() {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::AVX2.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
vals in vals_strategy()) {
for implementation in available_impls() {
if implementation == FilterImplPerInstructionSet::Scalar {
continue;
}
let mut impl_output = vals.clone();
let mut scalar_output = vals.clone();
implementation.filter_vec_in_place(start..=end, offset, &mut impl_output);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut scalar_output);
assert_eq!(&impl_output, &scalar_output);
}
}
}
}

View File

@@ -0,0 +1,118 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
const NUM_LANES: usize = 4;
// Compacts matching lanes to the front using a byte-level shuffle.
// `mask` is a 4-bit value: bit k=1 means lane k should appear in the output.
#[inline]
#[target_feature(enable = "neon")]
unsafe fn compact(data: uint32x4_t, mask: u8) -> uint32x4_t {
unsafe {
// SAFETY: mask is always in [0, 15] by construction (max sum of [1,2,4,8]).
// BYTE_SHUFFLE_TABLE has 16 entries, so this is always in bounds.
let shuffle = BYTE_SHUFFLE_TABLE.get_unchecked(mask as usize);
let shuffle_vec = vld1q_u8(shuffle.as_ptr());
vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(data), shuffle_vec))
}
}
// Safe (not unsafe) because NEON is mandatory on aarch64: no runtime feature check needed.
#[inline(never)]
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_neon_aux(
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let remainder_start = num_words * NUM_LANES;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "neon")]
unsafe fn filter_vec_neon_aux(
input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
unsafe {
let mut input = input;
let mut output_tail = output;
let range_start_simd = vdupq_n_u32(*range.start());
let range_end_simd = vdupq_n_u32(*range.end());
let mut ids = vld1q_u32([offset, offset + 1, offset + 2, offset + 3].as_ptr());
let shift = vdupq_n_u32(NUM_LANES as u32);
let bit_weights = vld1q_u32([1u32, 2, 4, 8].as_ptr());
for _ in 0..num_words {
let word = vld1q_u32(input);
// Unsigned compares: CMHS (compare higher or same) tests `word >= start`
// and `end >= word`. ANDing both gives the inside-range mask directly,
// which is cheaper than computing `outside` and then negating.
let ge_start = vcgeq_u32(word, range_start_simd);
let le_end = vcleq_u32(word, range_end_simd);
// inside[k] = 0xFFFFFFFF if val[k] is in range, 0 otherwise.
let inside = vandq_u32(ge_start, le_end);
// Build the 4-bit mask: AND bit_weights with the inside lane mask, so each
// inside lane contributes its bit_weight (1, 2, 4, or 8). Summing yields the
// 4-bit mask in one addv.
let inside_bits = vandq_u32(bit_weights, inside);
let mask = vaddvq_u32(inside_bits) as u8;
// mask is mathematically bounded: max value is 1+2+4+8=15 (all lanes match)
debug_assert!(mask <= 15, "mask must fit in 4 bits: {}", mask);
// Count of matching lanes = popcount(mask). Derives the count directly from
// the mask instead of running a parallel SIMD reduction over `outside`.
let added_len = mask.count_ones() as usize;
// Safe because mask is guaranteed to be in [0, 15]
let filtered_ids = compact(ids, mask);
vst1q_u32(output_tail, filtered_ids);
output_tail = output_tail.add(added_len);
ids = vaddq_u32(ids, shift);
input = input.add(NUM_LANES);
}
output_tail.offset_from(output) as usize
}
}
// Byte shuffle patterns to compact matching lanes to the front of the vector.
// Index is a 4-bit mask: bit k=1 means lane k (bytes 4k..4k+3) is in-range.
// The j-th set bit determines which input lane goes to output position j.
const BYTE_SHUFFLE_TABLE: [[u8; 16]; 16] = [
[
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
], // 0b0000: none
[0, 1, 2, 3, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0001: lane 0
[4, 5, 6, 7, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0010: lane 1
[0, 1, 2, 3, 4, 5, 6, 7, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0011: lanes 0,1
[8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0100: lane 2
[0, 1, 2, 3, 8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0101: lanes 0,2
[4, 5, 6, 7, 8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0110: lanes 1,2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 16, 16, 16], // 0b0111: lanes 0,1,2
[
12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
], // 0b1000: lane 3
[0, 1, 2, 3, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1001: lanes 0,3
[4, 5, 6, 7, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1010: lanes 1,3
[0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1011: lanes 0,1,3
[8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1100: lanes 2,3
[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1101: lanes 0,2,3
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1110: lanes 1,2,3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 0b1111: all lanes
];

View File

@@ -0,0 +1,260 @@
use std::ops::RangeInclusive;
// SVE vector length (in u32 lanes) is not a compile-time constant; query at runtime.
// Safe to call only when SVE is confirmed available via is_aarch64_feature_detected!("sve").
#[target_feature(enable = "sve")]
unsafe fn num_lanes() -> usize {
let vl: usize;
unsafe {
core::arch::asm!(
"cntw {vl}",
vl = out(reg) vl,
options(nostack, nomem, preserves_flags),
);
}
vl
}
// SAFETY: caller must ensure SVE is available (checked via is_aarch64_feature_detected!("sve")).
// Unlike NEON, SVE is optional on aarch64 and not guaranteed by the target architecture.
pub unsafe fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
if range.start() > range.end() {
output.clear();
return;
}
let vl = unsafe { num_lanes() };
let num_words = output.len() / vl;
let range_start = *range.start();
// Unsigned subtraction trick: val ∈ [lo, hi] ↔ (val - lo) ≤ᵤ (hi - lo).
// Values below lo wrap around to large u32, so the single unsigned ≤ excludes them.
let range_width = range.end().wrapping_sub(range_start);
let mut output_len = unsafe {
filter_vec_sve_aux(
output.as_ptr(),
range_start,
range_width,
output.as_mut_ptr(),
offset,
num_words,
vl,
)
};
let remainder_start = num_words * vl;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
// Register allocation for the asm! blocks:
// z0 ids_a (index vector for first half of each pair, advances by step2 each iter)
// z1 range_width broadcast
// z2 range_start broadcast
// z3 step2 broadcast (2 * vl)
// z4 ids_b (index vector for second half, = ids_a + step, advances by step2)
// z5 scratch: loaded word_a, then compacted_a
// z6 scratch: loaded word_b, then compacted_b
// p0 all-true predicate (ptrue p0.s)
// p1 in-range mask for word_a
// p2 in-range mask for word_b
#[target_feature(enable = "sve")]
unsafe fn filter_vec_sve_aux(
input: *const u32,
range_start: u32,
range_width: u32,
output: *mut u32,
offset: u32,
num_words: usize,
vl: usize,
) -> usize {
let num_pairs = num_words / 2;
let mut input_ptr = input;
let mut output_tail = output;
if num_pairs > 0 {
unsafe {
// We rely on asm! because the SVE intrinsics are not available in stable Rust.
// The code that follows was generated by Rustc nightly based on the intrinsics version
// at the bottom of this file.
core::arch::asm!(
// --- Setup ---
// All-true predicate for 32-bit lanes.
"ptrue p0.s",
// ids_a = [offset, offset+1, offset+2, ...]
"index z0.s, {offset:w}, #1",
// Broadcast scalars into SVE vectors.
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
// vl_gpr = number of 32-bit lanes (cntw).
"cntw {vl_gpr}",
// step2_bytes will first hold 2*vl (for the step2 vector), then 2*VL in bytes.
"lsl {step2_bytes}, {vl_gpr}, #1",
// z4 = step = [vl, vl, ...]; will become ids_b after the add below.
"mov z4.s, {vl_gpr:w}",
// z3 = step2 = [2*vl, 2*vl, ...], used to advance both id vectors each iter.
"mov z3.s, {step2_bytes:w}",
// Repurpose step2_bytes to hold the byte stride for advancing the input pointer
// by two full SVE vectors per iteration.
"rdvl {step2_bytes}, #2",
// ids_b = ids_a + step = [offset+vl, offset+vl+1, ...]
"add z4.s, z0.s, z4.s",
// --- Main loop: process two SVE vectors (ids_a and ids_b) per iteration ---
"0:",
// Load two consecutive SVE vectors from input.
"ld1w {{z5.s}}, p0/z, [{input}]",
"ld1w {{z6.s}}, p0/z, [{input}, #1, mul vl]",
// Advance input pointer by 2 * VL bytes.
"add {input}, {input}, {step2_bytes}",
// Unsigned shift: subtract range_start so in-range check becomes a single cmpu ≤.
"sub z5.s, z5.s, z2.s",
"sub z6.s, z6.s, z2.s",
// in_range: shifted value ≤ range_width (unsigned, so values below lo also fail).
"cmphs p1.s, p0/z, z1.s, z5.s",
"cmphs p2.s, p0/z, z1.s, z6.s",
// Count matching lanes; both cntp calls have independent inputs for OOO parallelism.
"cntp {cnt_a}, p0, p1.s",
"compact z5.s, p1, z0.s",
"compact z6.s, p2, z4.s",
"cntp {cnt_b}, p0, p2.s",
// Advance id vectors for the next iteration.
"add z0.s, z0.s, z3.s",
"add z4.s, z4.s, z3.s",
// Store compacted ids. Only the first cnt_a / cnt_b slots are valid; the rest
// will be overwritten by subsequent iterations before the final truncate.
"str z5, [{out}]",
"st1w {{z6.s}}, p0, [{out}, {cnt_a}, lsl #2]",
"add {out}, {out}, {cnt_a}, lsl #2",
"add {out}, {out}, {cnt_b}, lsl #2",
"subs {pairs}, {pairs}, #1",
"b.ne 0b",
// --- Operands ---
input = inout(reg) input_ptr,
out = inout(reg) output_tail,
pairs = inout(reg) num_pairs => _,
offset = in(reg) offset,
range_start = in(reg) range_start,
range_width = in(reg) range_width,
vl_gpr = out(reg) _,
step2_bytes = out(reg) _,
cnt_a = out(reg) _,
cnt_b = out(reg) _,
out("p0") _, out("p1") _, out("p2") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _,
options(nostack),
);
}
}
// Handle an odd trailing vector.
if num_words % 2 == 1 {
// ids_a for the odd word starts at offset + num_pairs * 2 * vl.
// input_ptr was advanced by the main loop and now points at the odd word.
let odd_offset =
offset.wrapping_add((num_pairs as u32).wrapping_mul(2).wrapping_mul(vl as u32));
unsafe {
core::arch::asm!(
"ptrue p0.s",
"index z0.s, {odd_offset:w}, #1",
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
"ld1w {{z3.s}}, p0/z, [{input}]",
"sub z3.s, z3.s, z2.s",
"cmphs p1.s, p0/z, z1.s, z3.s",
"cntp {cnt}, p0, p1.s",
"compact z0.s, p1, z0.s",
"str z0, [{out}]",
"add {out}, {out}, {cnt}, lsl #2",
odd_offset = in(reg) odd_offset,
range_width = in(reg) range_width,
range_start = in(reg) range_start,
input = in(reg) input_ptr,
out = inout(reg) output_tail,
cnt = out(reg) _,
out("p0") _, out("p1") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
options(nostack),
);
}
}
unsafe { output_tail.offset_from(output) as usize }
}
// SVE implements with intrinsics.
//
// #[target_feature(enable = "sve")]
// unsafe fn filter_vec_sve_aux(
// input: *const u32,
// range_start: u32,
// range_width: u32,
// output: *mut u32,
// offset: u32,
// num_words: usize,
// vl: usize,
// ) -> usize {
// unsafe {
// let all_true = svptrue_b32();
// let range_start_simd = svdup_n_u32(range_start);
// let range_width_simd = svdup_n_u32(range_width);
// // ids_a covers [offset .. offset+vl), ids_b covers the next vl ids.
// // Keeping them separate breaks the loop-carried dependency through ids so
// // both compact/cntp chains are fully independent within each unrolled body.
// let mut ids_a = svindex_u32(offset, 1);
// let step = svdup_n_u32(vl as u32);
// let step2 = svdup_n_u32(2 * vl as u32);
// let mut ids_b = svadd_u32_x(all_true, ids_a, step);
// let mut input = input;
// let mut output_tail = output;
// // Unrolled ×2: both cntp calls have independent inputs and execute in parallel.
// // The two output_tail updates are sequential but together cost 4+1+1=6 cy per
// // pair vs 5+5=10 cy for two scalar iterations, breaking the cntp latency chain.
// let num_pairs = num_words / 2;
// for _ in 0..num_pairs {
// let word_a = svld1_u32(all_true, input);
// let word_b = svld1_u32(all_true, input.add(vl));
// let shifted_a = svsub_u32_x(all_true, word_a, range_start_simd);
// let shifted_b = svsub_u32_x(all_true, word_b, range_start_simd);
// let in_range_a = svcmple_u32(all_true, shifted_a, range_width_simd);
// let in_range_b = svcmple_u32(all_true, shifted_b, range_width_simd);
// let compacted_a = svcompact_u32(in_range_a, ids_a);
// let compacted_b = svcompact_u32(in_range_b, ids_b);
// // cntp_a and cntp_b have independent inputs: OOO engine issues them in parallel.
// let added_len_a = svcntp_b32(all_true, in_range_a) as usize;
// let added_len_b = svcntp_b32(all_true, in_range_b) as usize;
// // Write the full vector — only the first added_len slots are valid.
// // Subsequent iterations overwrite the trailing zeros before truncate.
// svst1_u32(all_true, output_tail, compacted_a);
// output_tail = output_tail.add(added_len_a);
// svst1_u32(all_true, output_tail, compacted_b);
// output_tail = output_tail.add(added_len_b);
// ids_a = svadd_u32_x(all_true, ids_a, step2);
// ids_b = svadd_u32_x(all_true, ids_b, step2);
// input = input.add(2 * vl);
// }
// // Handle an odd trailing word.
// if num_words % 2 == 1 {
// let word = svld1_u32(all_true, input);
// let shifted = svsub_u32_x(all_true, word, range_start_simd);
// let in_range = svcmple_u32(all_true, shifted, range_width_simd);
// let added_len = svcntp_b32(all_true, in_range) as usize;
// let compacted_ids = svcompact_u32(in_range, ids_a);
// svst1_u32(all_true, output_tail, compacted_ids);
// output_tail = output_tail.add(added_len);
// }
// output_tail.offset_from(output) as usize
// }
// }

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-columnar"
version = "0.6.0"
version = "0.7.0"
edition = "2024"
license = "MIT"
homepage = "https://github.com/quickwit-oss/tantivy"
@@ -12,10 +12,10 @@ categories = ["database-implementations", "data-structures", "compression"]
itertools = "0.14.0"
fastdivide = "0.4.0"
stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
stacker = { version= "0.7", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.7", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.11", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.10", path = "../bitpacker/" }
serde = "1.0.152"
downcast-rs = "2.0.1"
@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
binggan = "0.14.0"
binggan = "0.17.0"
[[bench]]
name = "bench_merge"

View File

@@ -28,7 +28,7 @@ fn get_test_columns() -> Columns {
}
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer
.serialize(data.len() as u32, &mut buffer)
.serialize(data.len() as u32, None, &mut buffer)
.unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();

View File

@@ -54,6 +54,6 @@ pub fn generate_columnar_with_name(card: Card, num_docs: u32, column_name: &str)
}
let mut wrt: Vec<u8> = Vec::new();
columnar_writer.serialize(num_docs, &mut wrt).unwrap();
columnar_writer.serialize(num_docs, None, &mut wrt).unwrap();
ColumnarReader::open(wrt).unwrap()
}

View File

@@ -15,9 +15,37 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
{
#[inline]
pub fn fetch_block<'a>(&'a mut self, docs: &'a [u32], accessor: &Column<T>) {
if accessor.index.get_cardinality().is_full() {
self.val_cache.resize(docs.len(), T::default());
accessor.values.get_vals(docs, &mut self.val_cache);
self.fetch_block_with_is_full(docs, accessor, accessor.index.get_cardinality().is_full());
}
/// Like [`Self::fetch_block`] but takes the column's fullness instead of querying
/// `accessor.index.get_cardinality()` each call — for callers that know it up front (e.g.
/// checked once at construction). `is_full` must equal
/// `accessor.index.get_cardinality().is_full()`.
#[inline]
pub fn fetch_block_with_is_full<'a>(
&'a mut self,
docs: &'a [u32],
accessor: &Column<T>,
is_full: bool,
) {
if is_full {
// Skip the resize when already the right length (common case: fixed-size blocks).
if self.val_cache.len() != docs.len() {
self.val_cache.resize(docs.len(), T::default());
}
// When the docs form a contiguous ascending run we can fetch the values
// as a single range. This lets codecs (e.g. bitpacked) bulk-decode the
// slice instead of gathering value-by-value, and avoids per-value dynamic
// dispatch. `docs` is always sorted ascending and free of duplicates here,
// so comparing the endpoints is enough to detect contiguity.
if is_contiguous(docs) {
accessor
.values
.get_range(docs[0] as u64, &mut self.val_cache);
} else {
accessor.values.get_vals(docs, &mut self.val_cache);
}
} else {
self.docid_cache.clear();
self.row_id_cache.clear();
@@ -33,14 +61,14 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
missing_opt: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
let Some(missing) = missing_opt else {
return;
};
@@ -158,6 +186,22 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
/// Returns true if `docs` is a contiguous ascending run `[d, d + 1, ..., d + n - 1]`.
///
/// Assumes `docs` is sorted ascending and free of duplicates (the invariant for the
/// doc blocks passed to `fetch_block`), so comparing the endpoints is sufficient.
#[inline]
fn is_contiguous(docs: &[u32]) -> bool {
let (Some(&first), Some(&last)) = (docs.first(), docs.last()) else {
return false;
};
debug_assert!(
docs.windows(2).all(|w| w[0] < w[1]),
"fetch_block requires docs sorted ascending without duplicates"
);
(last - first) as usize + 1 == docs.len()
}
/// Given two sorted lists of docids `docs` and `hits`, hits is a subset of `docs`.
/// Return all docs that are not in `hits`.
fn find_missing_docs<F>(docs: &[u32], hits: &[u32], mut callback: F)
@@ -191,6 +235,7 @@ where F: FnMut(u32) {
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
@@ -287,4 +332,46 @@ mod tests {
assert_eq!(accessor.docid_cache, vec![0]);
assert_eq!(accessor.val_cache, vec![1]);
}
#[test]
fn test_is_contiguous() {
assert!(!is_contiguous(&[]));
assert!(is_contiguous(&[5]));
assert!(is_contiguous(&[5, 6, 7, 8]));
assert!(is_contiguous(&[0, 1, 2]));
assert!(!is_contiguous(&[5, 7, 8]));
assert!(!is_contiguous(&[0, 1, 3]));
}
#[test]
fn test_fetch_block_contiguous_and_gather_match() {
use crate::column_index::ColumnIndex;
use crate::column_values::{
ALL_U64_CODEC_TYPES, serialize_and_load_u64_based_column_values,
};
let vals: Vec<u64> = (0..200u64).map(|i| i * 7 + 3).collect();
let values =
serialize_and_load_u64_based_column_values::<u64>(&&vals[..], &ALL_U64_CODEC_TYPES);
let column = Column {
index: ColumnIndex::Full,
values,
};
let check = |accessor: &mut ColumnBlockAccessor<u64>, docs: &[u32]| {
accessor.fetch_block(docs, &column);
let got: Vec<(u32, u64)> = accessor.iter_docid_vals(docs, &column).collect();
let expected: Vec<(u32, u64)> = docs.iter().map(|&d| (d, vals[d as usize])).collect();
assert_eq!(got, expected);
};
let mut accessor = ColumnBlockAccessor::<u64>::default();
// Contiguous block -> get_range fast path.
check(&mut accessor, &(10..74).collect::<Vec<u32>>());
// Non-contiguous block -> get_vals gather path.
check(&mut accessor, &[0, 5, 9, 100, 199]);
// Single doc and full span.
check(&mut accessor, &[42]);
check(&mut accessor, &(0..200).collect::<Vec<u32>>());
}
}

View File

@@ -375,7 +375,7 @@ mod tests {
columnar_writer.record_numerical(5, "full", u64::MAX);
let mut wrt: Vec<u8> = Vec::new();
columnar_writer.serialize(7, &mut wrt).unwrap();
columnar_writer.serialize(7, None, &mut wrt).unwrap();
let reader = ColumnarReader::open(wrt).unwrap();
// Open the column as u64

View File

@@ -15,7 +15,9 @@ fn test_optional_index_with_num_docs(num_docs: u32) {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(100, "score", 80i64);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(num_docs, &mut buffer).unwrap();
dataframe_writer
.serialize(num_docs, None, &mut buffer)
.unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("score").unwrap();

View File

@@ -119,8 +119,18 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// the segment's `maxdoc`.
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
for (out, idx) in output.iter_mut().zip(start..) {
let mut out_chunks = output.chunks_exact_mut(4);
let mut idx = start;
for out_x4 in out_chunks.by_ref() {
out_x4[0] = self.get_val(idx as u32);
out_x4[1] = self.get_val((idx + 1) as u32);
out_x4[2] = self.get_val((idx + 2) as u32);
out_x4[3] = self.get_val((idx + 3) as u32);
idx += 4;
}
for out in out_chunks.into_remainder() {
*out = self.get_val(idx as u32);
idx += 1;
}
}

View File

@@ -121,6 +121,22 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);
// Validate `get_range` over the full column and a sub-range. The sub-range starts
// at a non-zero offset to exercise the entrance-ramp alignment of the batch decode.
buffer.resize(all_docs.len(), 0);
reader.get_range(0, &mut buffer);
assert_eq!(vals, buffer, "get_range (full) mismatch in data set {name}");
if vals.len() >= 2 {
let start = 1usize;
buffer.resize(vals.len() - start, 0);
reader.get_range(start as u64, &mut buffer);
assert_eq!(
&vals[start..],
&buffer[..],
"get_range (sub-range) mismatch in data set {name}"
);
}
if !vals.is_empty() {
let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals

View File

@@ -33,6 +33,25 @@ pub fn merge_bytes_or_str_column(
Ok(())
}
/// Computes a per-segment mapping from old term ordinal to merged term ordinal.
///
/// Performs a streaming k-way merge of per-segment term dictionaries (SSTable-backed) to build
/// a unified ordering. For each segment, the output is a `Vec<TermOrdinal>` where index `i`
/// holds the merged global ordinal corresponding to segment-local ordinal `i`.
///
/// This is used by index sorting to compare terms from different segments without materializing
/// term bytes in memory — only ordinals are compared.
#[doc(hidden)]
pub fn compute_merged_term_ord_mapping(
bytes_columns: &[BytesColumn],
) -> io::Result<Vec<Vec<TermOrdinal>>> {
let bytes_columns_opt: Vec<Option<BytesColumn>> =
bytes_columns.iter().cloned().map(Some).collect();
let term_ord_mapping =
merge_dict_and_compute_term_ord_mapping(&bytes_columns_opt, |_| true, |_| Ok(()))?;
Ok(term_ord_mapping.into_per_segment_new_term_ordinals())
}
struct RemappedTermOrdinalsValues<'a> {
bytes_columns: &'a [Option<BytesColumn>],
term_ord_mapping: &'a TermOrdinalMapping,
@@ -118,14 +137,14 @@ fn is_term_present(bitsets: &[Option<BitSet>], term_merger: &TermMerger) -> bool
false
}
fn serialize_merged_dict(
fn merge_dict_and_compute_term_ord_mapping(
bytes_columns: &[Option<BytesColumn>],
merge_row_order: &MergeRowOrder,
output: &mut impl Write,
mut should_keep_term: impl FnMut(&TermMerger) -> bool,
mut emit_term: impl FnMut(&[u8]) -> io::Result<()>,
) -> io::Result<TermOrdinalMapping> {
let mut term_ord_mapping = TermOrdinalMapping::default();
let mut field_term_streams = Vec::new();
let mut field_term_streams = Vec::with_capacity(bytes_columns.len());
for (segment_ord, column_opt) in bytes_columns.iter().enumerate() {
if let Some(column) = column_opt {
term_ord_mapping.add_segment(column.dictionary.num_terms());
@@ -141,21 +160,33 @@ fn serialize_merged_dict(
}
let mut merged_terms = TermMerger::new(field_term_streams);
let mut sstable_builder = sstable::VoidSSTable::writer(output);
match merge_row_order {
MergeRowOrder::Stack(_) => {
let mut current_term_ord = 0;
while merged_terms.advance() {
let term_bytes: &[u8] = merged_terms.key();
sstable_builder.insert(term_bytes, &())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
sstable_builder.finish()?;
let mut current_term_ord = 0;
while merged_terms.advance() {
if !should_keep_term(&merged_terms) {
continue;
}
emit_term(merged_terms.key())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
Ok(term_ord_mapping)
}
fn serialize_merged_dict(
bytes_columns: &[Option<BytesColumn>],
merge_row_order: &MergeRowOrder,
output: &mut impl Write,
) -> io::Result<TermOrdinalMapping> {
let mut sstable_builder = sstable::VoidSSTable::writer(output);
let term_ord_mapping = match merge_row_order {
MergeRowOrder::Stack(_) => merge_dict_and_compute_term_ord_mapping(
bytes_columns,
|_| true,
|term_bytes| sstable_builder.insert(term_bytes, &()),
)?,
MergeRowOrder::Shuffled(shuffle_merge_order) => {
assert_eq!(shuffle_merge_order.alive_bitsets.len(), bytes_columns.len());
let mut term_bitsets: Vec<Option<BitSet>> = Vec::with_capacity(bytes_columns.len());
@@ -174,21 +205,14 @@ fn serialize_merged_dict(
}
}
}
let mut current_term_ord = 0;
while merged_terms.advance() {
let term_bytes: &[u8] = merged_terms.key();
if !is_term_present(&term_bitsets[..], &merged_terms) {
continue;
}
sstable_builder.insert(term_bytes, &())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
sstable_builder.finish()?;
merge_dict_and_compute_term_ord_mapping(
bytes_columns,
|merged_terms| is_term_present(&term_bitsets[..], merged_terms),
|term_bytes| sstable_builder.insert(term_bytes, &()),
)?
}
}
};
sstable_builder.finish()?;
Ok(term_ord_mapping)
}
@@ -211,4 +235,8 @@ impl TermOrdinalMapping {
fn get_segment(&self, segment_ord: u32) -> &[TermOrdinal] {
&self.per_segment_new_term_ordinals[segment_ord as usize]
}
fn into_per_segment_new_term_ordinals(self) -> Vec<Vec<TermOrdinal>> {
self.per_segment_new_term_ordinals
}
}

View File

@@ -7,6 +7,7 @@ use std::io;
use std::net::Ipv6Addr;
use std::sync::Arc;
pub use merge_dict_column::compute_merged_term_ord_mapping;
pub use merge_mapping::{MergeRowOrder, ShuffleMergeOrder, StackMergeOrder};
use super::writer::ColumnarSerializer;

View File

@@ -17,7 +17,7 @@ fn make_columnar<T: Into<NumericalValue> + HasAssociatedColumnType + Copy>(
}
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer
.serialize(vals.len() as RowId, &mut buffer)
.serialize(vals.len() as RowId, None, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}
@@ -143,7 +143,9 @@ fn make_numerical_columnar_multiple_columns(
.max()
.unwrap_or(0u32);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(num_rows, &mut buffer).unwrap();
dataframe_writer
.serialize(num_rows, None, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}
@@ -166,7 +168,9 @@ fn make_byte_columnar_multiple_columns(
}
}
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(num_rows, &mut buffer).unwrap();
dataframe_writer
.serialize(num_rows, None, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}
@@ -185,7 +189,9 @@ fn make_text_columnar_multiple_columns(columns: &[(&str, &[&[&str]])]) -> Column
.max()
.unwrap_or(0u32);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(num_rows, &mut buffer).unwrap();
dataframe_writer
.serialize(num_rows, None, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}
@@ -544,7 +550,7 @@ fn build_columnar(spec: &ColumnarSpec) -> ColumnarReader {
}
let mut buffer = Vec::new();
writer.serialize(max_row_id + 1, &mut buffer).unwrap();
writer.serialize(max_row_id + 1, None, &mut buffer).unwrap();
ColumnarReader::open(buffer).unwrap()
}

View File

@@ -8,6 +8,9 @@ pub use column_type::{ColumnType, HasAssociatedColumnType};
pub use format_version::{CURRENT_VERSION, Version};
#[cfg(test)]
pub(crate) use merge::ColumnTypeCategory;
pub use merge::{MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, merge_columnar};
pub use merge::{
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, compute_merged_term_ord_mapping,
merge_columnar,
};
pub use reader::ColumnarReader;
pub use writer::ColumnarWriter;

View File

@@ -226,7 +226,7 @@ mod tests {
columnar_writer.record_column_type("col1", ColumnType::Str, false);
columnar_writer.record_column_type("col2", ColumnType::U64, false);
let mut buffer = Vec::new();
columnar_writer.serialize(1, &mut buffer).unwrap();
columnar_writer.serialize(1, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
let columns = columnar.list_columns().unwrap();
assert_eq!(columns.len(), 2);
@@ -242,7 +242,7 @@ mod tests {
columnar_writer.record_column_type("count", ColumnType::U64, false);
columnar_writer.record_numerical(1, "count", 1u64);
let mut buffer = Vec::new();
columnar_writer.serialize(2, &mut buffer).unwrap();
columnar_writer.serialize(2, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
let columns = columnar.list_columns().unwrap();
assert_eq!(columns.len(), 1);
@@ -256,7 +256,7 @@ mod tests {
columnar_writer.record_column_type("col", ColumnType::U64, false);
columnar_writer.record_numerical(1, "col", 1u64);
let mut buffer = Vec::new();
columnar_writer.serialize(2, &mut buffer).unwrap();
columnar_writer.serialize(2, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
{
let columns = columnar.read_columns("col").unwrap();
@@ -285,7 +285,7 @@ mod tests {
columnar_writer.record_str(1, "col1", "hello");
columnar_writer.record_str(0, "col2", "hello");
let mut buffer = Vec::new();
columnar_writer.serialize(2, &mut buffer).unwrap();
columnar_writer.serialize(2, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
{

View File

@@ -41,10 +41,31 @@ impl ColumnWriter {
pub(super) fn operation_iterator<'a, V: SymbolValue>(
&self,
arena: &MemoryArena,
old_to_new_ids_opt: Option<&[RowId]>,
buffer: &'a mut Vec<u8>,
) -> impl Iterator<Item = ColumnOperation<V>> + 'a + use<'a, V> {
buffer.clear();
self.values.read_to_end(arena, buffer);
if let Some(old_to_new_ids) = old_to_new_ids_opt {
// TODO avoid the extra deserialization / serialization.
let mut sorted_ops: Vec<(RowId, ColumnOperation<V>)> = Vec::new();
let mut new_doc = 0u32;
let mut cursor = &buffer[..];
for op in std::iter::from_fn(|| ColumnOperation::<V>::deserialize(&mut cursor)) {
if let ColumnOperation::NewDoc(doc) = &op {
new_doc = old_to_new_ids[*doc as usize];
sorted_ops.push((new_doc, ColumnOperation::NewDoc(new_doc)));
} else {
sorted_ops.push((new_doc, op));
}
}
// stable sort is crucial here.
sorted_ops.sort_by_key(|(new_doc_id, _)| *new_doc_id);
buffer.clear();
for (_, op) in sorted_ops {
buffer.extend_from_slice(op.serialize().as_ref());
}
}
let mut cursor: &[u8] = &buffer[..];
std::iter::from_fn(move || ColumnOperation::deserialize(&mut cursor))
}
@@ -211,9 +232,11 @@ impl NumericalColumnWriter {
pub(super) fn operation_iterator<'a>(
self,
arena: &MemoryArena,
old_to_new_ids: Option<&[RowId]>,
buffer: &'a mut Vec<u8>,
) -> impl Iterator<Item = ColumnOperation<NumericalValue>> + 'a + use<'a> {
self.column_writer.operation_iterator(arena, buffer)
self.column_writer
.operation_iterator(arena, old_to_new_ids, buffer)
}
}
@@ -255,9 +278,11 @@ impl StrOrBytesColumnWriter {
pub(super) fn operation_iterator<'a>(
&self,
arena: &MemoryArena,
old_to_new_ids: Option<&[RowId]>,
byte_buffer: &'a mut Vec<u8>,
) -> impl Iterator<Item = ColumnOperation<UnorderedId>> + 'a + use<'a> {
self.column_writer.operation_iterator(arena, byte_buffer)
self.column_writer
.operation_iterator(arena, old_to_new_ids, byte_buffer)
}
}

View File

@@ -44,7 +44,7 @@ struct SpareBuffers {
/// columnar_writer.record_str(1u32 /* doc id */, "product_name", "Apple");
/// columnar_writer.record_numerical(0u32 /* doc id */, "price", 10.5f64); //< uh oh we ended up mixing integer and floats.
/// let mut wrt: Vec<u8> = Vec::new();
/// columnar_writer.serialize(2u32, &mut wrt).unwrap();
/// columnar_writer.serialize(2u32, None, &mut wrt).unwrap();
/// ```
#[derive(Default)]
pub struct ColumnarWriter {
@@ -76,6 +76,75 @@ impl ColumnarWriter {
.sum::<usize>()
}
/// Returns the list of doc ids from 0..num_docs sorted by the `sort_field`
/// column.
///
/// If the column is multivalued, use the first value for scoring.
/// If no value is associated to a specific row, the document is assigned
/// the lowest possible score.
///
/// The sort applied is stable.
pub fn sort_order(&self, sort_field: &str, num_docs: RowId, reversed: bool) -> Vec<u32> {
let Some(numerical_col_writer) = self
.numerical_field_hash_map
.get::<NumericalColumnWriter>(sort_field.as_bytes())
.or_else(|| {
self.datetime_field_hash_map
.get::<NumericalColumnWriter>(sort_field.as_bytes())
})
else {
let str_or_bytes_column_opt = self
.str_field_hash_map
.get::<StrOrBytesColumnWriter>(sort_field.as_bytes())
.or_else(|| {
self.bytes_field_hash_map
.get::<StrOrBytesColumnWriter>(sort_field.as_bytes())
});
let Some(str_or_bytes_column) = str_or_bytes_column_opt else {
return Vec::new();
};
let dictionary_builder = &self.dictionaries[str_or_bytes_column.dictionary_id as usize];
let term_id_mapping = dictionary_builder.build_term_id_mapping(&self.arena);
let mut symbols_buffer = Vec::new();
return collect_sort_order_from_ops(
str_or_bytes_column.operation_iterator(&self.arena, None, &mut symbols_buffer),
num_docs,
reversed,
|uid| Some(term_id_mapping.to_ord(uid).0),
None,
|a, b| a.cmp(b),
);
};
let mut symbols_buffer = Vec::new();
collect_sort_order_from_ops(
numerical_col_writer.operation_iterator(&self.arena, None, &mut symbols_buffer),
num_docs,
reversed,
// MonotonicallyMappableToU64 converts each value to u64 in an
// order-preserving way (u64: identity, i64: XOR sign bit, f64: bit
// manipulation). Converting once per document lets the comparator be
// a simple u64 cmp instead of unwrapping the NumericalValue variant
// on every comparison.
//
// For f64, NaN maps to a deterministic u64 via raw bit manipulation,
// so it sorts to a consistent position. Sorting only requires total
// ordering, not IEEE 754 equality semantics where NaN != NaN.
|nv| {
Some(match nv {
NumericalValue::U64(v) => v.to_u64(),
NumericalValue::I64(v) => v.to_u64(),
NumericalValue::F64(v) => v.to_u64(),
})
},
// None for missing values. Option<u64> sorts None < Some(_),
// placing nulls before non-null values.
None,
|a, b| a.cmp(b),
)
}
/// Records a column type. This is useful to bypass the coercion process,
/// makes sure the empty is present in the resulting columnar, or set
/// the `sort_values_within_row`.
@@ -246,7 +315,12 @@ impl ColumnarWriter {
},
);
}
pub fn serialize(&mut self, num_docs: RowId, wrt: &mut dyn io::Write) -> io::Result<()> {
pub fn serialize(
&mut self,
num_docs: RowId,
old_to_new_row_ids: Option<&[RowId]>,
wrt: &mut dyn io::Write,
) -> io::Result<()> {
let mut serializer = ColumnarSerializer::new(wrt);
let mut columns: Vec<(&[u8], ColumnType, Addr)> = self
@@ -303,7 +377,11 @@ impl ColumnarWriter {
serialize_bool_column(
cardinality,
num_docs,
column_writer.operation_iterator(arena, &mut symbol_byte_buffer),
column_writer.operation_iterator(
arena,
old_to_new_row_ids,
&mut symbol_byte_buffer,
),
buffers,
&mut column_serializer,
)?;
@@ -317,7 +395,11 @@ impl ColumnarWriter {
serialize_ip_addr_column(
cardinality,
num_docs,
column_writer.operation_iterator(arena, &mut symbol_byte_buffer),
column_writer.operation_iterator(
arena,
old_to_new_row_ids,
&mut symbol_byte_buffer,
),
buffers,
&mut column_serializer,
)?;
@@ -342,8 +424,11 @@ impl ColumnarWriter {
num_docs,
str_or_bytes_column_writer.sort_values_within_row,
dictionary_builder,
str_or_bytes_column_writer
.operation_iterator(arena, &mut symbol_byte_buffer),
str_or_bytes_column_writer.operation_iterator(
arena,
old_to_new_row_ids,
&mut symbol_byte_buffer,
),
buffers,
&self.arena,
&mut column_serializer,
@@ -361,7 +446,11 @@ impl ColumnarWriter {
cardinality,
num_docs,
numerical_type,
numerical_column_writer.operation_iterator(arena, &mut symbol_byte_buffer),
numerical_column_writer.operation_iterator(
arena,
old_to_new_row_ids,
&mut symbol_byte_buffer,
),
buffers,
&mut column_serializer,
)?;
@@ -376,7 +465,11 @@ impl ColumnarWriter {
cardinality,
num_docs,
NumericalType::I64,
column_writer.operation_iterator(arena, &mut symbol_byte_buffer),
column_writer.operation_iterator(
arena,
old_to_new_row_ids,
&mut symbol_byte_buffer,
),
buffers,
&mut column_serializer,
)?;
@@ -389,6 +482,56 @@ impl ColumnarWriter {
}
}
/// Shared sorting pattern for both numeric and Str/Bytes sort fields.
///
/// Iterates column operations, fills gaps for missing docs with `default_key`, converts each value
/// to a sort key via `value_to_key`, then sorts by the key using `cmp_keys`. Returns the doc ids
/// in sorted order.
fn collect_sort_order_from_ops<V, K: Clone>(
ops: impl Iterator<Item = ColumnOperation<V>>,
num_docs: RowId,
reversed: bool,
value_to_key: impl Fn(V) -> K,
default_key: K,
cmp_keys: impl Fn(&K, &K) -> std::cmp::Ordering,
) -> Vec<u32> {
let mut doc_sort_keys: Vec<(K, RowId)> = Vec::with_capacity(num_docs as usize);
let mut start_doc_check_fill: RowId = 0;
let mut current_doc_opt: Option<RowId> = None;
for op in ops {
match op {
ColumnOperation::NewDoc(doc) => {
current_doc_opt = Some(doc);
}
ColumnOperation::Value(val) => {
if let Some(current_doc) = current_doc_opt {
// Fill gaps since the last doc with the default key.
doc_sort_keys.extend(
(start_doc_check_fill..current_doc).map(|doc| (default_key.clone(), doc)),
);
start_doc_check_fill = current_doc + 1;
// For multivalued fields, only the first value is used.
current_doc_opt = None;
doc_sort_keys.push((value_to_key(val), current_doc));
}
}
}
}
// Fill remaining docs at the tail.
doc_sort_keys.extend((start_doc_check_fill..num_docs).map(|doc| (default_key.clone(), doc)));
doc_sort_keys.sort_by(|(left_key, _), (right_key, _)| {
let cmp = cmp_keys(left_key, right_key);
if reversed { cmp.reverse() } else { cmp }
});
doc_sort_keys
.into_iter()
.map(|(_sort_key, doc)| doc)
.collect()
}
// Serialize [Dictionary, Column, dictionary num bytes U32::LE]
// Column: [Column Index, Column Values, column index num bytes U32::LE]
#[expect(clippy::too_many_arguments)]
@@ -689,7 +832,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(3), Cardinality::Full);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&arena, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 6);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));
@@ -718,7 +861,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(3), Cardinality::Optional);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&arena, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 4);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(1u32)));
@@ -741,7 +884,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(2), Cardinality::Optional);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&arena, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 2);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));
@@ -760,7 +903,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(1), Cardinality::Multivalued);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&arena, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 3);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));

View File

@@ -27,7 +27,7 @@ fn generate_columnar(num_docs: u32, value_offset: u64) -> Vec<u8> {
}
let mut wrt: Vec<u8> = Vec::new();
columnar_writer.serialize(num_docs, &mut wrt).unwrap();
columnar_writer.serialize(num_docs, None, &mut wrt).unwrap();
wrt
}

View File

@@ -51,6 +51,16 @@ impl DictionaryBuilder {
UnorderedId(unordered_id)
}
fn build_sorted_terms<'a>(&'a self, arena: &'a MemoryArena) -> Vec<(&'a [u8], UnorderedId)> {
let mut terms: Vec<(&[u8], UnorderedId)> = self
.dict
.iter(arena)
.map(|(k, v)| (k, arena.read(v)))
.collect();
terms.sort_unstable_by_key(|(key, _)| *key);
terms
}
/// Serialize the dictionary into an fst, and returns the
/// `UnorderedId -> TermOrdinal` map.
pub fn serialize<'a, W: io::Write + 'a>(
@@ -58,12 +68,7 @@ impl DictionaryBuilder {
arena: &MemoryArena,
wrt: &mut W,
) -> io::Result<TermIdMapping> {
let mut terms: Vec<(&[u8], UnorderedId)> = self
.dict
.iter(arena)
.map(|(k, v)| (k, arena.read(v)))
.collect();
terms.sort_unstable_by_key(|(key, _)| *key);
let terms = self.build_sorted_terms(arena);
// TODO Remove the allocation.
let mut unordered_to_ord: Vec<OrderedId> = vec![OrderedId(0u32); terms.len()];
let mut sstable_builder = sstable::VoidSSTable::writer(wrt);
@@ -76,6 +81,16 @@ impl DictionaryBuilder {
Ok(TermIdMapping { unordered_to_ord })
}
/// Build the `UnorderedId -> OrderedId` mapping in memory without serializing.
pub fn build_term_id_mapping(&self, arena: &MemoryArena) -> TermIdMapping {
let terms = self.build_sorted_terms(arena);
let mut unordered_to_ord: Vec<OrderedId> = vec![OrderedId(0u32); terms.len()];
for (ord, (_key, unordered_id)) in terms.into_iter().enumerate() {
unordered_to_ord[unordered_id.0 as usize] = OrderedId(ord as u32);
}
TermIdMapping { unordered_to_ord }
}
pub(crate) fn mem_usage(&self) -> usize {
self.dict.mem_usage()
}

View File

@@ -43,7 +43,8 @@ pub use column_values::{
};
pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, compute_merged_term_ord_mapping,
merge_columnar,
};
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};

View File

@@ -21,7 +21,7 @@ fn test_dataframe_writer_str() {
dataframe_writer.record_str(1u32, "my_string", "hello");
dataframe_writer.record_str(3u32, "my_string", "helloeee");
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(5, &mut buffer).unwrap();
dataframe_writer.serialize(5, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("my_string").unwrap();
@@ -35,7 +35,7 @@ fn test_dataframe_writer_bytes() {
dataframe_writer.record_bytes(1u32, "my_string", b"hello");
dataframe_writer.record_bytes(3u32, "my_string", b"helloeee");
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(5, &mut buffer).unwrap();
dataframe_writer.serialize(5, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("my_string").unwrap();
@@ -49,7 +49,7 @@ fn test_dataframe_writer_bool() {
dataframe_writer.record_bool(1u32, "bool.value", false);
dataframe_writer.record_bool(3u32, "bool.value", true);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(5, &mut buffer).unwrap();
dataframe_writer.serialize(5, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("bool.value").unwrap();
@@ -74,7 +74,7 @@ fn test_dataframe_writer_u64_multivalued() {
dataframe_writer.record_numerical(6u32, "divisor", 2u64);
dataframe_writer.record_numerical(6u32, "divisor", 3u64);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(7, &mut buffer).unwrap();
dataframe_writer.serialize(7, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("divisor").unwrap();
@@ -97,7 +97,7 @@ fn test_dataframe_writer_ip_addr() {
dataframe_writer.record_ip_addr(1, "ip_addr", Ipv6Addr::from_u128(1001));
dataframe_writer.record_ip_addr(3, "ip_addr", Ipv6Addr::from_u128(1050));
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(5, &mut buffer).unwrap();
dataframe_writer.serialize(5, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("ip_addr").unwrap();
@@ -128,7 +128,7 @@ fn test_dataframe_writer_numerical() {
dataframe_writer.record_numerical(2u32, "srical.value", NumericalValue::U64(13u64));
dataframe_writer.record_numerical(4u32, "srical.value", NumericalValue::U64(15u64));
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer.serialize(6, &mut buffer).unwrap();
dataframe_writer.serialize(6, None, &mut buffer).unwrap();
let columnar = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("srical.value").unwrap();
@@ -153,6 +153,46 @@ fn test_dataframe_writer_numerical() {
assert_eq!(column_i64.first(6), None); //< we can change the spec for that one.
}
#[test]
fn test_dataframe_sort_by_full() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(0u32, "value", NumericalValue::U64(1));
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(2));
let data = dataframe_writer.sort_order("value", 2, false);
assert_eq!(data, vec![0, 1]);
}
#[test]
fn test_dataframe_sort_by_opt() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(3));
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(2));
let data = dataframe_writer.sort_order("value", 5, false);
// 0, 2, 4 is 0.0
assert_eq!(data, vec![0, 2, 4, 3, 1]);
let data = dataframe_writer.sort_order("value", 5, true);
assert_eq!(
data,
vec![4, 2, 0, 3, 1].into_iter().rev().collect::<Vec<_>>()
);
}
#[test]
fn test_dataframe_sort_by_multi() {
let mut dataframe_writer = ColumnarWriter::default();
// valid for sort
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(2));
// those are ignored for sort
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(4));
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(4));
// valid for sort
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(3));
// ignored, would change sort order
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(1));
let data = dataframe_writer.sort_order("value", 4, false);
assert_eq!(data, vec![0, 2, 1, 3]);
}
#[test]
fn test_dictionary_encoded_str() {
let mut buffer = Vec::new();
@@ -161,7 +201,7 @@ fn test_dictionary_encoded_str() {
columnar_writer.record_str(3, "my.column", "c");
columnar_writer.record_str(3, "my.column2", "different_column!");
columnar_writer.record_str(4, "my.column", "b");
columnar_writer.serialize(5, &mut buffer).unwrap();
columnar_writer.serialize(5, None, &mut buffer).unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_columns(), 2);
let col_handles = columnar_reader.read_columns("my.column").unwrap();
@@ -195,7 +235,7 @@ fn test_dictionary_encoded_bytes() {
columnar_writer.record_bytes(3, "my.column", b"c");
columnar_writer.record_bytes(3, "my.column2", b"different_column!");
columnar_writer.record_bytes(4, "my.column", b"b");
columnar_writer.serialize(5, &mut buffer).unwrap();
columnar_writer.serialize(5, None, &mut buffer).unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_columns(), 2);
let col_handles = columnar_reader.read_columns("my.column").unwrap();
@@ -232,6 +272,93 @@ fn test_dictionary_encoded_bytes() {
assert_eq!(term_buffer, b"b");
}
#[test]
fn test_sort_order_str_asc_desc() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "z");
dataframe_writer.record_str(2, "s", "a");
dataframe_writer.record_str(3, "s", "m");
let asc = dataframe_writer.sort_order("s", 4, false);
assert_eq!(asc, vec![1, 2, 3, 0]); // None, a, m, z
let desc = dataframe_writer.sort_order("s", 4, true);
assert_eq!(desc, vec![0, 3, 2, 1]); // z, m, a, None
}
#[test]
fn test_sort_order_str_empty_vs_missing() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "");
let asc = dataframe_writer.sort_order("s", 2, false);
assert_eq!(asc, vec![1, 0]); // None first, then empty string
}
#[test]
fn test_sort_order_str_multivalued_stable() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "z");
dataframe_writer.record_str(0, "s", "a"); // multivalued; first value wins
dataframe_writer.record_str(1, "s", "b");
dataframe_writer.record_str(2, "s", "b");
let asc = dataframe_writer.sort_order("s", 3, false);
assert_eq!(asc, vec![1, 2, 0]); // b, b (stable), z
}
#[test]
fn test_sort_order_bytes_asc() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_bytes(1, "b", &[0x01]);
dataframe_writer.record_bytes(3, "b", &[0x00]);
let asc = dataframe_writer.sort_order("b", 4, false);
assert_eq!(asc, vec![0, 2, 3, 1]); // None, None, 0x00, 0x01
}
#[test]
fn test_sort_order_numeric_u64_above_2_24() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(0, "n", 16_777_217u64);
dataframe_writer.record_numerical(1, "n", 16_777_216u64);
let asc = dataframe_writer.sort_order("n", 2, false);
assert_eq!(asc, vec![1, 0]); // 16,777,216 then 16,777,217
}
#[test]
fn test_sort_order_numeric_u64_above_2_53() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(0, "n", 9_007_199_254_740_993u64);
dataframe_writer.record_numerical(1, "n", 9_007_199_254_740_992u64);
let asc = dataframe_writer.sort_order("n", 2, false);
assert_eq!(asc, vec![1, 0]); // smaller value first
}
#[test]
fn test_sort_order_numeric_null_vs_zero() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(0, "n", 0u64);
let asc = dataframe_writer.sort_order("n", 2, false);
assert_eq!(asc, vec![1, 0]); // None first, then 0
}
#[test]
fn test_sort_order_datetime_close_timestamps() {
let mut dataframe_writer = ColumnarWriter::default();
// Two timestamps 1 nanosecond apart. As f32, both round to the same value.
let dt1 = DateTime::from_timestamp_nanos(1_700_000_000_000_000_001);
let dt2 = DateTime::from_timestamp_nanos(1_700_000_000_000_000_000);
dataframe_writer.record_datetime(0, "ts", dt1);
dataframe_writer.record_datetime(1, "ts", dt2);
let asc = dataframe_writer.sort_order("ts", 2, false);
assert_eq!(asc, vec![1, 0]); // smaller timestamp first
}
fn num_strategy() -> impl Strategy<Value = NumericalValue> {
prop_oneof![
3 => Just(NumericalValue::U64(0u64)),
@@ -329,12 +456,26 @@ fn columnar_docs_strategy() -> impl Strategy<Value = Vec<Vec<(&'static str, Colu
.prop_flat_map(|num_docs| proptest::collection::vec(doc_strategy(), num_docs))
}
fn columnar_docs_and_mapping_strategy()
-> impl Strategy<Value = (Vec<Vec<(&'static str, ColumnValue)>>, Vec<RowId>)> {
columnar_docs_strategy().prop_flat_map(|docs| {
permutation_strategy(docs.len()).prop_map(move |permutation| (docs.clone(), permutation))
})
}
fn permutation_strategy(n: usize) -> impl Strategy<Value = Vec<RowId>> {
Just((0u32..n as RowId).collect()).prop_shuffle()
}
fn permutation_and_subset_strategy(n: usize) -> impl Strategy<Value = Vec<usize>> {
let vals: Vec<usize> = (0..n).collect();
subsequence(vals, 0..=n).prop_shuffle()
}
fn build_columnar_with_mapping(docs: &[Vec<(&'static str, ColumnValue)>]) -> ColumnarReader {
fn build_columnar_with_mapping(
docs: &[Vec<(&'static str, ColumnValue)>],
old_to_new_row_ids_opt: Option<&[RowId]>,
) -> ColumnarReader {
let num_docs = docs.len() as u32;
let mut buffer = Vec::new();
let mut columnar_writer = ColumnarWriter::default();
@@ -362,13 +503,15 @@ fn build_columnar_with_mapping(docs: &[Vec<(&'static str, ColumnValue)>]) -> Col
}
}
}
columnar_writer.serialize(num_docs, &mut buffer).unwrap();
columnar_writer
.serialize(num_docs, old_to_new_row_ids_opt, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}
fn build_columnar(docs: &[Vec<(&'static str, ColumnValue)>]) -> ColumnarReader {
build_columnar_with_mapping(docs)
build_columnar_with_mapping(docs, None)
}
fn assert_columnar_eq_strict(left: &ColumnarReader, right: &ColumnarReader) {
@@ -628,6 +771,54 @@ proptest! {
}
}
// Same as `test_single_columnar_builder_proptest` but with a shuffling mapping.
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_single_columnar_builder_with_shuffle_proptest((docs, mapping) in columnar_docs_and_mapping_strategy()) {
let columnar = build_columnar_with_mapping(&docs[..], Some(&mapping));
assert_eq!(columnar.num_docs() as usize, docs.len());
let mut expected_columns: HashMap<(&str, ColumnTypeCategory), HashMap<u32, Vec<&ColumnValue>> > = Default::default();
for (doc_id, doc_vals) in docs.iter().enumerate() {
for (col_name, col_val) in doc_vals {
expected_columns
.entry((col_name, col_val.column_type_category()))
.or_default()
.entry(mapping[doc_id])
.or_default()
.push(col_val);
}
}
let column_list = columnar.list_columns().unwrap();
assert_eq!(expected_columns.len(), column_list.len());
for (column_name, column) in column_list {
let dynamic_column = column.open().unwrap();
let col_category: ColumnTypeCategory = dynamic_column.column_type().into();
let expected_col_values: &HashMap<u32, Vec<&ColumnValue>> = expected_columns.get(&(column_name.as_str(), col_category)).unwrap();
for _doc_id in 0..columnar.num_docs() {
match &dynamic_column {
DynamicColumn::Bool(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::I64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::U64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::F64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::IpAddr(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::DateTime(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::Bytes(col) =>
assert_bytes_column_values(col, expected_col_values, false),
DynamicColumn::Str(col) =>
assert_bytes_column_values(col, expected_col_values, true),
}
}
}
}
}
// This tests create 2 or 3 random small columnar and attempts to merge them.
// It compares the resulting merged dataframe with what would have been obtained by building the
// dataframe from the concatenated rows to begin with.

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-common"
version = "0.10.0"
version = "0.11.0"
authors = ["Paul Masurel <paul@quickwit.io>", "Pascal Seitz <pascal@quickwit.io>"]
license = "MIT"
edition = "2024"
@@ -19,6 +19,6 @@ time = { version = "0.3.47", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.14.0"
binggan = "0.17.0"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -47,6 +47,9 @@ impl TinySet {
TinySet(val)
}
/// An empty `TinySet` constant.
pub const EMPTY: TinySet = TinySet(0u64);
/// Returns an empty `TinySet`.
#[inline]
pub fn empty() -> TinySet {

View File

@@ -121,7 +121,7 @@ pub struct FileSlice {
impl fmt::Debug for FileSlice {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "FileSlice({:?}, {:?})", &self.data, self.range)
write!(f, "FileSlice({:?}, {:?})", self.data, self.range)
}
}

View File

@@ -7,11 +7,6 @@
- [Other](#other)
- [Usage](#usage)
# Index Sorting has been removed!
More infos here:
https://github.com/quickwit-oss/tantivy/issues/2352
# Index Sorting
Tantivy allows you to sort the index according to a property.

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-query-grammar"
version = "0.25.0"
version = "0.26.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]

View File

@@ -327,7 +327,9 @@ fn exists(inp: &str) -> IResult<&str, UserInputLeaf> {
peek(alt((
value(
"",
satisfy(|c: char| c.is_whitespace() || ESCAPE_IN_WORD.contains(&c)),
satisfy(|c: char| {
c.is_whitespace() || (ESCAPE_IN_WORD.contains(&c) && c != '\\')
}),
),
eof,
))),
@@ -345,7 +347,9 @@ fn exists_precond(inp: &str) -> IResult<&str, (), ()> {
peek(alt((
value(
"",
satisfy(|c: char| c.is_whitespace() || ESCAPE_IN_WORD.contains(&c)),
satisfy(|c: char| {
c.is_whitespace() || (ESCAPE_IN_WORD.contains(&c) && c != '\\')
}),
),
eof,
))), // we need to check this isn't a wildcard query
@@ -707,6 +711,7 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
peek(alt((
value((), multispace1),
value((), char(')')),
value((), char('^')),
value((), eof),
))),
),
@@ -728,9 +733,10 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
peek(alt((
value((), multispace1),
value((), char(')')),
value((), char('^')),
value((), eof),
))),
"expected whitespace, closing parenthesis, or end of input",
"expected whitespace, closing parenthesis, boost, or end of input",
),
)(inp)
{
@@ -773,6 +779,10 @@ fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
value((), multispace1),
value((), char(')')),
value((), eof),
value(
(),
satisfy(|c: char| ESCAPE_IN_WORD.contains(&c) && c != '\\'),
),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
@@ -805,6 +815,10 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
value((), multispace1),
value((), char(')')),
value((), eof),
value(
(),
satisfy(|c: char| ESCAPE_IN_WORD.contains(&c) && c != '\\'),
),
))),
),
),
@@ -1045,18 +1059,43 @@ fn operand_leaf(inp: &str) -> IResult<&str, (Option<BinaryOperand>, Option<Occur
}
fn ast(inp: &str) -> IResult<&str, UserInputAst> {
let boolean_expr = map_res(
separated_pair(occur_leaf, multispace1, many1(operand_leaf)),
|(left, right)| aggregate_binary_expressions(left, right),
);
let single_leaf = map(occur_leaf, |(occur, ast)| {
if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
}
});
delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp)
// Parse `occur_leaf` once, then conditionally extend into a boolean
// expression. The previous implementation used `alt((boolean_expr,
// single_leaf))` which, when the input was a single leaf with no
// following operand, would parse `occur_leaf` once for `boolean_expr`,
// fail at `multispace1`, backtrack, then re-parse `occur_leaf` for
// `single_leaf`. With recursively-nested groups like `(+(+(+a)))`, that
// doubling at every level produced O(2^n) parse time. Parsing once and
// peeking ahead for the operand keeps it O(n).
delimited(
multispace0,
|inp| {
let (rest, first) = occur_leaf(inp)?;
// Only fall back on `Err::Error` (recoverable), mirroring
// `alt`'s behaviour. `Err::Failure` and `Err::Incomplete`
// must propagate so cut points and streaming needs are not
// accidentally swallowed if they are ever introduced in the
// operand parsers.
match preceded(multispace1, many1(operand_leaf))(rest) {
Ok((rest, more)) => {
let combined = aggregate_binary_expressions(first, more)
.map_err(|_| nom::Err::Error(Error::new(inp, ErrorKind::MapRes)))?;
Ok((rest, combined))
}
Err(nom::Err::Error(_)) => {
let (occur, ast) = first;
let single = if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
};
Ok((rest, single))
}
Err(e) => Err(e),
}
},
multispace0,
)(inp)
}
fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> {
@@ -1726,6 +1765,8 @@ mod test {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
// All query with boost
test_parse_query_to_ast_helper("*^2", "(*)^2");
}
#[test]
@@ -1788,6 +1829,7 @@ mod test {
test_parse_query_to_ast_helper("a:b*", "\"a\":b*");
test_parse_query_to_ast_helper("a:*b", "\"a\":*b");
test_parse_query_to_ast_helper(r#"a:*def*"#, "\"a\":*def*");
test_parse_query_to_ast_helper("a:*\\:foo", "\"a\":*:foo");
}
#[test]
@@ -1852,6 +1894,8 @@ mod test {
},
_ => panic!("Expected a leaf"),
}
// Regex followed by `^boost`
test_parse_query_to_ast_helper(r#"foo:/bar/^2"#, r#"("foo":/bar/)^2"#);
}
#[test]
@@ -1891,4 +1935,23 @@ mod test {
r#"(+"field":'happy tax payer' +"other_field":1)"#,
);
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2498:
// deeply nested parenthesized queries used to take O(2^n) time because the
// top-level `ast()` parser tried `boolean_expr` first and re-parsed the
// inner `occur_leaf` when it backtracked to `single_leaf`. Depth 60 would
// take ~10^18 operations under the regression; with the fix it parses
// instantly. We use `test_parse_query_to_ast_helper` so this test would
// never finish if the regression returned.
#[test]
fn test_parse_deeply_nested_query() {
let depth = 60;
let leading: String = "(".repeat(depth);
let trailing: String = ")".repeat(depth);
let query = format!("{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query, r#""title":test"#);
let query_with_plus = format!("+{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query_with_plus, r#""title":test"#);
}
}

View File

@@ -10,18 +10,18 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, CompositeAggReqData,
CompositeAggregation, CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData,
HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
build_segment_filter_collector, build_segment_histogram_collector,
build_segment_range_collector, CompositeAggReqData, CompositeAggregation,
CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, TermMissingAgg, TermsAggReqData,
TermsAggregation, TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
};
use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
@@ -41,7 +41,7 @@ pub struct AggregationsSegmentCtx {
impl AggregationsSegmentCtx {
pub(crate) fn push_term_req_data(&mut self, data: TermsAggReqData) -> usize {
self.per_request.term_req_data.push(Some(Box::new(data)));
self.per_request.term_req_data.push(data);
self.per_request.term_req_data.len() - 1
}
pub(crate) fn push_cardinality_req_data(&mut self, data: CardinalityAggReqData) -> usize {
@@ -61,31 +61,25 @@ impl AggregationsSegmentCtx {
self.per_request.missing_term_req_data.len() - 1
}
pub(crate) fn push_histogram_req_data(&mut self, data: HistogramAggReqData) -> usize {
self.per_request
.histogram_req_data
.push(Some(Box::new(data)));
self.per_request.histogram_req_data.push(data);
self.per_request.histogram_req_data.len() - 1
}
pub(crate) fn push_range_req_data(&mut self, data: RangeAggReqData) -> usize {
self.per_request.range_req_data.push(Some(Box::new(data)));
self.per_request.range_req_data.push(data);
self.per_request.range_req_data.len() - 1
}
pub(crate) fn push_filter_req_data(&mut self, data: FilterAggReqData) -> usize {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.push(data);
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.push(data);
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref()
.expect("term_req_data slot is empty (taken)")
&self.per_request.term_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data(&self, idx: usize) -> &CardinalityAggReqData {
@@ -103,116 +97,6 @@ impl AggregationsSegmentCtx {
pub(crate) fn get_missing_term_req_data(&self, idx: usize) -> &MissingTermAggReqData {
&self.per_request.missing_term_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data(&self, idx: usize) -> &HistogramAggReqData {
self.per_request.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_range_req_data(&self, idx: usize) -> &RangeAggReqData {
self.per_request.range_req_data[idx]
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
idx: usize,
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
.as_deref_mut()
.expect("histogram_req_data slot is empty (taken)")
}
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
self.per_request.histogram_req_data[idx]
.take()
.expect("histogram_req_data slot is empty (taken)")
}
/// Put back a Histogram request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_histogram_req_data(
&mut self,
idx: usize,
value: Box<HistogramAggReqData>,
) {
debug_assert!(self.per_request.histogram_req_data[idx].is_none());
self.per_request.histogram_req_data[idx] = Some(value);
}
/// Move out the boxed Range request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_range_req_data(&mut self, idx: usize) -> Box<RangeAggReqData> {
self.per_request.range_req_data[idx]
.take()
.expect("range_req_data slot is empty (taken)")
}
/// Put back a Range request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_range_req_data(&mut self, idx: usize, value: Box<RangeAggReqData>) {
debug_assert!(self.per_request.range_req_data[idx].is_none());
self.per_request.range_req_data[idx] = Some(value);
}
/// Move out the boxed Filter request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_filter_req_data(&mut self, idx: usize) -> Box<FilterAggReqData> {
self.per_request.filter_req_data[idx]
.take()
.expect("filter_req_data slot is empty (taken)")
}
/// Put back a Filter request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_filter_req_data(&mut self, idx: usize, value: Box<FilterAggReqData>) {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -223,15 +107,14 @@ impl AggregationsSegmentCtx {
/// for a node with [AggKind::Terms]).
#[derive(Default)]
pub struct PerRequestAggSegCtx {
// Box for cheap take/put - Only necessary for bucket aggs that have sub-aggregations
/// TermsAggReqData contains the request data for a terms aggregation.
pub term_req_data: Vec<Option<Box<TermsAggReqData>>>,
pub term_req_data: Vec<TermsAggReqData>,
/// HistogramAggReqData contains the request data for a histogram aggregation.
pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>,
pub histogram_req_data: Vec<HistogramAggReqData>,
/// RangeAggReqData contains the request data for a range aggregation.
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
pub range_req_data: Vec<RangeAggReqData>,
/// FilterAggReqData contains the request data for a filter aggregation.
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
pub filter_req_data: Vec<FilterAggReqData>,
/// Shared by avg, min, max, sum, stats, extended_stats, count
pub stats_metric_req_data: Vec<MetricAggReqData>,
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
@@ -241,7 +124,7 @@ pub struct PerRequestAggSegCtx {
/// MissingTermAggReqData contains the request data for a missing term aggregation.
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
pub composite_req_data: Vec<CompositeAggReqData>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -252,22 +135,22 @@ impl PerRequestAggSegCtx {
fn get_memory_consumption(&self) -> usize {
self.term_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.histogram_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.range_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.filter_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.stats_metric_req_data
@@ -292,7 +175,7 @@ impl PerRequestAggSegCtx {
+ self
.composite_req_data
.iter()
.map(|b| b.as_ref().map(|d| d.get_memory_consumption()).unwrap_or(0))
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -301,40 +184,16 @@ impl PerRequestAggSegCtx {
let idx = node.idx_in_req_data;
let kind = node.kind;
match kind {
AggKind::Terms => self.term_req_data[idx]
.as_deref()
.expect("term_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Terms => self.term_req_data[idx].name.as_str(),
AggKind::Cardinality => &self.cardinality_req_data[idx].name,
AggKind::StatsKind(_) => &self.stats_metric_req_data[idx].name,
AggKind::TopHits => &self.top_hits_req_data[idx].name,
AggKind::MissingTerm => &self.missing_term_req_data[idx].name,
AggKind::Histogram => self.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::DateHistogram => self.histogram_req_data[idx]
.as_deref()
.expect("histogram_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Range => self.range_req_data[idx]
.as_deref()
.expect("range_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Filter => self.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Histogram => self.histogram_req_data[idx].name.as_str(),
AggKind::DateHistogram => self.histogram_req_data[idx].name.as_str(),
AggKind::Range => self.range_req_data[idx].name.as_str(),
AggKind::Filter => self.filter_req_data[idx].name.as_str(),
AggKind::Composite => self.composite_req_data[idx].name.as_str(),
}
}
@@ -412,13 +271,39 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(TermMissingAgg::new(req, node)?))
}
AggKind::Cardinality => {
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
let req_data = req.get_cardinality_req_data(node.idx_in_req_data);
// For str columns, choose the per-bucket entries representation
// based on the segment's column.max_value():
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
// For non-str columns the `entries` field is unused (values go
// straight into the HLL sketch); we still pick `TermOrdSet`
// because its empty Sparse(FxHashSet) costs nothing.
let is_str = req_data.column_type == ColumnType::Str;
let max_term_ord_inclusive = if is_str {
req_data.accessor.max_value()
} else {
0
};
let collector: Box<dyn SegmentAggregationCollector> =
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
} else {
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
};
Ok(collector)
}
AggKind::StatsKind(stats_type) => {
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
@@ -433,7 +318,7 @@ pub(crate) fn build_segment_agg_collector(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
let req_data = req.get_metric_req_data(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
@@ -453,12 +338,8 @@ pub(crate) fn build_segment_agg_collector(
req_data.segment_ordinal,
)))
}
AggKind::Histogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Histogram => build_segment_histogram_collector(req, node),
AggKind::DateHistogram => build_segment_histogram_collector(req, node),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(
@@ -773,23 +654,18 @@ fn build_nodes(
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(schema, tokenizers)?;
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
reader,
)?;
// Pre-allocate buffer for batch filtering
let max_doc = reader.max_doc();
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
let matching_docs_buffer = Vec::with_capacity(buffer_capacity);
let evaluator =
std::rc::Rc::new(crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
reader,
)?);
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
name: agg_name.to_string(),
req: filter_req.clone(),
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
@@ -985,8 +861,12 @@ fn build_terms_or_cardinality_nodes(
let str_col = str_dict_column
.as_ref()
.expect("str_dict_column must exist for string column");
allowed_term_ids =
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
allowed_term_ids = build_allowed_term_ids_for_str(
str_col,
&req.include,
&req.exclude,
missing.is_some(),
)?;
};
let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
accessor,
@@ -1002,10 +882,20 @@ fn build_terms_or_cardinality_nodes(
(idx_in_req_data, AggKind::Terms)
}
TermsOrCardinalityRequest::Cardinality(ref req) => {
// `str_dict_column` is computed once per field; for JSON paths
// with mixed types it's `Some` even on the numeric req_data.
// Cardinality only consults it for the str column path, so
// gate by column_type to avoid driving non-str collectors
// through the coupon-cache path.
let str_dict_column_for_req = if column_type == ColumnType::Str {
str_dict_column.clone()
} else {
None
};
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
accessor,
column_type,
str_dict_column: str_dict_column.clone(),
str_dict_column: str_dict_column_for_req,
missing_value_for_accessor,
name: agg_name.to_string(),
req: req.clone(),
@@ -1025,16 +915,21 @@ fn build_terms_or_cardinality_nodes(
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
///
/// When `reserve_missing_sentinel` is true, the bitset will have 1 additional slot for the missing
/// term ordinal
fn build_allowed_term_ids_for_str(
str_col: &StrColumn,
include: &Option<IncludeExcludeParam>,
exclude: &Option<IncludeExcludeParam>,
reserve_missing_sentinel: bool,
) -> crate::Result<Option<BitSet>> {
let mut allowed: Option<BitSet> = None;
let num_terms = str_col.dictionary().num_terms() as u32;
let missing_sentinel_adjustment = if reserve_missing_sentinel { 1 } else { 0 };
let allowed_capacity = str_col.dictionary().num_terms() as u32 + missing_sentinel_adjustment;
if let Some(include) = include {
// add matches
allowed = Some(BitSet::with_max_value(num_terms));
allowed = Some(BitSet::with_max_value(allowed_capacity));
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
};
@@ -1042,7 +937,7 @@ fn build_allowed_term_ids_for_str(
if let Some(exclude) = exclude {
if allowed.is_none() {
// Start with all terms allowed
allowed = Some(BitSet::with_max_value_and_full(num_terms));
allowed = Some(BitSet::with_max_value_and_full(allowed_capacity));
}
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;

View File

@@ -115,6 +115,71 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Validates that all fields referenced in the aggregation request exist in the schema
/// and are configured as fast fields.
///
/// This is a convenience function for upfront validation before executing aggregations.
/// Returns an error if any field doesn't exist or is not a fast field.
///
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
/// default lenient behavior (returning empty results for missing fields) supports
/// schema evolution and federated queries where the same request runs against segments
/// or indices with different schemas.
///
/// # Example
/// ```
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
/// use tantivy::schema::{Schema, FAST};
/// use tantivy::Index;
///
/// # fn main() -> tantivy::Result<()> {
/// // Create a simple index
/// let mut schema_builder = Schema::builder();
/// schema_builder.add_f64_field("price", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// // Parse aggregation request
/// let agg_req: Aggregations = serde_json::from_str(r#"{
/// "avg_price": { "avg": { "field": "price" } }
/// }"#)?;
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Validate fields before executing
/// for segment_reader in searcher.segment_readers() {
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
/// }
/// # Ok(())
/// # }
/// ```
pub fn validate_aggregation_fields_exist(
aggs: &Aggregations,
reader: &crate::SegmentReader,
) -> crate::Result<()> {
let field_names = get_fast_field_names(aggs);
let schema = reader.schema();
for field_name in field_names {
// Check if the field is either directly in the schema or could be part of a json field
// present in the schema, and verify it's a fast field.
if let Some((field, _path)) = schema.find_field(&field_name) {
let field_type = schema.get_field_entry(field).field_type();
if !field_type.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field '{}' is not a fast field. Aggregations require fast fields.",
field_name
)));
}
} else {
return Err(crate::TantivyError::FieldNotFound(field_name));
}
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {
@@ -234,6 +299,12 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_sum(&self) -> Option<&SumAggregation> {
match &self {
AggregationVariants::Sum(sum) => Some(sum),
_ => None,
}
}
}
#[cfg(test)]

View File

@@ -208,7 +208,8 @@ pub enum BucketEntries<T> {
}
impl<T> BucketEntries<T> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
/// Iterate over all bucket entries.
pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),

View File

@@ -1436,3 +1436,46 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
)
);
}
#[test]
fn test_aggregation_field_validation_helper() {
// Test the standalone validation helper function for field validation
let index = get_test_index_2_segments(false).unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
// Test with invalid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "nonexistent_field" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_err());
match result {
Err(crate::TantivyError::FieldNotFound(field_name)) => {
assert_eq!(field_name, "nonexistent_field");
}
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
}
// Test with valid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "score" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_ok());
}

View File

@@ -16,6 +16,7 @@ use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
#[derive(Debug, Clone)]
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
@@ -34,6 +35,7 @@ impl CompositeAggReqData {
}
/// Accessors for a single column in a composite source.
#[derive(Debug, Clone)]
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
@@ -48,6 +50,7 @@ pub struct CompositeAccessor {
}
/// Accessors to all the columns that belong to the field of a composite source.
#[derive(Debug, Clone)]
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
@@ -358,7 +361,7 @@ impl PrecomputedDateInterval {
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),

View File

@@ -21,7 +21,7 @@ use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardSubAggCache};
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardSubAggBuffer};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
@@ -118,8 +118,8 @@ impl InternalValueRepr {
pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
accessor_idx: usize,
sub_agg: Option<CachedSubAggs<HighCardSubAggCache>>,
req_data: CompositeAggReqData,
sub_agg: Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
num_sources: usize,
@@ -132,10 +132,7 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let name = self.req_data.name.clone();
let buckets = self.add_intermediate_bucket_result(agg_data, parent_bucket_id)?;
results.push(
@@ -152,13 +149,12 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption(parent_bucket_id);
for doc in docs {
let mut visitor = CompositeKeyVisitor {
doc_id: *doc,
composite_agg_data: &composite_agg_data,
composite_agg_data: &self.req_data,
buckets: &mut self.parent_buckets[parent_bucket_id as usize],
sub_agg: &mut self.sub_agg,
bucket_id_provider: &mut self.bucket_id_provider,
@@ -166,13 +162,12 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
};
visitor.visit(0, true)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
@@ -199,36 +194,49 @@ impl SegmentAggregationCollector for SegmentCompositeCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Composite is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self) -> u64 {
self.parent_buckets
.iter()
.map(|m| m.memory_consumption())
.sum()
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
validate_req(req_data, node.idx_in_req_data)?;
let composite_req_data =
req_data.per_request.composite_req_data[node.idx_in_req_data].clone();
validate_req(&composite_req_data)?;
req_data
.context
.limits
.add_memory_consumed(composite_req_data.get_memory_consumption() as u64)?;
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_agg_collector = build_segment_agg_collectors(req_data, &node.children)?;
Some(CachedSubAggs::new(sub_agg_collector))
Some(BufferedSubAggs::new(sub_agg_collector))
} else {
None
};
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
let num_sources = composite_req_data.req.sources.len();
Ok(SegmentCompositeCollector {
parent_buckets: vec![DynArrayHeapMap::try_new(num_sources)?],
accessor_idx: node.idx_in_req_data,
req_data: composite_req_data,
sub_agg,
bucket_id_provider: BucketIdProvider::default(),
num_sources,
@@ -250,7 +258,7 @@ impl SegmentCompositeCollector {
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(heap_map.size());
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
let composite_data = &self.req_data;
for (key_internal_repr, agg) in heap_map.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
@@ -290,8 +298,7 @@ impl SegmentCompositeCollector {
}
}
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
fn validate_req(composite_data: &CompositeAggReqData) -> crate::Result<()> {
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(
@@ -332,7 +339,7 @@ fn collect_bucket_with_limit(
limit_num_buckets: usize,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
sub_agg: &mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: &mut BucketIdProvider,
) {
let mut record_in_bucket = |bucket: &mut CompositeBucketCollector| {
@@ -488,7 +495,7 @@ struct CompositeKeyVisitor<'a> {
doc_id: crate::DocId,
composite_agg_data: &'a CompositeAggReqData,
buckets: &'a mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
sub_agg: &'a mut Option<CachedSubAggs<HighCardSubAggCache>>,
sub_agg: &'a mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: &'a mut BucketIdProvider,
sub_level_values: SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
}

View File

@@ -511,14 +511,14 @@ mod tests {
fn datetime_from_iso_str(date_str: &str) -> common::DateTime {
let dt = OffsetDateTime::parse(date_str, &Rfc3339)
.expect(&format!("Failed to parse date: {}", date_str));
.unwrap_or_else(|_| panic!("Failed to parse date: {}", date_str));
let timestamp_secs = dt.unix_timestamp_nanos();
common::DateTime::from_timestamp_nanos(timestamp_secs as i64)
}
fn ms_timestamp_from_iso_str(date_str: &str) -> i64 {
let dt = OffsetDateTime::parse(date_str, &Rfc3339)
.expect(&format!("Failed to parse date: {}", date_str));
.unwrap_or_else(|_| panic!("Failed to parse date: {}", date_str));
(dt.unix_timestamp_nanos() / 1_000_000) as i64
}
@@ -548,7 +548,7 @@ mod tests {
agg_req_json["my_composite"]["composite"]["after"] = after_key.take().unwrap();
}
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
let expected_page_buckets = &expected_buckets_vec[page_idx * page_size
..std::cmp::min((page_idx + 1) * page_size, expected_buckets_vec.len())];
assert_eq!(
@@ -559,34 +559,30 @@ mod tests {
page_size,
agg_req,
);
if page_idx + 1 < page_count {
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on all but last page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
} else if res["my_composite"].get("after_key").is_some() {
// currently we sometime have an after_key on the last page,
// check that the next "page" is empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": res["my_composite"]["after_key"].clone(),
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
assert!(
res["my_composite"].get("after_key").is_some(),
"expected after_key on every non-empty page"
);
after_key = Some(res["my_composite"]["after_key"].clone());
}
// Using the after_key from the last page must yield an empty page.
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": composite_agg_sources,
"size": page_size,
"after": after_key,
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), index).unwrap();
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
}
}
@@ -711,8 +707,28 @@ mod tests {
{"key": {"myterm": "terme"}, "doc_count": 1}
])
);
assert!(res["my_composite"].get("after_key").is_none());
// paginating past last page should be empty
let agg_req_json = json!({
"my_composite": {
"composite": {
"sources": [
{"myterm": {"terms": {"field": "string_id"}}}
],
"size": 3,
"after": &res["my_composite"]["after_key"]
}
}
});
let agg_req: Aggregations = serde_json::from_value(agg_req_json).unwrap();
let res = exec_request(agg_req.clone(), &index).unwrap();
assert!(res["my_composite"].get("after_key").is_none());
assert_eq!(
res["my_composite"]["buckets"],
json!([]),
"expected no buckets when using after_key from last page, query: {:?}",
agg_req
);
Ok(())
}
@@ -820,7 +836,10 @@ mod tests {
{"key": {"myterm": "apple"}, "doc_count": 1}
])
);
assert!(res["fruity_aggreg"].get("after_key").is_none());
assert_eq!(
res["fruity_aggreg"]["after_key"],
json!({"myterm": "str:apple"})
);
Ok(())
}
@@ -1792,7 +1811,14 @@ mod tests {
{"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1},
]),
);
assert!(res["my_composite"].get("after_key").is_none());
let feb_2021_ns = ms_timestamp_from_iso_str("2021-02-01T00:00:00Z") * 1_000_000;
assert_eq!(
res["my_composite"]["after_key"],
json!({
"month": format!("dt:{}", feb_2021_ns),
"category": "str:books"
})
);
Ok(())
}

View File

@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::rc::Rc;
use common::BitSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -6,8 +7,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardSubAggBuffer, SubAggBuffer,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -396,6 +397,7 @@ impl PartialEq for FilterAggregation {
/// Request data for filter aggregation
/// This struct holds the per-segment data needed to execute a filter aggregation
#[derive(Clone)]
pub struct FilterAggReqData {
/// The name of the filter aggregation
pub name: String,
@@ -403,22 +405,20 @@ pub struct FilterAggReqData {
pub req: FilterAggregation,
/// The segment reader
pub segment_reader: SegmentReader,
/// Document evaluator for the filter query (precomputed BitSet)
/// This is built once when the request data is created
pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>,
/// Document evaluator for the filter query (precomputed BitSet).
/// Wrapped in `Rc` so cloning the request data does not duplicate the (potentially large)
/// underlying BitSet.
pub evaluator: Rc<DocumentQueryEvaluator>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
// Estimate: name + segment reader reference + bitset
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
}
}
@@ -503,21 +503,24 @@ struct DocCount {
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<C: SubAggCache> {
pub struct SegmentFilterCollector<B: SubAggBuffer> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<CachedSubAggs<C>>,
sub_aggregations: Option<BufferedSubAggs<B>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
/// Per-segment filter request data, owned by this collector.
req_data: FilterAggReqData,
/// Reusable buffer for matching documents to minimize allocations during collection.
matching_docs_buffer: Vec<DocId>,
}
impl<C: SubAggCache> SegmentFilterCollector<C> {
impl<B: SubAggBuffer> SegmentFilterCollector<B> {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
req_data: FilterAggReqData,
) -> crate::Result<Self> {
// Build sub-aggregation collectors if any
let sub_agg_collector = if !node.children.is_empty() {
@@ -525,13 +528,17 @@ impl<C: SubAggCache> SegmentFilterCollector<C> {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
let sub_agg_collector = sub_agg_collector.map(BufferedSubAggs::new);
let max_doc = req_data.segment_reader.max_doc();
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
req_data,
bucket_id_provider: BucketIdProvider::default(),
matching_docs_buffer: Vec::with_capacity(buffer_capacity),
})
}
}
@@ -540,33 +547,38 @@ pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
let req_data = req.per_request.filter_req_data[node.idx_in_req_data].clone();
req.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let is_top_level = req_data.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<LowCardSubAggBuffer>::from_req_and_validate(
req, node, req_data,
)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<HighCardSubAggBuffer>::from_req_and_validate(
req, node, req_data,
)?,
))
}
}
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.field("name", &self.req_data.name)
.finish()
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -598,11 +610,7 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
};
// Get the name of this filter aggregation
let name = agg_data.per_request.filter_req_data[self.accessor_idx]
.as_ref()
.expect("filter_req_data slot is empty")
.name
.clone();
let name = self.req_data.name.clone();
results.push(
name,
@@ -623,27 +631,24 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
// Use batch filtering with O(1) BitSet lookups
req.matching_docs_buffer.clear();
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
self.matching_docs_buffer.clear();
self.req_data
.evaluator
.filter_batch(docs, &mut self.matching_docs_buffer);
bucket.doc_count += req.matching_docs_buffer.len() as u64;
bucket.doc_count += self.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if !self.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
for &doc_id in &req.matching_docs_buffer {
for &doc_id in &self.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
@@ -674,6 +679,17 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
None
}
}
/// Intermediate result for filter aggregation

View File

@@ -10,7 +10,7 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
@@ -21,6 +21,7 @@ use crate::TantivyError;
/// Contains all information required by the SegmentHistogramCollector to perform the
/// histogram or date_histogram aggregation on a segment.
#[derive(Debug, Clone)]
pub struct HistogramAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
@@ -243,22 +244,55 @@ impl Display for HistogramBounds {
}
impl HistogramBounds {
fn contains(&self, val: f64) -> bool {
pub(crate) fn contains(&self, val: f64) -> bool {
val >= self.min && val <= self.max
}
}
#[derive(Default, Clone, Debug, PartialEq)]
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
/// The per-bucket identifier stored in a [`SegmentHistogramBucketEntry`].
///
/// It is [`BucketId`] when the histogram has sub aggregations (which key their state by it), and
/// the zero-sized `()` when it does not. Without sub aggregations the id is never read, so storing
/// `()` drops 8 bytes per bucket (24 -> 16) and turns id assignment into a no-op.
pub trait BucketIdSlot: Copy + Default + std::fmt::Debug + PartialEq {
/// Assigns the next id from the provider, called once when a bucket is first filled.
fn assign(provider: &mut BucketIdProvider) -> Self;
/// Resolves to the `BucketId` for sub-aggregation bookkeeping.
///
/// Only ever called for the [`BucketId`] slot: the `()` slot is used exactly when there are no
/// sub aggregations, so every call site is guarded by `sub_agg.is_some()` and is dead for `()`.
fn to_bucket_id(self) -> BucketId;
}
impl BucketIdSlot for BucketId {
#[inline(always)]
fn assign(provider: &mut BucketIdProvider) -> Self {
provider.next_bucket_id()
}
#[inline(always)]
fn to_bucket_id(self) -> BucketId {
self
}
}
impl BucketIdSlot for () {
#[inline(always)]
fn assign(_provider: &mut BucketIdProvider) -> Self {}
#[inline(always)]
fn to_bucket_id(self) -> BucketId {
unreachable!("bucket ids are only resolved when sub aggregations are present")
}
}
impl SegmentHistogramBucketEntry {
#[derive(Default, Clone, Debug, PartialEq)]
pub(crate) struct SegmentHistogramBucketEntry<B> {
pub key: f64,
pub doc_count: u64,
pub bucket_id: B,
}
impl<B: BucketIdSlot> SegmentHistogramBucketEntry<B> {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
sub_aggregation: &mut Option<HighCardBufferedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
@@ -268,7 +302,7 @@ impl SegmentHistogramBucketEntry {
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
self.bucket_id.to_bucket_id(),
)?;
}
Ok(IntermediateHistogramBucketEntry {
@@ -279,34 +313,147 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
/// The contiguous bucket range a histogram can span, derived from the column min/max (clamped to
/// the histogram bounds). Buckets in `[base_pos, base_pos + len)` can be stored in a flat `Vec`
/// indexed by `bucket_pos - base_pos`, avoiding the hash map on the hot path.
#[derive(Clone, Copy, Debug)]
pub(crate) struct DenseRange {
/// `bucket_pos` mapped to index 0 of the dense `Vec`.
pub(crate) base_pos: i64,
/// Number of bucket positions in the range.
pub(crate) len: usize,
}
/// Storage for the histogram buckets of a single parent bucket.
///
/// Starts out sparse (a hash map keyed by `bucket_pos`). Once enough distinct buckets have been
/// filled that we are clearly going to cover most of the column's theoretical range, it switches
/// to a dense `Vec` indexed by `bucket_pos - base_pos`, which removes hashing from the hot loop.
#[derive(Clone, Debug)]
enum HistogramBuckets<B> {
Sparse(FxHashMap<i64, SegmentHistogramBucketEntry<B>>),
Dense {
base_pos: i64,
/// One slot per bucket position; a slot with `doc_count == 0` has not been hit yet.
buckets: Vec<SegmentHistogramBucketEntry<B>>,
},
}
impl<B> Default for HistogramBuckets<B> {
fn default() -> Self {
HistogramBuckets::Sparse(FxHashMap::default())
}
}
impl<B: BucketIdSlot> HistogramBuckets<B> {
fn memory_consumption(&self) -> u64 {
let num_slots = match self {
HistogramBuckets::Sparse(map) => map.capacity(),
HistogramBuckets::Dense { buckets, .. } => buckets.capacity(),
};
num_slots as u64 * std::mem::size_of::<SegmentHistogramBucketEntry<B>>() as u64
}
/// Switches from sparse to dense storage once the dense `Vec` would use no more memory than the
/// hash map does now, so the switch never increases memory. Called at block boundaries.
///
/// The `Vec` holds one `Entry` per bucket position in the range. The map additionally stores
/// the key and a control byte per slot, at a load factor of 7/16..7/8, so for a dense histogram
/// its footprint grows past the `Vec` well before full coverage. And since the `Vec` never
/// grows afterwards while the map would keep growing, dense only gets relatively cheaper — so
/// no upper bound on the range is needed: a large but sparse range simply never crosses over.
#[inline]
fn maybe_densify(&mut self, dense_range: Option<DenseRange>) {
let Some(range) = dense_range else { return };
let HistogramBuckets::Sparse(map) = self else {
return;
};
let dense_bytes = range
.len
.saturating_mul(std::mem::size_of::<SegmentHistogramBucketEntry<B>>());
let sparse_bytes = map
.capacity()
.saturating_mul(std::mem::size_of::<(i64, SegmentHistogramBucketEntry<B>)>() + 1);
if dense_bytes > sparse_bytes {
return;
}
let map = std::mem::take(map);
let mut buckets = vec![SegmentHistogramBucketEntry::<B>::default(); range.len];
for (bucket_pos, entry) in map {
buckets[(bucket_pos - range.base_pos) as usize] = entry;
}
*self = HistogramBuckets::Dense {
base_pos: range.base_pos,
buckets,
};
}
/// Returns the bucket entry for `bucket_pos`, setting its key (and `bucket_id`, when `B` is
/// [`BucketId`]) on first use.
///
/// For the dense variant `bucket_pos` is guaranteed to be inside the range, since it is
/// derived from the column min/max that bounds every value (see [`compute_dense_range`]).
#[inline]
fn get_or_create(
&mut self,
bucket_pos: i64,
bucket_id_provider: &mut BucketIdProvider,
key_from_pos: impl FnOnce(i64) -> f64,
) -> &mut SegmentHistogramBucketEntry<B> {
match self {
HistogramBuckets::Sparse(map) => {
map.entry(bucket_pos)
.or_insert_with(|| SegmentHistogramBucketEntry {
key: key_from_pos(bucket_pos),
doc_count: 0,
bucket_id: B::assign(bucket_id_provider),
})
}
HistogramBuckets::Dense { base_pos, buckets } => {
let idx = (bucket_pos - *base_pos) as usize;
debug_assert!(idx < buckets.len(), "bucket_pos outside the dense range");
let entry = &mut buckets[idx];
if entry.doc_count == 0 {
entry.key = key_from_pos(bucket_pos);
entry.bucket_id = B::assign(bucket_id_provider);
}
entry
}
}
}
/// Consumes the storage, yielding all non-empty bucket entries.
fn into_filled_entries(self) -> Vec<SegmentHistogramBucketEntry<B>> {
match self {
HistogramBuckets::Sparse(map) => map.into_values().collect(),
HistogramBuckets::Dense { buckets, .. } => {
buckets.into_iter().filter(|b| b.doc_count > 0).collect()
}
}
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
pub struct SegmentHistogramCollector {
pub struct SegmentHistogramCollector<B> {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
accessor_idx: usize,
parent_buckets: Vec<HistogramBuckets<B>>,
sub_agg: Option<HighCardBufferedSubAggs>,
req_data: HistogramAggReqData,
bucket_id_provider: BucketIdProvider,
/// Theoretical bucket range derived from the column min/max, if dense `Vec` storage is
/// viable. `None` keeps every parent bucket in the sparse hash map.
dense_range: Option<DenseRange>,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<B> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
let name = self.req_data.name.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
@@ -323,10 +470,13 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let dense_range = self.dense_range;
let store = &mut self.parent_buckets[parent_bucket_id as usize];
// Upgrade to dense storage before processing the block if the buckets are dense enough.
store.maybe_densify(dense_range);
let req = &self.req_data;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
@@ -335,35 +485,43 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
// special path for nested buckets
if let Some(sub_agg) = &mut self.sub_agg {
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
if bounds.contains(val) {
let bucket = store.get_or_create(
get_bucket_pos(val),
&mut self.bucket_id_provider,
|pos| get_bucket_key_from_pos(pos as f64, interval, offset),
);
bucket.doc_count += 1;
sub_agg.push(bucket.bucket_id.to_bucket_id(), doc);
}
}
} else {
for val in agg_data.column_block_accessor.iter_vals() {
let val = f64_from_fastfield_u64(val, req.field_type);
if bounds.contains(val) {
let bucket = store.get_or_create(
get_bucket_pos(val),
&mut self.bucket_id_provider,
|pos| get_bucket_key_from_pos(pos as f64, interval, offset),
);
bucket.doc_count += 1;
}
}
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
// `checked_sub` is `None` when densifying shrank the accounted memory; only account growth.
if let Some(mem_delta) = self
.get_memory_consumption(parent_bucket_id)
.checked_sub(mem_pre)
{
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
@@ -386,39 +544,45 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
self.parent_buckets.push(HistogramBuckets::default());
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Histogram is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
impl<B: BucketIdSlot> SegmentHistogramCollector<B> {
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
histogram: HistogramBuckets<B>,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(histogram.buckets.len());
let filled = histogram.into_filled_entries();
let mut buckets = Vec::with_capacity(filled.len());
for bucket in histogram.buckets.into_values() {
for bucket in filled {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
}
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
let is_date_agg = agg_data
.get_histogram_req_data(self.accessor_idx)
.field_type
== ColumnType::DateTime;
let is_date_agg = self.req_data.field_type == ColumnType::DateTime;
Ok(IntermediateBucketResult::Histogram {
buckets,
is_date_agg,
@@ -434,32 +598,175 @@ impl SegmentHistogramCollector {
} else {
None
};
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
let mut req_data = agg_data.per_request.histogram_req_data[node.idx_in_req_data].clone();
normalize_histogram_req(&mut req_data)?;
agg_data
.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let dense_range = compute_dense_range(
&req_data.accessor,
req_data.field_type,
req_data.req.interval,
req_data.offset,
req_data.bounds,
);
let sub_agg = sub_agg.map(BufferedSubAggs::new);
Ok(Self {
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
req_data,
bucket_id_provider: BucketIdProvider::default(),
dense_range,
})
}
}
impl SegmentHistogramCollector<()> {
/// Builds a histogram collector whose parent `t` is a dense histogram filled from
/// `counts[t * num_time_buckets .. (t + 1) * num_time_buckets]` (row-major). Used by the fused
/// terms×histogram collector to turn its flat 2D counters into the regular intermediate result,
/// so cross-segment merging is shared with the general path.
pub(crate) fn from_dense_rows(
req_data: HistogramAggReqData,
base_pos: i64,
num_time_buckets: usize,
counts: &[u32],
) -> Self {
let interval = req_data.req.interval;
let offset = req_data.offset;
let num_parents = counts.len().checked_div(num_time_buckets).unwrap_or(0);
let parent_buckets = (0..num_parents)
.map(|t| {
let row = &counts[t * num_time_buckets..(t + 1) * num_time_buckets];
let buckets = row
.iter()
.enumerate()
.map(|(b, &doc_count)| SegmentHistogramBucketEntry {
key: get_bucket_key_from_pos(
(base_pos + b as i64) as f64,
interval,
offset,
),
doc_count: doc_count as u64,
bucket_id: (),
})
.collect();
HistogramBuckets::Dense { base_pos, buckets }
})
.collect();
Self {
parent_buckets,
sub_agg: None,
req_data,
bucket_id_provider: BucketIdProvider::default(),
dense_range: None,
}
}
}
/// Validates and normalizes a histogram request in place: applies date ns-normalization (for a
/// `histogram` on a date column) and resolves `bounds`/`offset` from the request.
fn normalize_histogram_req(req_data: &mut HistogramAggReqData) -> crate::Result<()> {
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
// Drop `hard_bounds` that can't exclude any value (the column's range already sits inside
// them): the per-doc `bounds.contains` check is then a no-op, so collapsing to the unbounded
// sentinel lets the histogram hot loop skip it and the fused term×histogram path derive
// per-term counts from the grid. Only this collect-time filter is touched — empty-bucket
// emission reads `req.hard_bounds` directly (see `get_req_min_max`), and `hard_bounds` only
// ever clips that range, so a wider-than-data bound leaves the result unchanged.
if req_data.req.hard_bounds.is_some() {
let col_min = f64_from_fastfield_u64(req_data.accessor.min_value(), req_data.field_type);
let col_max = f64_from_fastfield_u64(req_data.accessor.max_value(), req_data.field_type);
if col_min >= req_data.bounds.min && col_max <= req_data.bounds.max {
req_data.bounds = HistogramBounds {
min: f64::MIN,
max: f64::MAX,
};
}
}
Ok(())
}
/// Clones and normalizes (resolving interval/offset/bounds) the histogram request at `node`, and
/// returns it together with its dense bucket range — or `None` if the column has no usable range.
/// Used by the fused terms×histogram collector, which then owns the normalized request.
pub(crate) fn prepare_histogram_dense_range(
agg_data: &AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Option<(HistogramAggReqData, DenseRange)>> {
let mut req_data = agg_data.per_request.histogram_req_data[node.idx_in_req_data].clone();
normalize_histogram_req(&mut req_data)?;
let dense_range = compute_dense_range(
&req_data.accessor,
req_data.field_type,
req_data.req.interval,
req_data.offset,
req_data.bounds,
);
Ok(dense_range.map(|range| (req_data, range)))
}
/// Builds a boxed histogram (or date histogram) segment collector, picking the bucket-id storage
/// based on whether there are sub aggregations: `()` (no id stored) when there are none, otherwise
/// [`BucketId`].
pub(crate) fn build_segment_histogram_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
if node.children.is_empty() {
Ok(Box::new(
SegmentHistogramCollector::<()>::from_req_and_validate(agg_data, node)?,
))
} else {
Ok(Box::new(
SegmentHistogramCollector::<BucketId>::from_req_and_validate(agg_data, node)?,
))
}
}
#[inline]
fn get_bucket_pos_f64(val: f64, interval: f64, offset: f64) -> f64 {
pub(crate) fn get_bucket_pos_f64(val: f64, interval: f64, offset: f64) -> f64 {
((val - offset) / interval).floor()
}
/// Computes the dense bucket range for a column from its min/max value (clamped to the histogram
/// bounds), or `None` if there are no values within bounds (or the range overflows `usize`).
///
/// There is no upper bound on the range: whether dense storage is actually used is decided later,
/// per parent bucket, by [`HistogramBuckets::maybe_densify`] based on the memory it would save.
///
/// The column min/max bound every value the collector can see, so a `Vec` sized to this range can
/// be indexed by `bucket_pos - base_pos` without any out-of-bounds check on the hot path.
fn compute_dense_range(
accessor: &Column<u64>,
field_type: ColumnType,
interval: f64,
offset: f64,
bounds: HistogramBounds,
) -> Option<DenseRange> {
let col_min = f64_from_fastfield_u64(accessor.min_value(), field_type);
let col_max = f64_from_fastfield_u64(accessor.max_value(), field_type);
let lo = col_min.max(bounds.min);
let hi = col_max.min(bounds.max);
if lo > hi {
return None;
}
let base_pos = get_bucket_pos_f64(lo, interval, offset) as i64;
let top_pos = get_bucket_pos_f64(hi, interval, offset) as i64;
let len = usize::try_from(top_pos.checked_sub(base_pos)?.checked_add(1)?).ok()?;
(len > 0).then_some(DenseRange { base_pos, len })
}
#[inline]
fn get_bucket_key_from_pos(bucket_pos: f64, interval: f64, offset: f64) -> f64 {
bucket_pos * interval + offset
@@ -764,6 +1071,62 @@ mod tests {
Ok(())
}
#[test]
fn histogram_dense_storage_test() -> crate::Result<()> {
histogram_dense_storage_test_with_opt(false)?;
histogram_dense_storage_test_with_opt(true)?;
Ok(())
}
/// Exercises the switch from sparse hash map to dense `Vec` storage. The switch happens at a
/// block boundary (a block is `COLLECT_BLOCK_BUFFER_LEN` = 64 docs), so we need many docs in a
/// single segment, densely covering the bucket range. `with_sub_agg` toggles the `iter_vals`
/// fast path vs. the `iter_docid_vals` path used when there is a sub aggregation.
fn histogram_dense_storage_test_with_opt(with_sub_agg: bool) -> crate::Result<()> {
let num_buckets = 50usize;
let docs_per_bucket = 10usize;
// Value `k` repeated `docs_per_bucket` times for each bucket `k`, so every value in bucket
// `k` equals `k` and the per-bucket average is exactly `k`.
let values: Vec<f64> = (0..num_buckets * docs_per_bucket)
.map(|i| (i % num_buckets) as f64)
.collect();
// `merge_segments = true` collapses the per-value segments into a single segment with all
// the docs, which is collected in 64-doc blocks and therefore switches to dense storage.
let index = get_test_index_from_values(true, &values)?;
let agg_req: Aggregations = serde_json::from_value(if with_sub_agg {
json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 },
"aggs": { "avg": { "avg": { "field": "score_f64" } } }
}
})
} else {
json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 }
}
})
})
.unwrap();
let res = exec_request(agg_req, &index)?;
for k in 0..num_buckets {
assert_eq!(res["histogram"]["buckets"][k]["key"], k as f64);
assert_eq!(
res["histogram"]["buckets"][k]["doc_count"],
docs_per_bucket as u64
);
if with_sub_agg {
assert_eq!(res["histogram"]["buckets"][k]["avg"]["value"], k as f64);
}
}
assert_eq!(res["histogram"]["buckets"][num_buckets], Value::Null);
Ok(())
}
#[test]
fn histogram_memory_limit() -> crate::Result<()> {
let index = get_test_index_with_num_docs(true, 100)?;
@@ -1058,6 +1421,55 @@ mod tests {
Ok(())
}
#[test]
fn histogram_non_binding_hard_bounds_test_multi_segment() -> crate::Result<()> {
histogram_non_binding_hard_bounds_test_with_opt(false)
}
#[test]
fn histogram_non_binding_hard_bounds_test_single_segment() -> crate::Result<()> {
histogram_non_binding_hard_bounds_test_with_opt(true)
}
/// `hard_bounds` wider than the data (here with mid-interval edges, to cover the "bound cuts a
/// bucket" case) can't exclude any value, so the result must be identical to the same request
/// without bounds. Guards the normalization that collapses such bounds to the unbounded
/// sentinel so the hot loop / fused path can skip the per-doc bounds check.
fn histogram_non_binding_hard_bounds_test_with_opt(merge_segments: bool) -> crate::Result<()> {
let values = vec![10.0, 12.0, 14.0, 16.0, 10.0, 13.0, 10.0, 12.0];
let index = get_test_index_from_values(merge_segments, &values)?;
// Mid-interval edges, but wider than the data range [10, 16] -> they exclude nothing.
let with_bounds: Aggregations = serde_json::from_value(json!({
"histogram": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 9.5, "max": 16.5 }
}
}
}))
.unwrap();
let no_bounds: Aggregations = serde_json::from_value(json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 }
}
}))
.unwrap();
let res_bounds = exec_request(with_bounds, &index)?;
let res_plain = exec_request(no_bounds, &index)?;
// Dropping a non-binding bound must not change anything.
assert_eq!(res_bounds, res_plain);
// Sanity: buckets span the data range with gaps filled (min_doc_count defaults to 0).
assert_eq!(res_bounds["histogram"]["buckets"][0]["key"], 10.0);
assert_eq!(res_bounds["histogram"]["buckets"][0]["doc_count"], 3);
assert_eq!(res_bounds["histogram"]["buckets"][6]["key"], 16.0);
assert_eq!(res_bounds["histogram"]["buckets"][6]["doc_count"], 1);
assert_eq!(res_bounds["histogram"]["buckets"][7], Value::Null);
Ok(())
}
#[test]
fn histogram_empty_result_behaviour_test_single_segment() -> crate::Result<()> {
histogram_empty_result_behaviour_test_with_opt(true)

View File

@@ -9,8 +9,9 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -22,6 +23,7 @@ use crate::TantivyError;
/// Contains all information required by the SegmentRangeCollector to perform the
/// range aggregation on a segment.
#[derive(Debug, Clone)]
pub struct RangeAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
@@ -155,13 +157,13 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<C: SubAggCache> {
pub struct SegmentRangeCollector<B: SubAggBuffer> {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
pub(crate) req_data: RangeAggReqData,
sub_agg: Option<BufferedSubAggs<B>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
@@ -178,12 +180,12 @@ pub struct SegmentRangeCollector<C: SubAggCache> {
limits: AggregationLimitsGuard,
}
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
impl<B: SubAggBuffer> Debug for SegmentRangeCollector<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("name", &self.req_data.name)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
@@ -229,7 +231,7 @@ impl SegmentRangeBucketEntry {
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -238,10 +240,7 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let name = self.req_data.name.to_string();
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
@@ -280,17 +279,15 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_range_req_data(self.accessor_idx);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
.fetch_block(docs, &self.req_data.accessor);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
.iter_docid_vals(docs, &self.req_data.accessor)
{
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
@@ -300,7 +297,6 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
@@ -318,15 +314,26 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
let new_buckets = self.create_new_buckets()?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Range is a multi-bucket agg with no single value to extract.
None
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
@@ -334,8 +341,11 @@ pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let req_data = agg_data.per_request.range_req_data[node.idx_in_req_data].clone();
agg_data
.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
@@ -350,19 +360,19 @@ pub(crate) fn build_segment_range_collector(
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggBuffer> {
sub_agg: sub_agg.map(LowCardBufferedSubAggs::new),
column_type: field_type,
accessor_idx,
req_data,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggBuffer> {
sub_agg: sub_agg.map(BufferedSubAggs::new),
column_type: field_type,
accessor_idx,
req_data,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
@@ -370,13 +380,10 @@ pub(crate) fn build_segment_range_collector(
}
}
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
impl<B: SubAggBuffer> SegmentRangeCollector<B> {
pub(crate) fn create_new_buckets(&mut self) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
let req_data = &self.req_data;
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
@@ -551,17 +558,16 @@ mod tests {
get_test_index_with_num_docs,
};
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
pub fn build_test_buckets(
ranges: &[RangeAggregationRange],
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggCache> {
) -> Vec<SegmentRangeAndBucketEntry> {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
ranges: ranges.to_vec(),
..Default::default()
};
// Build buckets directly as in from_req_and_validate without AggregationsData
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
extend_validate_ranges(&req.ranges, &field_type)
.expect("unexpected error in extend_validate_ranges")
.iter()
.map(|range| {
@@ -592,16 +598,7 @@ mod tests {
},
}
})
.collect();
SegmentRangeCollector {
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
.collect()
}
#[test]
@@ -844,10 +841,10 @@ mod tests {
#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = [(10f64..20f64).into(), (30f64..40f64).into()];
let parent_buckets = [build_test_buckets(&buckets, ColumnType::F64)];
let buckets = collector.parent_buckets[0].clone();
let buckets = parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -863,14 +860,14 @@ mod tests {
fn bucket_test_range_conversion_special_case() {
// the monotonic conversion between f64 and u64, does not map f64::MIN.to_u64() ==
// u64::MIN, but the into trait converts f64::MIN/MAX to None
let buckets = vec![
let buckets = [
(f64::MIN..10f64).into(),
(10f64..20f64).into(),
(20f64..f64::MAX).into(),
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let parent_buckets = [build_test_buckets(&buckets, ColumnType::F64)];
let buckets = collector.parent_buckets[0].clone();
let buckets = parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -882,28 +879,28 @@ mod tests {
#[test]
fn bucket_range_test_negative_vals() {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = [(-10f64..-1f64).into()];
let parent_buckets = [build_test_buckets(&buckets, ColumnType::F64)];
let buckets = collector.parent_buckets[0].clone();
let buckets = parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
#[test]
fn bucket_range_test_positive_vals() {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = [(0f64..10f64).into()];
let parent_buckets = [build_test_buckets(&buckets, ColumnType::F64)];
let buckets = collector.parent_buckets[0].clone();
let buckets = parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
#[test]
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let check_ranges = |ranges: &[RangeAggregationRange]| {
let parent_buckets = [build_test_buckets(ranges, ColumnType::U64)];
let search = |val: u64| get_bucket_pos(val, &parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -916,7 +913,7 @@ mod tests {
};
let ranges = vec![(10.0..100.0).into()];
check_ranges(ranges);
check_ranges(&ranges);
let ranges = vec![
RangeAggregationRange {
@@ -926,7 +923,7 @@ mod tests {
},
(10.0..100.0).into(),
];
check_ranges(ranges);
check_ranges(&ranges);
let ranges = vec![
RangeAggregationRange {
@@ -941,15 +938,15 @@ mod tests {
from: Some(100.0),
},
];
check_ranges(ranges);
check_ranges(&ranges);
}
#[test]
fn range_binary_search_test_f64() {
let ranges = vec![(10.0..100.0).into()];
let ranges = [(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let parent_buckets = [build_test_buckets(&ranges, ColumnType::F64)];
let search = |val: u64| get_bucket_pos(val, &parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);

View File

@@ -1,5 +1,4 @@
use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
@@ -17,8 +16,9 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -29,6 +29,8 @@ use crate::aggregation::{format_date, BucketId, Key};
use crate::error::DataCorruption;
use crate::TantivyError;
mod term_histogram;
/// Contains all information required by the SegmentTermCollector to perform the
/// terms aggregation on a segment.
#[derive(Debug, Clone)]
@@ -352,19 +354,15 @@ pub(crate) fn build_segment_term_collector(
)));
}
// Validate sub aggregation exists when ordering by sub-aggregation.
{
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
// Validate that the referenced sub-aggregation exists when ordering by one.
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric sub_aggregations"
))
})?;
}
// Build sub-aggregation blueprint if there are children.
@@ -378,9 +376,21 @@ pub(crate) fn build_segment_term_collector(
// Let's see if we can use a vec to aggregate our data
// instead of a hashmap.
let col_max_value = terms_req_data.accessor.max_value();
let max_term_id: u64 =
let max_column_val: u64 =
col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64));
// Fused fast path: low-cardinality terms × a single `histogram`/`date_histogram` leaf over full
// columns with a small enough bucket grid. Anything else falls through to the general path.
if let Some(collector) = term_histogram::maybe_build_collector(
req_data,
node,
&terms_req_data,
max_column_val,
is_top_level,
)? {
return Ok(collector);
}
let sub_agg_collector = if has_sub_aggregations {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
@@ -389,51 +399,51 @@ pub(crate) fn build_segment_term_collector(
let mut bucket_id_provider = BucketIdProvider::default();
// Decide which bucket storage is best suited for this aggregation.
if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider);
let collector: SegmentTermCollector<_, HighCardSubAggCache> = SegmentTermCollector {
if is_top_level && max_column_val < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_column_val + 1, &mut bucket_id_provider);
let collector: SegmentTermCollector<_, HighCardSubAggBuffer> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg: None,
bucket_id_provider,
max_term_id,
max_term_id: max_column_val,
terms_req_data,
};
Ok(Box::new(collector))
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider);
let sub_agg = sub_agg_collector.map(LowCardCachedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggCache> = SegmentTermCollector {
} else if is_top_level && max_column_val < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_column_val + 1, &mut bucket_id_provider);
let sub_agg = sub_agg_collector.map(LowCardBufferedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggBuffer> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id,
max_term_id: max_column_val,
terms_req_data,
};
Ok(Box::new(collector))
} else if max_term_id < 8_000_000 && is_top_level {
} else if max_column_val < 8_000_000 && is_top_level {
let term_buckets: PagedTermMap =
PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider);
PagedTermMap::new(max_column_val + 1, &mut bucket_id_provider);
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggCache> =
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggBuffer> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id,
max_term_id: max_column_val,
terms_req_data,
};
Ok(Box::new(collector))
} else {
let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default();
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggCache> =
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggBuffer> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id,
max_term_id: max_column_val,
terms_req_data,
};
Ok(Box::new(collector))
@@ -758,10 +768,10 @@ impl TermAggregationMap for VecTermBuckets {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
struct SegmentTermCollector<TermMap: TermAggregationMap, C: SubAggCache> {
struct SegmentTermCollector<TermMap: TermAggregationMap, B: SubAggBuffer> {
/// The buckets containing the aggregation data.
parent_buckets: Vec<TermMap>,
sub_agg: Option<CachedSubAggs<C>>,
sub_agg: Option<BufferedSubAggs<B>>,
bucket_id_provider: BucketIdProvider,
max_term_id: u64,
terms_req_data: TermsAggReqData,
@@ -772,8 +782,8 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
(agg_name, agg_property)
}
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
for SegmentTermCollector<TermMap, C>
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
for SegmentTermCollector<TermMap, B>
{
fn add_intermediate_aggregation_result(
&mut self,
@@ -790,8 +800,14 @@ impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
let term_req = &self.terms_req_data;
let name = term_req.name.clone();
let bucket =
Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?;
let bucket = Self::into_intermediate_bucket_result(
term_req,
self.sub_agg
.as_mut()
.map(BufferedSubAggs::get_sub_agg_collector),
bucket,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
@@ -803,7 +819,7 @@ impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let req_data = &mut self.terms_req_data;
@@ -847,7 +863,7 @@ impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data
.context
@@ -881,6 +897,17 @@ impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Terms is a multi-bucket agg with no single value to extract.
None
}
}
/// Missing value are represented as a sentinel value in the column.
@@ -907,30 +934,53 @@ fn extract_missing_value<T>(
Some((key, bucket))
}
impl<TermMap, C> SegmentTermCollector<TermMap, C>
fn reborrow_opt_collector<'a>(
opt: &'a mut Option<&mut dyn SegmentAggregationCollector>,
) -> Option<&'a mut dyn SegmentAggregationCollector> {
match opt {
Some(inner) => Some(*inner),
None => None,
}
}
fn into_intermediate_bucket_entry(
bucket: Bucket,
sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateTermBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg_collector) = sub_agg_collector {
sub_agg_collector.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
}
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
}
impl<TermMap, B> SegmentTermCollector<TermMap, B>
where
TermMap: TermAggregationMap,
C: SubAggCache,
B: SubAggBuffer,
{
fn get_memory_consumption(&self) -> usize {
self.parent_buckets
.iter()
.map(|b| b.get_memory_consumption())
.sum()
#[inline]
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> usize {
self.parent_buckets[parent_bucket_id as usize].get_memory_consumption()
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
term_req: &TermsAggReqData,
sub_agg: &mut Option<CachedSubAggs<C>>,
mut sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
term_buckets: TermMap,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
match &term_req.req.order.target {
OrderTarget::Key => {
// We rely on the fact, that term ordinals match the order of the strings
@@ -942,10 +992,37 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
}
}
OrderTarget::SubAggregation(_name) => {
// don't sort and cut off since it's hard to make assumptions on the quality of the
// results when cutting off du to unknown nature of the sub_aggregation (possible
// to check).
OrderTarget::SubAggregation(sub_agg_path) => {
// Peek segment-level metric values, sort, then fall through to
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
// by a sub-agg: top-K results are approximate and may differ from the
// global ordering, especially for non-monotonic metrics like avg/min.
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"Could not find sub-aggregation collector for path {sub_agg_path}"
))
})?;
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
// Fetch values up-front; otherwise sort would re-compute per comparison
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
.into_iter()
.map(|bucket| {
let metric_value = coll
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
.unwrap_or(0.0);
(metric_value, bucket)
})
.collect();
if term_req.req.order.order == Order::Desc {
keyed.sort_unstable_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
keyed.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
entries = keyed.into_iter().map(|(_, e)| e).collect();
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
@@ -956,40 +1033,12 @@ where
}
}
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
};
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
let into_intermediate_bucket_entry =
|bucket: Bucket,
sub_agg: &mut Option<CachedSubAggs<C>>|
-> crate::Result<IntermediateTermBucketEntry> {
if let Some(sub_agg) = sub_agg {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
} else {
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: Default::default(),
})
}
};
if term_req.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = term_req
@@ -1000,7 +1049,11 @@ where
if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?;
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(intermediate_key, intermediate_entry);
}
@@ -1008,19 +1061,28 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
let mut buckets_it = buckets.into_iter();
term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| {
let bucket = buckets_it.next().unwrap();
let intermediate_entry =
into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?;
let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;
let mut intermediate_entry_it = intermediate_entries.into_iter();
term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
Ok(())
})?;
if term_req.req.min_doc_count == 0 {
@@ -1055,14 +1117,22 @@ where
}
} else if term_req.column_type == ColumnType::DateTime {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let val = i64::from_u64(val);
let date = format_date(val)?;
dict.insert(IntermediateKey::Str(date), intermediate_entry);
}
} else if term_req.column_type == ColumnType::Bool {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let val = bool::from_u64(val);
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
}
@@ -1082,14 +1152,22 @@ where
})?;
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
dict.insert(IntermediateKey::IpAddr(val), intermediate_entry);
}
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
if term_req.column_type == ColumnType::U64 {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
} else if term_req.column_type == ColumnType::I64 {
@@ -1123,13 +1201,13 @@ where
}
}
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentTermCollector<TermMap, C> {
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentTermCollector<TermMap, B> {
#[inline]
fn collect_terms_with_docs(
iter: impl Iterator<Item = (crate::DocId, u64)>,
term_buckets: &mut TermMap,
bucket_id_provider: &mut BucketIdProvider,
sub_agg: &mut CachedSubAggs<C>,
sub_agg: &mut BufferedSubAggs<B>,
) {
for (doc, term_id) in iter {
let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider);
@@ -1202,7 +1280,7 @@ mod tests {
use crate::aggregation::{AggregationLimitsGuard, DistributedAggregationCollector};
use crate::indexer::NoMergePolicy;
use crate::query::AllQuery;
use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING};
use crate::schema::{IntoIpv6Addr, Schema, FAST, INDEXED, STRING, TEXT};
use crate::{Index, IndexWriter};
#[test]
@@ -1731,6 +1809,263 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(true)
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(false)
}
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
// Distinct score values per bucket key: A→5, B→1, C→3.
// Order by cardinality desc must yield A, C, B.
let segment_and_terms = vec![vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
]];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
// and `sum_other_doc_count` reflects the dropped B (3 docs).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
Ok(())
}
#[test]
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(true)
}
#[test]
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(false)
}
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
let segment_and_terms = vec![
vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
],
vec![
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// Desc on a Sum metric engages the fast path (column is U64).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "asc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 desc with cutoff: top-2 by sum (A, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "mystats.sum": "desc" }
},
"aggs": {
"mystats": { "stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Sum on a signed column (I64) takes the same cutoff path. Results may be
// approximate near the boundary on adversarial data, but for this dataset the
// top-K is unambiguous.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score_i64" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Order by extended_stats sub-property exercises compute_metric_value on the
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "ext.max": "desc" }
},
"aggs": {
"ext": { "extended_stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)
@@ -2896,4 +3231,101 @@ mod tests {
Ok(())
}
fn prep_index_with_n_unique_terms_plus_one_null(n: u64) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED);
let title_field = schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// set to one thread to guarantee all docs end up in the same segment
let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
writer.add_document(doc!(
id_field => 0u64,
))?;
for i in 1u64..=n {
let title = format!("foo{i}");
writer.add_document(doc!(
id_field => i,
title_field => title,
))?;
}
writer.commit()?;
Ok(index)
}
#[test]
fn null_bitset_bounds_check_regression() -> crate::Result<()> {
// include cases
for i in 0..=4 {
let index = prep_index_with_n_unique_terms_plus_one_null(i * 64)?;
let normal_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let include_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"include": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let exclude_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"exclude": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let normal_res = exec_request(normal_req, &index)?;
let normal_buckets = normal_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
normal_buckets.len(),
(i * 64) as usize + 1,
"The normal request should return all 'foo' buckets, plus the missing term bucket",
);
let include_res = exec_request(include_req, &index)?;
eprintln!("include_res: {include_res:?}");
let include_buckets = include_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
include_buckets.len(),
(i * 64) as usize,
"The include request should return all 'foo' buckets, and not the missing term \
bucket",
);
assert!(include_buckets
.iter()
.all(|b| b["key"].as_str().unwrap().starts_with("foo")));
let exclude_res = exec_request(exclude_req, &index)?;
let exclude_buckets = exclude_res["my_bool"]["buckets"].as_array().unwrap();
if i != 0 {
// TODO: Remove this if after fixing exclude + missing bug
assert_eq!(
exclude_buckets.len(),
1,
"The exclude request should exclude all 'foo' buckets, and only the missing \
term bucket",
);
assert_eq!(exclude_buckets[0]["key"], "__NULL__");
}
}
Ok(())
}
}

View File

@@ -0,0 +1,585 @@
//! Fused collector for the very common shape `terms` (low cardinality) × a single
//! `histogram`/`date_histogram` sub-aggregation with nothing nested below it.
//!
//! See [`SegmentTermHistogramCollector`] for the approach and [`maybe_build_collector`] for the
//! conditions under which it is used.
use columnar::ColumnBlockAccessor;
use super::{Bucket, SegmentTermCollector, TermsAggReqData, VecTermBuckets};
use crate::aggregation::agg_data::{AggKind, AggRefNode, AggregationsSegmentCtx};
use crate::aggregation::bucket::{
get_bucket_pos_f64, prepare_histogram_dense_range, HistogramAggReqData,
SegmentHistogramCollector,
};
use crate::aggregation::buffered_sub_aggs::LowCardSubAggBuffer;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::{f64_from_fastfield_u64, BucketId};
/// Maximum number of cells (`num_terms × num_time_buckets`) in the fused flat 2D grid. Above this
/// the grid would be too large/cache-unfriendly, so we fall back to the general buffered path.
/// `1 << 14` cells = 128 KB of `u64` counters, comfortably L2-resident.
///
/// Since we are only at the top-level, this won't be multiplied by any parent buckets.
const MAX_FUSED_GRID_BUCKETS: usize = 16384;
/// Fused collector for `terms` (low cardinality) × a single `histogram`/`date_histogram` leaf with
/// nothing nested below it, when the resulting `num_terms × num_time_buckets` grid is small (see
/// [`MAX_FUSED_GRID_BUCKETS`]).
///
/// It keeps a flat, fully dense 2D counter grid (`counts[term * num_time_buckets + bucket]`) and a
/// per-term total. A single pass reads both the term and histogram columns in document order and
/// bumps the counters directly — no doc-id buffering, no per-term scattered re-fetch, no dynamic
/// dispatch on flush, no per-bucket key/id storage during collection (keys are derived from the
/// index at the end).
///
/// At result time the flat grid is expanded back into the regular term map + histogram storage and
/// handed to the shared intermediate-result builders, so cross-segment merging is identical to the
/// general path.
#[derive(Debug)]
pub(crate) struct SegmentTermHistogramCollector {
/// Per-term count of docs *outside* `hard_bounds` (still in `doc_count`, but in no bucket).
/// Per-term total = this + the term's `counts` row-sum; left empty when there are no hard
/// bounds (every doc is in-bounds, so there's no remainder to track).
term_counts: Vec<u32>,
/// Flattened `[num_terms * num_time_buckets]` histogram counters (`u32`, see
/// `term_counts`).
///
/// Each term id get its own contiguous slice of `num_time_buckets` histogram counter.
/// When we count all docs (#nofilter), we can derive the per-term total as the sum over that
/// term's slice.
counts: Vec<u32>,
/// Histogram buckets per term (the dense time-range length).
num_time_buckets: usize,
/// `bucket_pos` mapped to time-bucket index 0.
base_pos: i64,
terms_req_data: TermsAggReqData,
/// The (cloned, normalized) histogram request: its column + interval/offset/bounds.
hist_req_data: HistogramAggReqData,
/// Private block accessors for both columns. We read them together, so each needs its own
/// (the shared `agg_data` scratch accessor only holds one block at a time). Owning them keeps
/// `collect` independent of `agg_data`.
term_block: ColumnBlockAccessor<u64>,
hist_block: ColumnBlockAccessor<u64>,
/// No hard bounds, so every doc is in-bounds.
all_docs_in_bounds: bool,
/// Both columns are full (fused-path precondition); cached so `collect` skips the per-block
/// cardinality lookup in `fetch_block`.
is_full: bool,
}
impl SegmentAggregationCollector for SegmentTermHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
debug_assert_eq!(
parent_bucket_id, 0,
"fused term-histogram collector is top-level only"
);
// Expand the flat grid back into the regular structures and reuse the shared builders, so
// ordering/cut-off/dict handling and cross-segment merging match the general path exactly.
let mut bucket_id_provider = BucketIdProvider::default();
// Per-term total = histogram row-sum (in-bounds) + `term_counts` (out-of-bounds remainder,
// empty when there are no hard bounds).
let term_buckets = VecTermBuckets {
buckets: self
.counts
.chunks_exact(self.num_time_buckets)
.enumerate()
.map(|(term_id, row)| {
let in_bounds: u32 = row.iter().sum();
let out_of_bounds = self.term_counts.get(term_id).copied().unwrap_or(0);
Bucket {
count: in_bounds + out_of_bounds,
bucket_id: bucket_id_provider.next_bucket_id(),
}
})
.collect(),
};
let mut histogram = SegmentHistogramCollector::<()>::from_dense_rows(
self.hist_req_data.clone(),
self.base_pos,
self.num_time_buckets,
&self.counts,
);
let name = self.terms_req_data.name.clone();
let bucket = SegmentTermCollector::<VecTermBuckets, LowCardSubAggBuffer>::into_intermediate_bucket_result(
&self.terms_req_data,
Some(&mut histogram as &mut dyn SegmentAggregationCollector),
term_buckets,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
_agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(
parent_bucket_id, 0,
"fused term-histogram collector is top-level only"
);
// Fetch both columns into our own accessors (we read them together, so they can't share the
// single `agg_data` scratch accessor). The collector owns all its inputs, so `collect`
// doesn't touch `agg_data`.
self.term_block
.fetch_block_with_is_full(docs, &self.terms_req_data.accessor, self.is_full);
self.hist_block
.fetch_block_with_is_full(docs, &self.hist_req_data.accessor, self.is_full);
// Hoist the loop-invariant fields into locals: the optimizer can't prove the
// `self.counts`/`self.term_counts` writes don't alias these `self` fields, so it can't keep
// them in registers and re-reads them from memory every iteration — ~15% slower on
// `terms_status_with_date_histogram` when read straight from `self`.
// Note: check which are actually relevant.
let field_type = self.hist_req_data.field_type;
let bounds = self.hist_req_data.bounds;
let interval = self.hist_req_data.req.interval;
let offset = self.hist_req_data.offset;
let base_pos = self.base_pos;
let num_time_buckets = self.num_time_buckets;
let all_docs_in_bounds = self.all_docs_in_bounds;
let term_counts = &mut self.term_counts;
let counts = &mut self.counts;
// Both columns are full (checked at construction), so values align with `docs` positionally
// and are read together in one pass.
// In-bounds docs bump the `counts` grid, out-of-bounds bump `term_counts`; deriving the
// total at flush avoids a per-doc `term_counts` RMW that serializes on
// store-to-load forwarding.
for (term_id, hist_raw) in self.term_block.iter_vals().zip(self.hist_block.iter_vals()) {
let term_id = term_id as usize;
let val = f64_from_fastfield_u64(hist_raw, field_type);
if all_docs_in_bounds || bounds.contains(val) {
let bucket = (get_bucket_pos_f64(val, interval, offset) as i64 - base_pos) as usize;
debug_assert!(
bucket < num_time_buckets,
"histogram bucket outside dense range"
);
counts[term_id * num_time_buckets + bucket] += 1;
} else {
term_counts[term_id] += 1;
}
}
Ok(())
}
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Nothing is buffered: `collect` writes the flat grid directly.
Ok(())
}
fn prepare_max_bucket(
&mut self,
_max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
// Top-level: the flat grid is allocated up front.
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
None
}
}
/// Builds the fused terms×histogram collector for a single top-level parent, when the shape is
/// eligible. Returns `Ok(None)` to fall back to the general buffered terms path.
///
/// Eligibility: top-level, low-cardinality terms over a full column with no missing/include-exclude
/// handling; a single `histogram`/`date_histogram` leaf (no nesting below it) over a full column;
/// and a `num_terms × num_time_buckets` grid no larger than [`MAX_FUSED_GRID_BUCKETS`].
pub(super) fn maybe_build_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
terms_req_data: &TermsAggReqData,
col_max_val: u64,
is_top_level: bool,
) -> crate::Result<Option<Box<dyn SegmentAggregationCollector>>> {
// Both columns must be full (one value per doc) so their values align positionally with `docs`
// and we can zip them. Requiring full columns also makes the terms agg's `missing` config a
// no-op (`fetch_block_with_missing` early-returns on full columns), so we needn't check for it.
//
// We don't cap the term cardinality here: the flat grid is bounded by the total cell count
// (`num_terms * num_time_buckets <= MAX_FUSED_GRID_BUCKETS`) checked below, which subsumes it.
//
// We only allow this at the top-level, since we don't know how many buckets are created. We
// are less likely to get enough docs for the preallocation to be worth and there's a risk of
// using too much memory. We could check the maximum theoretical buckets up-front and pass
// them down.
let fuseable = is_top_level
// TODO: We can easily support this
&& terms_req_data.allowed_term_ids.is_none()
&& terms_req_data.accessor.get_cardinality().is_full()
// The flat counters are `u32`, bumped once per value, so no count can exceed the column's
// value count. (Essentially always true here: the column is full, so its value count
// equals the doc count, and `DocId` is `u32`.)
&& terms_req_data.accessor.values.num_vals() < u32::MAX
&& node.children.len() == 1
&& matches!(
node.children[0].kind,
AggKind::Histogram | AggKind::DateHistogram
)
&& node.children[0].children.is_empty()
&& agg_data.per_request.histogram_req_data[node.children[0].idx_in_req_data]
.accessor
.get_cardinality()
.is_full();
if !fuseable {
return Ok(None);
}
// Clone + normalize the histogram request and get its dense bucket range; only take the fused
// path when the flat `num_terms × num_time_buckets` grid is small enough.
let Some((hist_req_data, range)) = prepare_histogram_dense_range(agg_data, &node.children[0])?
else {
return Ok(None);
};
let num_terms = col_max_val.saturating_add(1) as usize;
if num_terms.saturating_mul(range.len) > MAX_FUSED_GRID_BUCKETS {
return Ok(None);
}
// No hard bounds means every doc is in-bounds, letting `collect` short-circuit the bounds
// check — and leaving `term_counts` (the out-of-bounds remainder) unused, so we skip allocating
// it.
let all_docs_in_bounds =
hist_req_data.bounds.min == f64::MIN && hist_req_data.bounds.max == f64::MAX;
let counts = vec![0u32; num_terms * range.len];
let term_counts = if all_docs_in_bounds {
Vec::new()
} else {
vec![0u32; num_terms]
};
// Charge both grids to the aggregation memory limit.
agg_data.context.limits.add_memory_consumed(
((counts.len() + term_counts.len()) * std::mem::size_of::<u32>()) as u64,
)?;
Ok(Some(Box::new(SegmentTermHistogramCollector {
term_counts,
counts,
num_time_buckets: range.len,
base_pos: range.base_pos,
terms_req_data: terms_req_data.clone(),
hist_req_data,
term_block: ColumnBlockAccessor::default(),
hist_block: ColumnBlockAccessor::default(),
all_docs_in_bounds,
is_full: terms_req_data.accessor.get_cardinality().is_full(),
})))
}
#[cfg(test)]
mod tests {
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{
exec_request, exec_request_with_query_and_memory_limit,
get_test_index_from_values_and_terms,
};
use crate::aggregation::AggregationLimitsGuard;
/// Hand-computed correctness check for the fused terms×histogram fast path
/// ([`super::SegmentTermHistogramCollector`]): low-cardinality terms × a histogram leaf over
/// full columns, exercised single- and multi-segment.
#[test]
fn fused_term_histogram_test() -> crate::Result<()> {
fused_term_histogram_with_opt(false)?;
fused_term_histogram_with_opt(true)?;
Ok(())
}
fn fused_term_histogram_with_opt(merge_segments: bool) -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, histogram value = i % 20 (interval 1 => buckets
// 0..19). gcd(3, 20) = 1, so every (term, bucket) pair occurs exactly 300 / 60 = 5 times.
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
// Two segments, to also exercise cross-segment merging of the fused per-term histograms.
let segments = vec![docs[..150].to_vec(), docs[150..].to_vec()];
let index = get_test_index_from_values_and_terms(merge_segments, &segments)?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["key"], b as f64, "term {term} bucket {b}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
assert_eq!(histo[20], serde_json::Value::Null);
}
assert_eq!(res["by_term"]["buckets"][3], serde_json::Value::Null);
Ok(())
}
/// A `missing` config on a *full* term column still takes the fused path (the string sentinel
/// is just `col_max + 1`, so the column stays low-cardinality). Since no doc is missing, the
/// real term buckets must be exactly as without `missing`.
#[test]
fn fused_term_histogram_with_missing_on_full_column() -> crate::Result<()> {
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "missing": "MISSING", "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Column is full, so "MISSING" never applies: a, b, c are unchanged (100 docs, 5 per
// bucket).
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
}
Ok(())
}
/// Term cardinality above the general path's `MAX_NUM_TERMS_FOR_VEC` (100) still fuses: the
/// flat grid is bounded by the total cell count (`num_terms * num_time_buckets`), not the
/// term count.
#[test]
fn fused_term_histogram_many_terms() -> crate::Result<()> {
let num_terms = 150usize;
let docs_per_term = 2usize;
// All docs share histogram value 0 (a single bucket), so the grid is 150 x 1 = 150 cells.
let docs: Vec<(f64, String)> = (0..num_terms * docs_per_term)
.map(|i| (0.0, format!("t{:03}", i % num_terms)))
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "size": 1000, "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
let buckets = res["by_term"]["buckets"].as_array().unwrap();
assert_eq!(buckets.len(), num_terms);
for (i, bucket) in buckets.iter().enumerate() {
assert_eq!(bucket["key"], format!("t{i:03}"));
assert_eq!(bucket["doc_count"], docs_per_term as u64);
assert_eq!(bucket["histo"]["buckets"][0]["key"], 0.0);
assert_eq!(
bucket["histo"]["buckets"][0]["doc_count"],
docs_per_term as u64
);
}
Ok(())
}
/// `hard_bounds` exercises the non-derived `term_counts` branch: a term's `doc_count` must
/// count *every* doc with that term, including docs whose histogram value is outside the
/// bounds (those are excluded from the histogram buckets but still counted for the term). This
/// is the case where the per-doc `term_counts` increment cannot be replaced by the grid
/// row-sum.
#[test]
fn fused_term_histogram_with_hard_bounds() -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, value = i % 20. Per term: 100 docs, each value in
// 0..=19 occurring 5 times.
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
// hard_bounds [5, 14] (inclusive) keeps only values 5..=14 in the histogram (10 buckets);
// values 0..=4 and 15..=19 are out of bounds.
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 5.0, "max": 14.0 }
}
}
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
// doc_count includes the 50 per-term docs whose value is outside [5, 14].
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..10usize {
let key = 5 + b;
assert_eq!(histo[b]["key"], key as f64, "term {term} bucket key {key}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {key}");
}
// Only the 10 in-bounds buckets exist.
assert_eq!(histo[10], serde_json::Value::Null);
}
Ok(())
}
/// Non-binding `hard_bounds` (wider than the data, with mid-interval edges) must still produce
/// exact results via the derive-from-grid path: since no doc is out of bounds, normalization
/// drops the bound, every doc lands in the dense range, and each term's total equals its
/// histogram row-sum. This is the case that previously fell back to the per-doc counter only
/// because `bounds != [MIN, MAX]`.
#[test]
fn fused_term_histogram_with_non_binding_hard_bounds() -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, value = i % 20. Data values span [0, 19].
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
// Bounds wider than [0, 19], with mid-interval edges -> they exclude nothing.
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": -0.5, "max": 19.5 }
}
}
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
// Every doc is in-bounds, so the per-term total is the full 100 (as without bounds).
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["key"], b as f64, "term {term} bucket {b}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
assert_eq!(histo[20], serde_json::Value::Null);
}
Ok(())
}
/// Regression: with hard bounds the fused path allocates `term_counts` (one `u32`/term) on top
/// of the grid, and that allocation must be charged to the memory limit. With many terms and a
/// single time bucket the two are equal in size, so a limit admitting the grid alone but not
/// grid + `term_counts` must fail.
#[test]
fn fused_term_histogram_hard_bounds_charges_term_counts() -> crate::Result<()> {
// 16k distinct terms, one doc each; values alternate in/out of the single-bucket bounds
// [5, 5] so the bounds bind and `term_counts` is allocated. num_terms=16000,
// num_time_buckets=1 => `counts` and `term_counts` are ~64 KB each.
let docs: Vec<(f64, String)> = (0..16_000u64)
.map(|i| (if i % 2 == 0 { 5.0 } else { 10.0 }, format!("t{i:05}")))
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id" },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 5.0, "max": 5.0 }
}
}
}
}
}))
.unwrap();
// ~96 KB admits the grid (~64 KB) but not grid + `term_counts` (~128 KB).
let err = exec_request_with_query_and_memory_limit(
agg_req,
&index,
None,
AggregationLimitsGuard::new(Some(96_000), None),
)
.unwrap_err();
assert!(
err.to_string().contains("memory limit was exceeded"),
"expected a memory-limit error, got: {err}"
);
Ok(())
}
}

View File

@@ -5,7 +5,7 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
@@ -47,7 +47,7 @@ struct MissingCount {
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<HighCardCachedSubAggs>,
sub_agg: Option<HighCardBufferedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
@@ -66,7 +66,7 @@ impl TermMissingAgg {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let sub_agg = sub_agg.map(BufferedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
@@ -177,6 +177,17 @@ impl SegmentAggregationCollector for TermMissingAgg {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
None
}
}
#[cfg(test)]

View File

@@ -6,7 +6,7 @@ use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// A buffer for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
@@ -24,21 +24,21 @@ use crate::DocId;
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
pub(crate) struct BufferedSubAggs<B: SubAggBuffer> {
buffer: B,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
pub type LowCardBufferedSubAggs = BufferedSubAggs<LowCardSubAggBuffer>;
pub type HighCardBufferedSubAggs = BufferedSubAggs<HighCardSubAggBuffer>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// A trait for buffering sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
pub trait SubAggBuffer: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
@@ -49,22 +49,22 @@ pub trait SubAggCache: Debug {
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
buffer: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
pub fn get_sub_agg_collector(&mut self) -> &mut dyn SegmentAggregationCollector {
&mut *self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.buffer.push(bucket_id, doc_id);
self.num_docs += 1;
}
@@ -75,7 +75,7 @@ impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
self.buffer
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
@@ -85,7 +85,7 @@ impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
self.buffer
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
@@ -94,11 +94,11 @@ impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
/// Number of partitions for high cardinality sub-aggregation buffer.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
pub(crate) struct HighCardSubAggBuffer {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
@@ -108,7 +108,7 @@ pub(crate) struct HighCardSubAggCache {
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
impl HighCardSubAggBuffer {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
@@ -131,13 +131,14 @@ impl PartitionEntry {
}
}
impl SubAggCache for HighCardSubAggCache {
impl SubAggBuffer for HighCardSubAggBuffer {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
#[inline]
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
@@ -173,14 +174,14 @@ impl SubAggCache for HighCardSubAggCache {
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
pub(crate) struct LowCardSubAggBuffer {
/// Buffer doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
impl LowCardSubAggBuffer {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
@@ -189,13 +190,14 @@ impl LowCardSubAggCache {
}
}
impl SubAggCache for LowCardSubAggCache {
impl SubAggBuffer for LowCardSubAggBuffer {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
#[inline]
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {

View File

@@ -1,6 +1,6 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::buffered_sub_aggs::LowCardBufferedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardCachedSubAggs,
agg_collector: LowCardBufferedSubAggs,
error: Option<TantivyError>,
}
@@ -152,7 +152,7 @@ impl AggregationSegmentCollector {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
LowCardBufferedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero

View File

@@ -377,7 +377,22 @@ impl IntermediateMetricResult {
MetricResult::ExtendedStats(intermediate_stats.finalize())
}
IntermediateMetricResult::Sum(intermediate_sum) => {
MetricResult::Sum(intermediate_sum.finalize().into())
// By default match Elasticsearch: empty / all-missing sum
// buckets serialize as `"value": 0`, not `"value": null`.
// The non-ES `none_if_no_match` flag on `SumAggregation`
// opts into SQL-style `null` for downstream consumers.
let none_if_no_match = req
.agg
.as_sum()
.and_then(|sum| sum.none_if_no_match)
.unwrap_or(false);
let value = intermediate_sum.finalize();
if none_if_no_match {
MetricResult::Sum(value.into())
} else {
let value = Some(value.unwrap_or(0.0));
MetricResult::Sum(value.into())
}
}
IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles(
percentiles
@@ -1004,24 +1019,20 @@ impl IntermediateCompositeBucketResult {
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let after_key = trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap_or_default();
let buckets = trimmed_entry_vec
.into_iter()

File diff suppressed because it is too large Load Diff

View File

@@ -399,6 +399,26 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let extended = self.buckets.get(bucket_id as usize)?;
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
// doesn't disturb the eventual intermediate result.
extended
.finalize()
.get_value(sub_agg_property)
.ok()
.flatten()
}
}
#[cfg(test)]

View File

@@ -107,10 +107,9 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
/// Percentile
/// The percentile key (e.g. 1.0, 5.0, 25.0).
pub key: f64,
/// Value at the percentile
/// The percentile value. `NaN` when there are no values.
pub value: f64,
}

View File

@@ -312,6 +312,26 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
return None;
}
let percentile: f64 = sub_agg_property.parse().ok()?;
if !(0.0..=100.0).contains(&percentile) {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?;
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
// not affect the intermediate state used for the final result.
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
}
}
#[cfg(test)]

View File

@@ -321,6 +321,40 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let stats = self.buckets.get(bucket_id as usize)?;
// The property depends on what we're collecting:
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
// the value they were configured to collect.
let prop = match self.collecting_for {
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
StatsType::Sum if sub_agg_property.is_empty() => "sum",
StatsType::Count if sub_agg_property.is_empty() => "count",
StatsType::Max if sub_agg_property.is_empty() => "max",
StatsType::Min if sub_agg_property.is_empty() => "min",
StatsType::Average if sub_agg_property.is_empty() => "avg",
_ => return None,
};
match prop {
"count" => Some(stats.count as f64),
"sum" => Some(stats.sum),
"min" if stats.count > 0 => Some(stats.min),
"max" if stats.count > 0 => Some(stats.max),
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
_ => None,
}
}
}
#[inline]

View File

@@ -27,6 +27,16 @@ pub struct SumAggregation {
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
/// Non-Elasticsearch extension. When `Some(true)`, the serialized result
/// returns `"value": null` if no values were collected (all documents had
/// missing/NULL values for the field), matching the behavior of `min`,
/// `max`, and `avg`. When `None` or `Some(false)` (the default) the
/// result returns `"value": 0`, matching Elasticsearch.
///
/// Intended for SQL-style consumers where `SUM` of zero rows is `NULL`
/// and must be distinguishable from a bucket that genuinely sums to `0`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub none_if_no_match: Option<bool>,
}
impl SumAggregation {
@@ -35,6 +45,7 @@ impl SumAggregation {
Self {
field: field_name,
missing: None,
none_if_no_match: None,
}
}
/// Returns the field name the aggregation is computed on.
@@ -59,8 +70,104 @@ impl IntermediateSum {
pub fn merge_fruits(&mut self, other: IntermediateSum) {
self.stats.merge_fruits(other.stats);
}
/// Computes the final minimum value.
/// Computes the final sum value.
///
/// Returns `None` when no values were collected, matching the Rust-side
/// behavior of `IntermediateMin`, `IntermediateMax`, and
/// `IntermediateAvg`. The Elasticsearch-vs-SQL choice for the
/// user-visible result is made at the boundary in
/// [`IntermediateMetricResult::into_final_metric_result`]: by default
/// `None` is coerced to `Some(0.0)` to match Elasticsearch
/// (`"value": 0`), and the [`SumAggregation::none_if_no_match`] flag
/// opts out of that coercion for SQL-style consumers.
pub fn finalize(&self) -> Option<f64> {
Some(self.stats.finalize().sum)
let stats = self.stats.finalize();
if stats.count == 0 {
None
} else {
Some(stats.sum)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_finalize_returns_none_when_no_values() {
// Default IntermediateSum has count=0 — finalize should return None,
// matching MIN/MAX/AVG behavior for all-NULL groups.
let sum = IntermediateSum::default();
assert_eq!(sum.finalize(), None);
}
#[test]
fn test_sum_finalize_returns_value_when_has_values() {
let mut sum = IntermediateSum::default();
// Merge in a result that has actual values
let stats = IntermediateStats {
count: 3,
sum: 42.0,
min: 10.0,
max: 20.0,
..Default::default()
};
let other = IntermediateSum::from_stats(stats);
sum.merge_fruits(other);
assert_eq!(sum.finalize(), Some(42.0));
}
#[test]
fn test_sum_merge_two_empty_still_none() {
let mut a = IntermediateSum::default();
let b = IntermediateSum::default();
a.merge_fruits(b);
assert_eq!(a.finalize(), None);
}
#[test]
fn test_sum_aggregation_empty_index_default_matches_es() -> crate::Result<()> {
use serde_json::json;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
// Empty index — sum has no values to collect.
let values: Vec<Vec<&str>> = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"score_sum": { "sum": { "field": "score" } }
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Default: match Elasticsearch — empty sum serializes as 0, not null.
assert_eq!(res["score_sum"]["value"], 0.0);
Ok(())
}
#[test]
fn test_sum_aggregation_empty_index_none_if_no_match_opt_in() -> crate::Result<()> {
use serde_json::json;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
let values: Vec<Vec<&str>> = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"score_sum": { "sum": { "field": "score", "none_if_no_match": true } }
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Opt-in non-ES extension — empty sum serializes as null.
assert!(
res["score_sum"]["value"].is_null(),
"expected null, got {:?}",
res["score_sum"]["value"]
);
Ok(())
}
}

View File

@@ -644,6 +644,17 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
);
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// top_hits is not a numeric metric and cannot be used as an order target.
None
}
}
#[cfg(test)]

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
pub(crate) mod cached_sub_aggs;
pub(crate) mod buffered_sub_aggs;
mod collector;
mod date;
mod error;

View File

@@ -76,6 +76,31 @@ pub trait SegmentAggregationCollector: Debug {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
///
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
/// this value and cuts off at segment time, matching the approximation tradeoff
/// Elasticsearch makes for any sub-agg ordering.
///
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
/// the metric is a single-value kind such as cardinality.
///
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
/// must be idempotent so the final intermediate result is unaffected.
///
/// No default impl on purpose: every collector must decide explicitly whether it
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
/// buckets to the same key and drop arbitrary winners.
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64>;
}
#[derive(Default)]
@@ -137,4 +162,21 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
for agg in &self.aggs {
if let Some(value) =
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
{
return Some(value);
}
}
None
}
}

View File

@@ -1,5 +1,6 @@
use super::Collector;
use crate::collector::SegmentCollector;
use crate::query::Weight;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
/// `CountCollector` collector only counts how many
@@ -55,6 +56,15 @@ impl Collector for Count {
fn merge_fruits(&self, segment_counts: Vec<usize>) -> crate::Result<usize> {
Ok(segment_counts.into_iter().sum())
}
fn collect_segment(
&self,
weight: &dyn Weight,
_segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<usize> {
Ok(weight.count(reader)? as usize)
}
}
#[derive(Default)]

View File

@@ -389,6 +389,13 @@ impl SegmentCollector for FacetSegmentCollector {
}
let mut facet = vec![];
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
// document has the exact registered facet, not a child of it).
// Passing it to ord_to_term would resolve to the last dictionary
// entry and produce a spurious facet from an unrelated branch.
if facet_ord == u64::MAX {
continue;
}
// TODO handle errors.
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
if let Some((end_collapsed_facet, _)) = facet
@@ -814,6 +821,63 @@ mod tests {
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
// When a document has the exact registered facet path (not just a child),
// harvest() must not turn the unmapped sentinel into a spurious root entry.
#[test]
fn test_facet_collector_wrong_root() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let facets: Vec<&str> = vec![
"/science-fiction/asimov",
"/science-fiction/clarke",
"/science-fiction/dick",
"/science-fiction/herbert",
"/science-fiction/orwell",
// This exact match on the registered facet is the bug trigger:
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
// mapping, which without the fix resolves to an unrelated term.
"/fantasy/epic-fantasy",
"/fantasy/epic-fantasy/tolkien",
"/fantasy/epic-fantasy/martin",
];
for facet_str in &facets {
index_writer.add_document(doc!(
facet_field => Facet::from(*facet_str)
))?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut facet_collector = FacetCollector::for_field("facet");
facet_collector.add_facet("/fantasy/epic-fantasy");
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
let result: Vec<(String, u64)> = counts
.get("/")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
assert_eq!(
result,
vec![
("/fantasy/epic-fantasy/martin".to_string(), 1),
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
]
);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -1,5 +1,8 @@
use std::cmp::{Ordering, Reverse};
use std::collections::BinaryHeap;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -25,6 +28,10 @@ impl SortKeyComputer for SortBySimilarityScore {
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
//
// We use a BinaryHeap (TopNHeap) instead of TopNComputer here so that the
// threshold is always the exact K-th best score. TopNComputer only updates its
// threshold every K docs (at truncation), giving Block-WAND a stale bound.
fn collect_segment_top_k(
&self,
k: usize,
@@ -32,12 +39,10 @@ impl SortKeyComputer for SortBySimilarityScore {
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
let mut top_n = TopNHeap::new(k);
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
@@ -56,7 +61,7 @@ impl SortKeyComputer for SortBySimilarityScore {
Ok(top_n
.into_vec()
.into_iter()
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.map(|(score, doc)| (score, DocAddress::new(segment_ord, doc)))
.collect())
}
}
@@ -75,3 +80,204 @@ impl SegmentSortKeyComputer for SortBySimilarityScore {
score
}
}
/// Min-heap entry: higher score = greater, lower doc wins ties.
struct ScoreHeapEntry {
score: Score,
doc: DocId,
}
impl Eq for ScoreHeapEntry {}
impl PartialEq for ScoreHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for ScoreHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoreHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.then_with(|| other.doc.cmp(&self.doc))
}
}
/// Heap-based top-K for score collection. O(log K) per insert, but the threshold
/// is always tight, so Block-WAND prunes better than with [`TopNComputer`]'s
/// buffer/median approach.
///
/// Like [`TopNComputer`], items must arrive in ascending doc order, and equal
/// scores are rejected (strict `>`) so that lower doc IDs win ties.
///
/// [`TopNComputer`]: crate::collector::TopNComputer
struct TopNHeap {
heap: BinaryHeap<Reverse<ScoreHeapEntry>>,
top_n: usize,
threshold: Option<Score>,
}
impl TopNHeap {
fn new(top_n: usize) -> Self {
TopNHeap {
heap: BinaryHeap::with_capacity(top_n),
top_n,
threshold: None,
}
}
#[inline]
fn push(&mut self, score: Score, doc: DocId) {
if self.heap.len() < self.top_n {
self.heap.push(Reverse(ScoreHeapEntry { score, doc }));
if self.heap.len() == self.top_n {
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
} else if let Some(threshold) = self.threshold {
if score > threshold {
// peek_mut + assign is a single sift-down, vs pop + push = two sifts.
if let Some(mut min) = self.heap.peek_mut() {
*min = Reverse(ScoreHeapEntry { score, doc });
}
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
}
}
fn into_vec(self) -> Vec<(Score, DocId)> {
self.heap
.into_vec()
.into_iter()
.map(|Reverse(entry)| (entry.score, entry.doc))
.collect()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::TopNComputer;
#[test]
fn test_top_n_heap_zero_capacity() {
let mut heap = TopNHeap::new(0);
heap.push(1.0, 0);
heap.push(2.0, 1);
assert!(heap.into_vec().is_empty());
}
#[test]
fn test_top_n_heap_basic() {
let mut heap = TopNHeap::new(2);
heap.push(1.0, 0);
heap.push(3.0, 1);
heap.push(2.0, 2);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 1), (2.0, 2)]);
}
#[test]
fn test_top_n_heap_threshold_always_accurate() {
let mut heap = TopNHeap::new(2);
assert_eq!(heap.threshold, None);
heap.push(1.0, 0);
assert_eq!(heap.threshold, None);
heap.push(3.0, 1);
assert_eq!(heap.threshold, Some(1.0));
heap.push(2.0, 2); // evicts 1.0
assert_eq!(heap.threshold, Some(2.0));
heap.push(4.0, 3); // evicts 2.0
assert_eq!(heap.threshold, Some(3.0));
}
#[test]
fn test_top_n_heap_tiebreaking_lower_doc_wins() {
let mut heap = TopNHeap::new(2);
heap.push(5.0, 0);
heap.push(5.0, 1);
heap.push(5.0, 2); // rejected: not strictly > threshold
let mut results = heap.into_vec();
results.sort_by_key(|&(_, doc)| doc);
assert_eq!(results, vec![(5.0, 0), (5.0, 1)]);
}
#[test]
fn test_top_n_heap_single_element() {
let mut heap = TopNHeap::new(1);
heap.push(1.0, 0);
assert_eq!(heap.threshold, Some(1.0));
heap.push(0.5, 1); // rejected
heap.push(2.0, 2); // accepted
assert_eq!(heap.threshold, Some(2.0));
let results = heap.into_vec();
assert_eq!(results, vec![(2.0, 2)]);
}
#[test]
fn test_top_n_heap_under_capacity() {
let mut heap = TopNHeap::new(5);
heap.push(3.0, 0);
heap.push(1.0, 1);
heap.push(2.0, 2);
// Only 3 elements, capacity is 5 — all should be kept
assert_eq!(heap.threshold, None);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 0), (2.0, 2), (1.0, 1)]);
}
proptest! {
#[test]
fn test_top_n_heap_matches_top_n_computer(
limit in 0..20_usize,
mut docs in proptest::collection::vec((0..1000_u32, 0..1000_u32), 0..200_usize),
) {
// Both require ascending doc order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
docs.dedup_by_key(|(_, doc_id)| *doc_id);
let mut heap = TopNHeap::new(limit);
let mut computer: TopNComputer<Score, DocId, NaturalComparator> =
TopNComputer::new_with_comparator(limit, NaturalComparator);
for &(score_u32, doc) in &docs {
let score = score_u32 as Score;
heap.push(score, doc);
computer.push(score, doc);
}
let mut heap_results = heap.into_vec();
heap_results.sort_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1))
});
let computer_results: Vec<(Score, DocId)> = computer
.into_sorted_vec()
.into_iter()
.map(|cd| (cd.sort_key, cd.doc))
.collect();
prop_assert_eq!(heap_results, computer_results);
}
}
}

View File

@@ -52,7 +52,7 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
if schema_type != T::to_type() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
&self.field,
self.field,
T::to_type()
)));
}

View File

@@ -513,7 +513,9 @@ pub struct TopNComputer<Score, D, C> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
/// The current threshold for pruning. Documents with scores at or below
/// this value are skipped by `push()`. Updated when the buffer is truncated.
pub threshold: Option<Score>,
comparator: C,
}

View File

@@ -676,7 +676,7 @@ mod tests {
let num_segments = reader.searcher().segment_readers().len();
assert!(num_segments <= 4);
let num_components_except_deletes_and_tempstore =
crate::index::SegmentComponent::iterator().len() - 1;
crate::index::SegmentComponent::iterator().len() - 2;
let max_num_mmapped = num_components_except_deletes_and_tempstore * num_segments;
assert_eventually(|| {
let num_mmapped = mmap_directory.get_cache_info().mmapped.len();

View File

@@ -1,5 +1,7 @@
use std::borrow::{Borrow, BorrowMut};
use common::TinySet;
use crate::fastfield::AliveBitSet;
use crate::DocId;
@@ -14,6 +16,12 @@ pub const TERMINATED: DocId = i32::MAX as u32;
/// exactly this size as long as we can fill the buffer.
pub const COLLECT_BLOCK_BUFFER_LEN: usize = 64;
/// Number of `TinySet` (64-bit) buckets in a block used by [`DocSet::fill_bitset_block`].
pub const BLOCK_NUM_TINYBITSETS: usize = 16;
/// Number of doc IDs covered by one block: `BLOCK_NUM_TINYBITSETS * 64 = 1024`.
pub const BLOCK_WINDOW: u32 = BLOCK_NUM_TINYBITSETS as u32 * 64;
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
@@ -160,6 +168,31 @@ pub trait DocSet: Send {
self.size_hint() as u64
}
/// Fills a bitmask representing which documents in `[min_doc, min_doc + BLOCK_WINDOW)` are
/// present in this docset.
///
/// The window is divided into `BLOCK_NUM_TINYBITSETS` buckets of 64 docs each.
/// Returns the next doc `>= min_doc + BLOCK_WINDOW`, or `TERMINATED` if exhausted.
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
self.seek(min_doc);
let horizon = min_doc + BLOCK_WINDOW;
loop {
let doc = self.doc();
if doc >= horizon {
return doc;
}
let delta = doc - min_doc;
mask[(delta / 64) as usize].insert_mut(delta % 64);
if self.advance() == TERMINATED {
return TERMINATED;
}
}
}
/// Returns the number documents matching.
/// Calling this method consumes the `DocSet`.
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
@@ -214,6 +247,18 @@ impl DocSet for &mut dyn DocSet {
(**self).seek_danger(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
(**self).fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
(**self).fill_bitset_block(min_doc, mask)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -256,6 +301,15 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_bitset_block(min_doc, mask)
}
fn doc(&self) -> DocId {
let unboxed: &TDocSet = self.borrow();
unboxed.doc()

View File

@@ -127,7 +127,7 @@ mod tests {
fast_field_writers
.add_document(&doc!(*FIELD=>2u64))
.unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -178,7 +178,7 @@ mod tests {
fast_field_writers
.add_document(&doc!(*FIELD=>215u64))
.unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -211,7 +211,7 @@ mod tests {
.add_document(&doc!(*FIELD=>100_000u64))
.unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -243,7 +243,7 @@ mod tests {
.add_document(&doc!(*FIELD=>5_000_000_000_000_000_000u64 + doc_id))
.unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -276,7 +276,7 @@ mod tests {
doc.add_i64(i64_field, i);
fast_field_writers.add_document(&doc).unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -315,7 +315,7 @@ mod tests {
let mut fast_field_writers = FastFieldsWriter::from_schema(&schema).unwrap();
let doc = TantivyDocument::default();
fast_field_writers.add_document(&doc).unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
@@ -348,7 +348,7 @@ mod tests {
let mut fast_field_writers = FastFieldsWriter::from_schema(&schema).unwrap();
let doc = TantivyDocument::default();
fast_field_writers.add_document(&doc).unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
@@ -385,7 +385,7 @@ mod tests {
for &x in &permutation {
fast_field_writers.add_document(&doc!(*FIELD=>x)).unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -770,7 +770,7 @@ mod tests {
fast_field_writers
.add_document(&doc!(field=>false))
.unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -802,7 +802,7 @@ mod tests {
.add_document(&doc!(field=>false))
.unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -827,7 +827,7 @@ mod tests {
let mut fast_field_writers = FastFieldsWriter::from_schema(&schema).unwrap();
let doc = TantivyDocument::default();
fast_field_writers.add_document(&doc).unwrap();
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
let file = directory.open_read(path).unwrap();
@@ -855,7 +855,7 @@ mod tests {
for doc in docs {
fast_field_writers.add_document(doc).unwrap();
}
fast_field_writers.serialize(&mut write).unwrap();
fast_field_writers.serialize(&mut write, None).unwrap();
write.terminate().unwrap();
}
Ok(directory)

View File

@@ -4,6 +4,7 @@ use columnar::{ColumnarWriter, NumericalValue};
use common::{DateTimePrecision, JsonPathWriter};
use tokenizer_api::Token;
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::schema::document::{Document, ReferenceValue, ReferenceValueLeaf, Value};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema, Type};
use crate::tokenizer::{TextAnalyzer, TokenizerManager};
@@ -105,6 +106,16 @@ impl FastFieldsWriter {
self.columnar_writer.mem_usage()
}
pub(crate) fn sort_order(
&self,
sort_field: &str,
num_docs: DocId,
reversed: bool,
) -> Vec<DocId> {
self.columnar_writer
.sort_order(sort_field, num_docs, reversed)
}
/// Indexes all of the fastfields of a new document.
pub fn add_document<D: Document>(&mut self, doc: &D) -> crate::Result<()> {
let doc_id = self.num_docs;
@@ -222,9 +233,16 @@ impl FastFieldsWriter {
/// Serializes all of the `FastFieldWriter`s by pushing them in
/// order to the fast field serializer.
pub fn serialize(mut self, wrt: &mut dyn io::Write) -> io::Result<()> {
pub fn serialize(
mut self,
wrt: &mut dyn io::Write,
doc_id_map_opt: Option<&DocIdMapping>,
) -> io::Result<()> {
let num_docs = self.num_docs;
self.columnar_writer.serialize(num_docs, wrt)?;
let old_to_new_row_ids =
doc_id_map_opt.map(|doc_id_mapping| doc_id_mapping.old_to_new_ids());
self.columnar_writer
.serialize(num_docs, old_to_new_row_ids, wrt)?;
Ok(())
}
}
@@ -374,7 +392,7 @@ mod tests {
}
let mut buffer = Vec::new();
columnar_writer
.serialize(json_docs.len() as DocId, &mut buffer)
.serialize(json_docs.len() as DocId, None, &mut buffer)
.unwrap();
ColumnarReader::open(buffer).unwrap()
}

View File

@@ -77,7 +77,7 @@ mod tests {
let mut fieldnorm_writers = FieldNormsWriter::for_schema(&SCHEMA);
fieldnorm_writers.record(2u32, *TXT_FIELD, 5);
fieldnorm_writers.record(3u32, *TXT_FIELD, 3);
fieldnorm_writers.serialize(serializer)?;
fieldnorm_writers.serialize(serializer, None)?;
}
let file = directory.open_read(path)?;
{

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use std::{io, iter};
use super::{fieldnorm_to_id, FieldNormsSerializer};
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::schema::{Field, Schema};
use crate::DocId;
@@ -91,7 +92,11 @@ impl FieldNormsWriter {
}
/// Serialize the seen fieldnorm values to the serializer for all fields.
pub fn serialize(&self, mut fieldnorms_serializer: FieldNormsSerializer) -> io::Result<()> {
pub fn serialize(
&self,
mut fieldnorms_serializer: FieldNormsSerializer,
doc_id_map: Option<&DocIdMapping>,
) -> io::Result<()> {
for (field, fieldnorms_buffer) in self.fieldnorms_buffers.iter().enumerate().filter_map(
|(field_id, fieldnorms_buffer_opt)| {
fieldnorms_buffer_opt.as_ref().map(|fieldnorms_buffer| {
@@ -99,7 +104,12 @@ impl FieldNormsWriter {
})
},
) {
fieldnorms_serializer.serialize_field(field, fieldnorms_buffer)?;
if let Some(doc_id_map) = doc_id_map {
let remapped_fieldnorm_buffer = doc_id_map.remap(fieldnorms_buffer);
fieldnorms_serializer.serialize_field(field, &remapped_fieldnorm_buffer)?;
} else {
fieldnorms_serializer.serialize_field(field, fieldnorms_buffer)?;
}
}
fieldnorms_serializer.close()?;
Ok(())

View File

@@ -4,7 +4,8 @@ use rand::{rng, Rng};
use crate::indexer::index_writer::MEMORY_BUDGET_NUM_BYTES_MIN;
use crate::schema::*;
use crate::{doc, schema, Index, IndexWriter, Searcher};
#[allow(deprecated)]
use crate::{doc, schema, Index, IndexSettings, IndexSortByField, IndexWriter, Order, Searcher};
fn check_index_content(searcher: &Searcher, vals: &[u64]) -> crate::Result<()> {
assert!(searcher.segment_readers().len() < 20);
@@ -62,6 +63,71 @@ fn get_num_iterations() -> usize {
.map(|str| str.parse().unwrap())
.unwrap_or(2000)
}
#[test]
#[ignore]
fn test_functional_indexing_sorted() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED | FAST);
let multiples_field = schema_builder.add_u64_field("multiples", INDEXED);
let text_field_options = TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default()
.set_index_option(schema::IndexRecordOption::WithFreqsAndPositions),
)
.set_stored();
let text_field = schema_builder.add_text_field("text_field", text_field_options);
let schema = schema_builder.build();
let mut index_builder = Index::builder().schema(schema);
index_builder = index_builder.settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id".to_string(),
order: Order::Desc,
}),
..Default::default()
});
let index = index_builder.create_from_tempdir().unwrap();
let reader = index.reader()?;
let mut rng = rng();
let mut index_writer: IndexWriter =
index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?;
let mut committed_docs: HashSet<u64> = HashSet::new();
let mut uncommitted_docs: HashSet<u64> = HashSet::new();
for _ in 0..get_num_iterations() {
let random_val = rng.random_range(0..20);
if random_val == 0 {
index_writer.commit()?;
committed_docs.extend(&uncommitted_docs);
uncommitted_docs.clear();
reader.reload()?;
let searcher = reader.searcher();
// check that everything is correct.
check_index_content(
&searcher,
&committed_docs.iter().cloned().collect::<Vec<u64>>(),
)?;
} else if committed_docs.remove(&random_val) || uncommitted_docs.remove(&random_val) {
let doc_id_term = Term::from_field_u64(id_field, random_val);
index_writer.delete_term(doc_id_term);
} else {
uncommitted_docs.insert(random_val);
let mut doc = TantivyDocument::new();
doc.add_u64(id_field, random_val);
for i in 1u64..10u64 {
doc.add_u64(multiples_field, random_val * i);
}
doc.add_text(text_field, get_text());
index_writer.add_document(doc)?;
}
}
Ok(())
}
const LOREM: &str = "Doc Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod \
tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, \

View File

@@ -22,7 +22,7 @@ use crate::indexer::segment_updater::save_metas;
use crate::indexer::{IndexWriter, SingleSegmentIndexWriter};
use crate::reader::{IndexReader, IndexReaderBuilder};
use crate::schema::document::Document;
use crate::schema::{Field, FieldType, Schema};
use crate::schema::{Field, FieldType, Schema, Type};
use crate::tokenizer::{TextAnalyzer, TokenizerManager};
use crate::SegmentReader;
@@ -232,7 +232,38 @@ impl IndexBuilder {
}
fn validate(&self) -> crate::Result<()> {
if let Some(_schema) = self.schema.as_ref() {
if let Some(schema) = self.schema.as_ref() {
if let Some(sort_by_field) = self.index_settings.sort_by_field.as_ref() {
let schema_field = schema.get_field(&sort_by_field.field).map_err(|_| {
TantivyError::InvalidArgument(format!(
"Field to sort index {} not found in schema",
sort_by_field.field
))
})?;
let entry = schema.get_field_entry(schema_field);
if !entry.is_fast() {
return Err(TantivyError::InvalidArgument(format!(
"Field {} is no fast field. Field needs to be a single value fast field \
to be used to sort an index",
sort_by_field.field
)));
}
let supported_field_types = [
Type::I64,
Type::U64,
Type::F64,
Type::Date,
Type::Str,
Type::Bytes,
];
let field_type = entry.field_type().value_type();
if !supported_field_types.contains(&field_type) {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported field type in sort_by_field: {field_type:?}. Supported field \
types: {supported_field_types:?} ",
)));
}
}
Ok(())
} else {
Err(TantivyError::InvalidArgument(

View File

@@ -1,6 +1,8 @@
use std::collections::HashSet;
use std::fmt;
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
@@ -35,6 +37,7 @@ impl SegmentMetaInventory {
let inner = InnerSegmentMeta {
segment_id,
max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: None,
};
SegmentMeta::from(self.inventory.track(inner))
@@ -82,6 +85,15 @@ impl SegmentMeta {
self.tracked.segment_id
}
/// Removes the Component::TempStore from the alive list and
/// therefore marks the temp docstore file to be deleted by
/// the garbage collection.
pub fn untrack_temp_docstore(&self) {
self.tracked
.include_temp_doc_store
.store(false, std::sync::atomic::Ordering::Relaxed);
}
/// Returns the number of deleted documents.
pub fn num_deleted_docs(&self) -> u32 {
self.tracked
@@ -99,9 +111,20 @@ impl SegmentMeta {
/// is by removing all files that have been created by tantivy
/// and are not used by any segment anymore.
pub fn list_files(&self) -> HashSet<PathBuf> {
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
if self
.tracked
.include_temp_doc_store
.load(std::sync::atomic::Ordering::Relaxed)
{
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
} else {
SegmentComponent::iterator()
.filter(|comp| *comp != &SegmentComponent::TempStore)
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
}
}
/// Returns the relative path of a component of our segment.
@@ -115,6 +138,7 @@ impl SegmentMeta {
SegmentComponent::Positions => ".pos".to_string(),
SegmentComponent::Terms => ".term".to_string(),
SegmentComponent::Store => ".store".to_string(),
SegmentComponent::TempStore => ".store.temp".to_string(),
SegmentComponent::FastFields => ".fast".to_string(),
SegmentComponent::FieldNorms => ".fieldnorm".to_string(),
SegmentComponent::Delete => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
@@ -159,6 +183,7 @@ impl SegmentMeta {
segment_id: inner_meta.segment_id,
max_doc,
deletes: None,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
});
SegmentMeta { tracked }
}
@@ -177,6 +202,7 @@ impl SegmentMeta {
let tracked = self.tracked.map(move |inner_meta| InnerSegmentMeta {
segment_id: inner_meta.segment_id,
max_doc: inner_meta.max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: Some(delete_meta),
});
SegmentMeta { tracked }
@@ -188,6 +214,14 @@ struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
pub deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
#[serde(default = "default_temp_store")]
pub(crate) include_temp_doc_store: Arc<AtomicBool>,
}
fn default_temp_store() -> Arc<AtomicBool> {
Arc::new(AtomicBool::new(false))
}
impl InnerSegmentMeta {
@@ -212,6 +246,10 @@ fn is_true(val: &bool) -> bool {
/// index, like presort documents.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub struct IndexSettings {
/// Sorts the documents by information
/// provided in `IndexSortByField`
#[serde(skip_serializing_if = "Option::is_none")]
pub sort_by_field: Option<IndexSortByField>,
/// The `Compressor` used to compress the doc store.
#[serde(default)]
pub docstore_compression: Compressor,
@@ -234,6 +272,7 @@ fn default_docstore_blocksize() -> usize {
impl Default for IndexSettings {
fn default() -> Self {
Self {
sort_by_field: None,
docstore_compression: Compressor::default(),
docstore_blocksize: default_docstore_blocksize(),
docstore_compress_dedicated_thread: true,
@@ -241,6 +280,18 @@ impl Default for IndexSettings {
}
}
/// Settings to presort the documents in an index
///
/// Presorting documents can greatly improve performance
/// in some scenarios, by applying top n
/// optimizations.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub struct IndexSortByField {
/// The field to sort the documents by
pub field: String,
/// The order to sort the documents by
pub order: Order,
}
/// The order to sort by
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum Order {
@@ -360,7 +411,7 @@ mod tests {
use crate::store::Compressor;
#[cfg(feature = "zstd-compression")]
use crate::store::ZstdCompressor;
use crate::IndexSettings;
use crate::{IndexSettings, IndexSortByField, Order};
#[test]
fn test_serialize_metas() {
@@ -372,6 +423,10 @@ mod tests {
let index_metas = IndexMeta {
index_settings: IndexSettings {
docstore_compression: Compressor::None,
sort_by_field: Some(IndexSortByField {
field: "text".to_string(),
order: Order::Asc,
}),
..Default::default()
},
segments: Vec::new(),
@@ -382,7 +437,7 @@ mod tests {
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
json,
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
r#"{"index_settings":{"sort_by_field":{"field":"text","order":"Asc"},"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
@@ -401,6 +456,10 @@ mod tests {
};
let index_metas = IndexMeta {
index_settings: IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "text".to_string(),
order: Order::Asc,
}),
docstore_compression: crate::store::Compressor::Zstd(ZstdCompressor {
compression_level: Some(4),
}),
@@ -415,7 +474,7 @@ mod tests {
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
json,
r#"{"index_settings":{"docstore_compression":"zstd(compression_level=4)","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
r#"{"index_settings":{"sort_by_field":{"field":"text","order":"Asc"},"docstore_compression":"zstd(compression_level=4)","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
@@ -427,35 +486,35 @@ mod tests {
#[test]
#[cfg(all(feature = "lz4-compression", feature = "zstd-compression"))]
fn test_serialize_metas_invalid_comp() {
let json = r#"{"index_settings":{"docstore_compression":"zsstd","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let json = r#"{"index_settings":{"sort_by_field":{"field":"text","order":"Asc"},"docstore_compression":"zsstd","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let err = serde_json::from_str::<UntrackedIndexMeta>(json).unwrap_err();
assert_eq!(
err.to_string(),
"unknown variant `zsstd`, expected one of `none`, `lz4`, `zstd`, \
`zstd(compression_level=5)` at line 1 column 49"
`zstd(compression_level=5)` at line 1 column 96"
.to_string()
);
let json = r#"{"index_settings":{"docstore_compression":"zstd(bla=10)","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let json = r#"{"index_settings":{"sort_by_field":{"field":"text","order":"Asc"},"docstore_compression":"zstd(bla=10)","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let err = serde_json::from_str::<UntrackedIndexMeta>(json).unwrap_err();
assert_eq!(
err.to_string(),
"unknown zstd option \"bla\" at line 1 column 56".to_string()
"unknown zstd option \"bla\" at line 1 column 103".to_string()
);
}
#[test]
#[cfg(not(feature = "zstd-compression"))]
fn test_serialize_metas_unsupported_comp() {
let json = r#"{"index_settings":{"docstore_compression":"zstd","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let json = r#"{"index_settings":{"sort_by_field":{"field":"text","order":"Asc"},"docstore_compression":"zstd","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#;
let err = serde_json::from_str::<UntrackedIndexMeta>(json).unwrap_err();
assert_eq!(
err.to_string(),
"unsupported variant `zstd`, please enable Tantivy's `zstd-compression` feature at \
line 1 column 48"
line 1 column 95"
.to_string()
);
}
@@ -469,6 +528,7 @@ mod tests {
assert_eq!(
index_settings,
IndexSettings {
sort_by_field: None,
docstore_compression: Compressor::default(),
docstore_compress_dedicated_thread: true,
docstore_blocksize: 16_384

View File

@@ -12,7 +12,7 @@ mod segment_reader;
pub use self::index::{Index, IndexBuilder};
pub(crate) use self::index_meta::SegmentMetaInventory;
pub use self::index_meta::{IndexMeta, IndexSettings, Order, SegmentMeta};
pub use self::index_meta::{IndexMeta, IndexSettings, IndexSortByField, Order, SegmentMeta};
pub use self::inverted_index_reader::InvertedIndexReader;
pub use self::segment::Segment;
pub use self::segment_component::SegmentComponent;

View File

@@ -23,6 +23,8 @@ pub enum SegmentComponent {
/// Accessing a document from the store is relatively slow, as it
/// requires to decompress the entire block it belongs to.
Store,
/// Temporary storage of the documents, before streamed to `Store`.
TempStore,
/// Bitset describing which document of the segment is alive.
/// (It was representing deleted docs but changed to represent alive docs from v0.17)
Delete,
@@ -31,13 +33,14 @@ pub enum SegmentComponent {
impl SegmentComponent {
/// Iterates through the components.
pub fn iterator() -> slice::Iter<'static, SegmentComponent> {
static SEGMENT_COMPONENTS: [SegmentComponent; 7] = [
static SEGMENT_COMPONENTS: [SegmentComponent; 8] = [
SegmentComponent::Postings,
SegmentComponent::Positions,
SegmentComponent::FastFields,
SegmentComponent::FieldNorms,
SegmentComponent::Terms,
SegmentComponent::Store,
SegmentComponent::TempStore,
SegmentComponent::Delete,
];
SEGMENT_COMPONENTS.iter()

View File

@@ -6,6 +6,7 @@ use common::{ByteCount, HasLen};
use fnv::FnvHashMap;
use itertools::Itertools;
use crate::directory::error::OpenReadError;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
@@ -159,12 +160,10 @@ impl SegmentReader {
let postings_file = segment.open_read(SegmentComponent::Postings)?;
let postings_composite = CompositeFile::open(&postings_file)?;
let positions_composite = {
if let Ok(positions_file) = segment.open_read(SegmentComponent::Positions) {
CompositeFile::open(&positions_file)?
} else {
CompositeFile::empty()
}
let positions_composite = match segment.open_read(SegmentComponent::Positions) {
Ok(positions_file) => CompositeFile::open(&positions_file)?,
Err(OpenReadError::FileDoesNotExist(_)) => CompositeFile::empty(),
Err(open_read_error) => return Err(open_read_error.into()),
};
let schema = segment.schema();
@@ -323,7 +322,7 @@ impl SegmentReader {
// Without expand dots enabled dots need to be escaped.
let escaped_json_path = json_path.replace('.', "\\.");
let full_path = format!("{field_name}.{escaped_json_path}");
let full_path_unescaped = format!("{}.{}", field_name, &json_path);
let full_path_unescaped = format!("{}.{}", field_name, json_path);
map_to_canonical.insert(full_path_unescaped, full_path.to_string());
full_path
} else {

View File

@@ -3,17 +3,28 @@
use common::ReadOnlyBitSet;
use crate::DocAddress;
use super::SegmentWriter;
use crate::schema::{Field, Schema};
use crate::{DocAddress, DocId, IndexSortByField, TantivyError};
/// Describes how the document ID mapping was produced during a merge.
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum MappingType {
/// Segments are concatenated in order with no deletes; doc IDs are contiguous ranges.
Stacked,
/// Segments are concatenated in order but some documents have been deleted and are skipped.
StackedWithDeletes,
/// Documents have been reordered (e.g. sorted by a field or externally shuffled).
Shuffled,
}
/// Struct to provide mapping from new doc_id to old doc_id and segment.
///
/// Callers outside tantivy (e.g. pomsky's merge executor) can construct a
/// `Shuffled` mapping directly from a precomputed permutation and pass it
/// into [`IndexMerger::write_with_doc_id_mapping`].
#[derive(Clone)]
pub(crate) struct SegmentDocIdMapping {
pub struct SegmentDocIdMapping {
pub(crate) new_doc_id_to_old_doc_addr: Vec<DocAddress>,
pub(crate) alive_bitsets: Vec<Option<ReadOnlyBitSet>>,
mapping_type: MappingType,
@@ -32,6 +43,25 @@ impl SegmentDocIdMapping {
}
}
/// Build a `Shuffled` mapping from an explicit permutation of [`DocAddress`]es.
///
/// `new_doc_id_to_old_doc_addr[new_id]` gives the source segment and doc id for
/// the document that should appear at position `new_id` in the merged segment.
/// `alive_bitsets` must contain one entry per source segment (in the same order
/// as passed to [`IndexMerger::open_with_custom_alive_set`]), each `None` if that
/// segment has no deletes.
pub fn new_shuffled(
new_doc_id_to_old_doc_addr: Vec<DocAddress>,
alive_bitsets: Vec<Option<ReadOnlyBitSet>>,
) -> Self {
Self {
new_doc_id_to_old_doc_addr,
mapping_type: MappingType::Shuffled,
alive_bitsets,
}
}
/// Returns the [`MappingType`] that describes how this mapping was constructed.
pub fn mapping_type(&self) -> MappingType {
self.mapping_type
}
@@ -43,4 +73,559 @@ impl SegmentDocIdMapping {
pub(crate) fn iter_old_doc_addrs(&self) -> impl Iterator<Item = DocAddress> + '_ {
self.new_doc_id_to_old_doc_addr.iter().copied()
}
/// This flags means the segments are simply stacked in the order of their ordinal.
/// e.g. [(0, 1), .. (n, 1), (0, 2)..., (m, 2)]
///
/// The different segment may present some deletes, in which case it is expressed by skipping a
/// `DocId`. [(0, 1), (0, 3)] <--- here doc_id=0 and doc_id=1 have been deleted
///
/// Being trivial is equivalent to having the `new_doc_id_to_old_doc_addr` array sorted.
///
/// This allows for some optimization.
pub(crate) fn is_trivial(&self) -> bool {
match self.mapping_type {
MappingType::Stacked | MappingType::StackedWithDeletes => true,
MappingType::Shuffled => false,
}
}
}
/// Bidirectional mapping between old and new doc IDs within a single segment.
pub struct DocIdMapping {
new_doc_id_to_old: Vec<DocId>,
old_doc_id_to_new: Vec<DocId>,
}
impl DocIdMapping {
/// Constructs a [`DocIdMapping`] from a vector mapping each new doc ID to its old doc ID.
pub fn from_new_id_to_old_id(new_doc_id_to_old: Vec<DocId>) -> Self {
let max_doc = new_doc_id_to_old.len();
let old_max_doc = new_doc_id_to_old
.iter()
.cloned()
.max()
.map(|n| n + 1)
.unwrap_or(0);
let mut old_doc_id_to_new = vec![0; old_max_doc as usize];
for i in 0..max_doc {
old_doc_id_to_new[new_doc_id_to_old[i] as usize] = i as DocId;
}
DocIdMapping {
new_doc_id_to_old,
old_doc_id_to_new,
}
}
/// returns the new doc_id for the old doc_id
pub fn get_new_doc_id(&self, doc_id: DocId) -> DocId {
self.old_doc_id_to_new[doc_id as usize]
}
/// returns the old doc_id for the new doc_id
pub fn get_old_doc_id(&self, doc_id: DocId) -> DocId {
self.new_doc_id_to_old[doc_id as usize]
}
/// iterate over old doc_ids in order of the new doc_ids
pub fn iter_old_doc_ids(&self) -> impl Iterator<Item = DocId> + Clone + '_ {
self.new_doc_id_to_old.iter().cloned()
}
/// Returns a slice mapping each old doc ID to its corresponding new doc ID.
pub fn old_to_new_ids(&self) -> &[DocId] {
&self.old_doc_id_to_new[..]
}
/// Remaps a given array to the new doc ids.
pub fn remap<T: Copy>(&self, els: &[T]) -> Vec<T> {
self.new_doc_id_to_old
.iter()
.map(|old_doc| els[*old_doc as usize])
.collect()
}
/// Returns the number of new (post-sort) doc IDs in this mapping.
pub fn num_new_doc_ids(&self) -> usize {
self.new_doc_id_to_old.len()
}
/// Returns the number of old (pre-sort) doc IDs covered by this mapping.
pub fn num_old_doc_ids(&self) -> usize {
self.old_doc_id_to_new.len()
}
}
pub(crate) fn expect_field_id_for_sort_field(
schema: &Schema,
sort_by_field: &IndexSortByField,
) -> crate::Result<Field> {
schema.get_field(&sort_by_field.field).map_err(|_| {
TantivyError::InvalidArgument(format!(
"field to sort index by not found: {:?}",
sort_by_field.field
))
})
}
// Generates a document mapping in the form of [index new doc_id] -> old doc_id
// TODO detect if field is already sorted and discard mapping
pub(crate) fn get_doc_id_mapping_from_field(
sort_by_field: IndexSortByField,
segment_writer: &SegmentWriter,
) -> crate::Result<DocIdMapping> {
let schema = segment_writer.segment_serializer.segment().schema();
expect_field_id_for_sort_field(&schema, &sort_by_field)?; // for now expect
let new_doc_id_to_old = segment_writer.fast_field_writers.sort_order(
sort_by_field.field.as_str(),
segment_writer.max_doc(),
sort_by_field.order.is_desc(),
);
// create new doc_id to old doc_id index (used in fast_field_writers)
Ok(DocIdMapping::from_new_id_to_old_id(new_doc_id_to_old))
}
#[cfg(test)]
mod tests_indexsorting {
use common::DateTime;
use crate::collector::TopDocs;
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::indexer::NoMergePolicy;
use crate::query::QueryParser;
use crate::schema::*;
use crate::{DocAddress, Index, IndexBuilder, IndexSettings, IndexSortByField, Order};
fn create_test_index(
index_settings: Option<IndexSettings>,
text_field_options: TextOptions,
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let my_text_field = schema_builder.add_text_field("text_field", text_field_options);
let my_string_field = schema_builder.add_text_field("string_field", STRING | STORED);
let my_number =
schema_builder.add_u64_field("my_number", NumericOptions::default().set_fast());
let multi_numbers =
schema_builder.add_u64_field("multi_numbers", NumericOptions::default().set_fast());
let schema = schema_builder.build();
let mut index_builder = Index::builder().schema(schema);
if let Some(settings) = index_settings {
index_builder = index_builder.settings(settings);
}
let index = index_builder.create_in_ram()?;
let mut index_writer = index.writer_for_tests()?;
index_writer.add_document(doc!(my_number=>40_u64))?;
index_writer.add_document(
doc!(my_number=>20_u64, multi_numbers => 5_u64, multi_numbers => 6_u64),
)?;
index_writer.add_document(doc!(my_number=>100_u64))?;
index_writer.add_document(
doc!(my_number=>10_u64, my_string_field=> "blublub", my_text_field => "some text"),
)?;
index_writer.add_document(doc!(my_number=>30_u64, multi_numbers => 3_u64 ))?;
index_writer.commit()?;
Ok(index)
}
fn get_text_options() -> TextOptions {
TextOptions::default().set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::Basic),
)
}
#[test]
fn test_sort_index_test_text_field() -> crate::Result<()> {
// there are different serializers for different settings in postings/recorder.rs
// test remapping for all of them
let options = vec![
get_text_options(),
get_text_options().set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
),
get_text_options().set_indexing_options(
TextFieldIndexing::default()
.set_index_option(IndexRecordOption::WithFreqsAndPositions),
),
];
for option in options {
// let options = get_text_options();
// no index_sort
let index = create_test_index(None, option.clone())?;
let my_text_field = index.schema().get_field("text_field").unwrap();
let searcher = index.reader()?.searcher();
let query = QueryParser::for_index(&index, vec![my_text_field]).parse_query("text")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![3]
);
// sort by field asc
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Asc,
}),
..Default::default()
}),
option.clone(),
)?;
let my_text_field = index.schema().get_field("text_field").unwrap();
let reader = index.reader()?;
let searcher = reader.searcher();
let query = QueryParser::for_index(&index, vec![my_text_field]).parse_query("text")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![0]
);
// test new field norm mapping
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let fieldnorm_reader = searcher
.segment_reader(0)
.get_fieldnorms_reader(my_text_field)?;
assert_eq!(fieldnorm_reader.fieldnorm(0), 2); // some text
assert_eq!(fieldnorm_reader.fieldnorm(1), 0);
}
// sort by field desc
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
option.clone(),
)?;
let my_string_field = index.schema().get_field("text_field").unwrap();
let searcher = index.reader()?.searcher();
let query =
QueryParser::for_index(&index, vec![my_string_field]).parse_query("text")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![4]
);
// test new field norm mapping
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let fieldnorm_reader = searcher
.segment_reader(0)
.get_fieldnorms_reader(my_text_field)?;
assert_eq!(fieldnorm_reader.fieldnorm(0), 0);
assert_eq!(fieldnorm_reader.fieldnorm(1), 0);
assert_eq!(fieldnorm_reader.fieldnorm(2), 0);
assert_eq!(fieldnorm_reader.fieldnorm(3), 0);
assert_eq!(fieldnorm_reader.fieldnorm(4), 2); // some text
}
}
Ok(())
}
#[test]
fn test_sort_index_get_documents() -> crate::Result<()> {
// default baseline
let index = create_test_index(None, get_text_options())?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let searcher = index.reader()?.searcher();
{
assert!(searcher
.doc::<TantivyDocument>(DocAddress::new(0, 0))?
.get_first(my_string_field)
.is_none());
assert_eq!(
searcher
.doc::<TantivyDocument>(DocAddress::new(0, 3))?
.get_first(my_string_field)
.unwrap()
.as_str(),
Some("blublub")
);
}
// sort by field asc
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Asc,
}),
..Default::default()
}),
get_text_options(),
)?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let searcher = index.reader()?.searcher();
{
assert_eq!(
searcher
.doc::<TantivyDocument>(DocAddress::new(0, 0))?
.get_first(my_string_field)
.unwrap()
.as_str(),
Some("blublub")
);
let doc = searcher.doc::<TantivyDocument>(DocAddress::new(0, 4))?;
assert!(doc.get_first(my_string_field).is_none());
}
// sort by field desc
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
get_text_options(),
)?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let searcher = index.reader()?.searcher();
{
let doc = searcher.doc::<TantivyDocument>(DocAddress::new(0, 4))?;
assert_eq!(
doc.get_first(my_string_field).unwrap().as_str(),
Some("blublub")
);
}
Ok(())
}
#[test]
fn test_sort_index_test_string_field() -> crate::Result<()> {
let index = create_test_index(None, get_text_options())?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let searcher = index.reader()?.searcher();
let query = QueryParser::for_index(&index, vec![my_string_field]).parse_query("blublub")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![3]
);
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Asc,
}),
..Default::default()
}),
get_text_options(),
)?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let reader = index.reader()?;
let searcher = reader.searcher();
let query = QueryParser::for_index(&index, vec![my_string_field]).parse_query("blublub")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![0]
);
// test new field norm mapping
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let fieldnorm_reader = searcher
.segment_reader(0)
.get_fieldnorms_reader(my_text_field)?;
assert_eq!(fieldnorm_reader.fieldnorm(0), 2); // some text
assert_eq!(fieldnorm_reader.fieldnorm(1), 0);
}
// sort by field desc
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
get_text_options(),
)?;
let my_string_field = index.schema().get_field("string_field").unwrap();
let searcher = index.reader()?.searcher();
let query = QueryParser::for_index(&index, vec![my_string_field]).parse_query("blublub")?;
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>(),
vec![4]
);
// test new field norm mapping
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let fieldnorm_reader = searcher
.segment_reader(0)
.get_fieldnorms_reader(my_text_field)?;
assert_eq!(fieldnorm_reader.fieldnorm(0), 0);
assert_eq!(fieldnorm_reader.fieldnorm(1), 0);
assert_eq!(fieldnorm_reader.fieldnorm(2), 0);
assert_eq!(fieldnorm_reader.fieldnorm(3), 0);
assert_eq!(fieldnorm_reader.fieldnorm(4), 2); // some text
}
Ok(())
}
#[test]
fn test_sort_index_fast_field() -> crate::Result<()> {
let index = create_test_index(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "my_number".to_string(),
order: Order::Asc,
}),
..Default::default()
}),
get_text_options(),
)?;
assert_eq!(
index.settings().sort_by_field.as_ref().unwrap().field,
"my_number".to_string()
);
let searcher = index.reader()?.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
let fast_fields = segment_reader.fast_fields();
let fast_field = fast_fields
.u64("my_number")
.unwrap()
.first_or_default_col(999);
assert_eq!(fast_field.get_val(0), 10u64);
assert_eq!(fast_field.get_val(1), 20u64);
assert_eq!(fast_field.get_val(2), 30u64);
let multifield = fast_fields.u64("multi_numbers").unwrap();
let vals: Vec<u64> = multifield.values_for_doc(0u32).collect();
assert_eq!(vals, &[] as &[u64]);
let vals: Vec<_> = multifield.values_for_doc(1u32).collect();
assert_eq!(vals, &[5, 6]);
let vals: Vec<_> = multifield.values_for_doc(2u32).collect();
assert_eq!(vals, &[3]);
Ok(())
}
#[test]
fn test_with_sort_by_date_field() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let date_field = schema_builder.add_date_field("date", INDEXED | STORED | FAST);
let schema = schema_builder.build();
let settings = IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "date".to_string(),
order: Order::Desc,
}),
..Default::default()
};
let index = Index::builder()
.schema(schema)
.settings(settings)
.create_in_ram()?;
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
index_writer.add_document(doc!(
date_field => DateTime::from_timestamp_secs(1000),
))?;
index_writer.add_document(doc!(
date_field => DateTime::from_timestamp_secs(999),
))?;
index_writer.add_document(doc!(
date_field => DateTime::from_timestamp_secs(1001),
))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
let fast_fields = segment_reader.fast_fields();
let fast_field = fast_fields
.date("date")
.unwrap()
.first_or_default_col(DateTime::from_timestamp_secs(0));
assert_eq!(fast_field.get_val(0), DateTime::from_timestamp_secs(1001));
assert_eq!(fast_field.get_val(1), DateTime::from_timestamp_secs(1000));
assert_eq!(fast_field.get_val(2), DateTime::from_timestamp_secs(999));
Ok(())
}
#[test]
fn test_doc_mapping() {
let doc_mapping = DocIdMapping::from_new_id_to_old_id(vec![3, 2, 5]);
assert_eq!(doc_mapping.get_old_doc_id(0), 3);
assert_eq!(doc_mapping.get_old_doc_id(1), 2);
assert_eq!(doc_mapping.get_old_doc_id(2), 5);
assert_eq!(doc_mapping.get_new_doc_id(0), 0);
assert_eq!(doc_mapping.get_new_doc_id(1), 0);
assert_eq!(doc_mapping.get_new_doc_id(2), 1);
assert_eq!(doc_mapping.get_new_doc_id(3), 0);
assert_eq!(doc_mapping.get_new_doc_id(4), 0);
assert_eq!(doc_mapping.get_new_doc_id(5), 2);
}
#[test]
fn test_doc_mapping_remap() {
let doc_mapping = DocIdMapping::from_new_id_to_old_id(vec![2, 8, 3]);
assert_eq!(
&doc_mapping.remap(&[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000]),
&[2000, 8000, 3000]
);
}
#[test]
fn test_text_sort() -> crate::Result<()> {
let mut schema_builder = SchemaBuilder::new();
let id_field = schema_builder.add_text_field("id", STRING | FAST | STORED);
schema_builder.add_text_field("name", TEXT | STORED);
let index = IndexBuilder::new()
.schema(schema_builder.build())
.settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id".to_string(),
order: Order::Asc,
}),
..Default::default()
})
.create_in_ram()?;
let mut index_writer = index.writer_for_tests()?;
index_writer.add_document(doc!(id_field => "z"))?;
index_writer.add_document(doc!(id_field => "a"))?;
index_writer.add_document(doc!(id_field => "m"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let str_col = segment_reader.fast_fields().str("id")?.unwrap();
let mut values = Vec::new();
for doc in 0..segment_reader.max_doc() {
if let Some(ord) = str_col.ords().first(doc) {
let mut s = String::new();
str_col.ord_to_str(ord, &mut s)?;
values.push(s);
}
}
assert_eq!(values, vec!["a", "m", "z"]);
Ok(())
}
}

View File

@@ -218,7 +218,7 @@ fn index_documents<D: Document>(
let alive_bitset_opt = apply_deletes(&segment_with_max_doc, &mut delete_cursor, &doc_opstamps)?;
let meta = segment_with_max_doc.meta().clone();
meta.untrack_temp_docstore();
// update segment_updater inventory to remove tempstore
let segment_entry = SegmentEntry::new(meta, delete_cursor, alive_bitset_opt);
segment_updater.schedule_add_segment(segment_entry).wait()?;
@@ -819,7 +819,7 @@ mod tests {
use std::collections::{HashMap, HashSet};
use std::net::Ipv6Addr;
use columnar::{Column, MonotonicallyMappableToU128};
use columnar::{Cardinality, Column, MonotonicallyMappableToU128};
use itertools::Itertools;
use proptest::prop_oneof;
@@ -829,7 +829,7 @@ mod tests {
use crate::error::*;
use crate::indexer::index_writer::MEMORY_BUDGET_NUM_BYTES_MIN;
use crate::indexer::{IndexWriterOptions, NoMergePolicy};
use crate::query::{QueryParser, TermQuery};
use crate::query::{BooleanQuery, Occur, Query, QueryParser, TermQuery};
use crate::schema::{
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, JsonObjectOptions,
NumericOptions, Schema, TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED,
@@ -837,8 +837,8 @@ mod tests {
};
use crate::store::DOCSTORE_CACHE_CAPACITY;
use crate::{
DateTime, DocAddress, Index, IndexSettings, IndexWriter, ReloadPolicy, TantivyDocument,
Term,
DateTime, DocAddress, Index, IndexSettings, IndexSortByField, IndexWriter, Order,
ReloadPolicy, TantivyDocument, Term,
};
const LOREM: &str = "Doc Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do \
@@ -1479,6 +1479,116 @@ mod tests {
assert!(text_fast_field.term_ords(1).eq([1].into_iter()));
}
#[test]
fn test_delete_with_sort_by_field() -> crate::Result<()> {
let mut schema_builder = schema::Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED | schema::STORED | FAST);
let schema = schema_builder.build();
let settings = IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id".to_string(),
order: Order::Desc,
}),
..Default::default()
};
let index = Index::builder()
.schema(schema)
.settings(settings)
.create_in_ram()?;
let index_reader = index.reader()?;
let mut index_writer = index.writer_for_tests()?;
// create and delete docs in same commit
for id in 0u64..5u64 {
index_writer.add_document(doc!(id_field => id))?;
}
for id in 2u64..4u64 {
index_writer.delete_term(Term::from_field_u64(id_field, id));
}
for id in 5u64..10u64 {
index_writer.add_document(doc!(id_field => id))?;
}
index_writer.commit()?;
index_reader.reload()?;
let searcher = index_reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
assert_eq!(segment_reader.num_docs(), 8);
assert_eq!(segment_reader.max_doc(), 10);
let fast_field_reader = segment_reader.fast_fields().u64("id")?;
let in_order_alive_ids: Vec<u64> = segment_reader
.doc_ids_alive()
.flat_map(|doc| fast_field_reader.values_for_doc(doc))
.collect();
assert_eq!(&in_order_alive_ids[..], &[9, 8, 7, 6, 5, 4, 1, 0]);
Ok(())
}
#[test]
fn test_delete_query_with_sort_by_field() -> crate::Result<()> {
let mut schema_builder = schema::Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED | schema::STORED | FAST);
let schema = schema_builder.build();
let settings = IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id".to_string(),
order: Order::Desc,
}),
..Default::default()
};
let index = Index::builder()
.schema(schema)
.settings(settings)
.create_in_ram()?;
let index_reader = index.reader()?;
let mut index_writer = index.writer_for_tests()?;
// create and delete docs in same commit
for id in 0u64..5u64 {
index_writer.add_document(doc!(id_field => id))?;
}
for id in 1u64..4u64 {
let term = Term::from_field_u64(id_field, id);
let not_term = Term::from_field_u64(id_field, 2);
let term = Box::new(TermQuery::new(term, Default::default()));
let not_term = Box::new(TermQuery::new(not_term, Default::default()));
let query: BooleanQuery = vec![
(Occur::Must, term as Box<dyn Query>),
(Occur::MustNot, not_term as Box<dyn Query>),
]
.into();
index_writer.delete_query(Box::new(query))?;
}
for id in 5u64..10u64 {
index_writer.add_document(doc!(id_field => id))?;
}
index_writer.commit()?;
index_reader.reload()?;
let searcher = index_reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
assert_eq!(segment_reader.num_docs(), 8);
assert_eq!(segment_reader.max_doc(), 10);
let fast_field_reader = segment_reader.fast_fields().u64("id")?;
let in_order_alive_ids: Vec<u64> = segment_reader
.doc_ids_alive()
.flat_map(|doc| fast_field_reader.values_for_doc(doc))
.collect();
assert_eq!(&in_order_alive_ids[..], &[9, 8, 7, 6, 5, 4, 2, 0]);
Ok(())
}
#[derive(Debug, Clone)]
enum IndexingOp {
AddMultipleDoc {
@@ -1625,7 +1735,11 @@ mod tests {
id_list
}
fn test_operation_strategy(ops: &[IndexingOp], force_end_merge: bool) -> crate::Result<Index> {
fn test_operation_strategy(
ops: &[IndexingOp],
sort_index: bool,
force_end_merge: bool,
) -> crate::Result<Index> {
let mut schema_builder = schema::Schema::builder();
let json_field = schema_builder.add_json_field("json", FAST | TEXT | STORED);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST | INDEXED | STORED);
@@ -1661,7 +1775,15 @@ mod tests {
);
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let settings = {
let settings = if sort_index {
IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id_opt".to_string(),
order: Order::Asc,
}),
..Default::default()
}
} else {
IndexSettings {
..Default::default()
}
@@ -2225,13 +2347,33 @@ mod tests {
}
}
// Test if index property is in sort order
if sort_index {
// load all id_opt in each segment and check they are in order
for reader in searcher.segment_readers() {
let (ff_reader, _) = reader.fast_fields().u64_lenient("id_opt").unwrap().unwrap();
let mut ids_in_segment: Vec<u64> = Vec::new();
for doc in 0..reader.num_docs() {
ids_in_segment.extend(ff_reader.values_for_doc(doc));
}
assert!(is_sorted(&ids_in_segment));
fn is_sorted<T>(data: &[T]) -> bool
where T: Ord {
data.windows(2).all(|w| w[0] <= w[1])
}
}
}
Ok(index)
}
#[test]
fn test_fast_field_range() {
let ops: Vec<_> = (0..1000).map(IndexingOp::add).collect();
assert!(test_operation_strategy(&ops, true).is_ok());
assert!(test_operation_strategy(&ops, false, true).is_ok());
}
#[test]
@@ -2245,6 +2387,7 @@ mod tests {
IndexingOp::Commit,
IndexingOp::Merge
],
true,
false
)
.is_ok());
@@ -2261,6 +2404,7 @@ mod tests {
IndexingOp::add(1),
IndexingOp::Commit,
],
false,
true
)
.is_ok());
@@ -2268,24 +2412,97 @@ mod tests {
#[test]
fn test_minimal_sort_force_end_merge() {
assert!(
test_operation_strategy(&[IndexingOp::add(23), IndexingOp::add(13),], false).is_ok()
);
assert!(test_operation_strategy(
&[IndexingOp::add(23), IndexingOp::add(13),],
false,
false
)
.is_ok());
}
#[test]
fn test_minimal_no_force_end_merge() {
fn test_minimal_sort() {
let mut schema_builder = Schema::builder();
let val = schema_builder.add_u64_field("val", FAST);
let id = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let settings = IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "id".to_string(),
order: Order::Asc,
}),
..Default::default()
};
let index = Index::builder()
.schema(schema)
.settings(settings)
.create_in_ram()
.unwrap();
let mut writer = index.writer_for_tests().unwrap();
writer
.add_document(doc!(id=> 3u64, val=>4u64, val=>4u64))
.unwrap();
writer
.add_document(doc!(id=> 2u64, val=>2u64, val=>2u64))
.unwrap();
writer
.add_document(doc!(id=> 1u64, val=>1u64, val=>1u64))
.unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let id_col: Column = segment_reader
.fast_fields()
.column_opt("id")
.unwrap()
.unwrap();
let val_col: Column = segment_reader
.fast_fields()
.column_opt("val")
.unwrap()
.unwrap();
assert_eq!(id_col.get_cardinality(), Cardinality::Full);
assert_eq!(val_col.get_cardinality(), Cardinality::Multivalued);
assert_eq!(id_col.first(0u32), Some(1u64));
assert_eq!(id_col.first(1u32), Some(2u64));
assert!(val_col.values_for_doc(0u32).eq([1u64, 1u64].into_iter()));
assert!(val_col.values_for_doc(1u32).eq([2u64, 2u64].into_iter()));
}
#[test]
fn test_minimal_sort_force_end_merge_with_delete() {
assert!(test_operation_strategy(
&[
IndexingOp::add(23),
IndexingOp::add(13),
IndexingOp::DeleteDoc { id: 13 }
],
true,
true
)
.is_ok());
}
#[test]
fn test_minimal_no_sort_no_force_end_merge() {
assert!(test_operation_strategy(
&[
IndexingOp::add(23),
IndexingOp::add(13),
IndexingOp::DeleteDoc { id: 13 }
],
false,
false
)
.is_ok());
}
#[test]
fn test_minimal_sort_merge() {
assert!(test_operation_strategy(&[IndexingOp::add(3),], true, true).is_ok());
}
use proptest::prelude::*;
proptest! {
@@ -2293,23 +2510,77 @@ mod tests {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_delete_proptest_adding(ops in proptest::collection::vec(adding_operation_strategy(), 1..100)) {
assert!(test_operation_strategy(&ops[..], false).is_ok());
assert!(test_operation_strategy(&ops[..], true, false).is_ok());
}
#[test]
fn test_delete_proptest_with_merge_adding(ops in proptest::collection::vec(adding_operation_strategy(), 1..100)) {
assert!(test_operation_strategy(&ops[..], true).is_ok());
assert!(test_operation_strategy(&ops[..], false, false).is_ok());
}
#[test]
fn test_delete_proptest(ops in proptest::collection::vec(balanced_operation_strategy(), 1..10)) {
assert!(test_operation_strategy(&ops[..], false).is_ok());
assert!(test_operation_strategy(&ops[..], true, true).is_ok());
}
#[test]
fn test_delete_proptest_with_merge(ops in proptest::collection::vec(balanced_operation_strategy(), 1..100)) {
assert!(test_operation_strategy(&ops[..], true).is_ok());
assert!(test_operation_strategy(&ops[..], false, true).is_ok());
}
#[test]
#[ignore = "doesn't work with deferred segment loading"]
fn test_delete_without_sort_proptest(ops in proptest::collection::vec(balanced_operation_strategy(), 1..10)) {
assert!(test_operation_strategy(&ops[..], false, false).is_ok());
}
#[test]
#[ignore = "doesn't work with deferred segment loading"]
fn test_delete_with_sort_proptest_with_merge(ops in proptest::collection::vec(balanced_operation_strategy(), 1..10)) {
assert!(test_operation_strategy(&ops[..], true, true).is_ok());
}
}
#[test]
fn test_delete_with_sort_by_field_last_opstamp_is_not_max() -> crate::Result<()> {
let mut schema_builder = schema::Schema::builder();
let sort_by_field = schema_builder.add_u64_field("sort_by", FAST);
let id_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let settings = IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "sort_by".to_string(),
order: Order::Asc,
}),
..Default::default()
};
let index = Index::builder()
.schema(schema)
.settings(settings)
.create_in_ram()?;
let mut index_writer = index.writer_for_tests()?;
// We add a doc...
index_writer.add_document(doc!(sort_by_field => 2u64, id_field => 0u64))?;
// And remove it.
index_writer.delete_term(Term::from_field_u64(id_field, 0u64));
// We add another doc.
index_writer.add_document(doc!(sort_by_field=>1u64, id_field => 0u64))?;
// The expected result is a segment with
// maxdoc = 2
// numdoc = 1.
index_writer.commit()?;
let searcher = index.reader()?.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
assert_eq!(segment_reader.max_doc(), 2);
assert_eq!(segment_reader.num_docs(), 1);
Ok(())
}
#[test]
@@ -2326,7 +2597,7 @@ mod tests {
IndexingOp::add(4),
Commit,
];
test_operation_strategy(&ops[..], true).unwrap();
test_operation_strategy(&ops[..], false, true).unwrap();
}
#[test]
@@ -2339,7 +2610,7 @@ mod tests {
Commit,
Merge,
];
test_operation_strategy(&ops[..], true).unwrap();
test_operation_strategy(&ops[..], false, true).unwrap();
}
#[test]
@@ -2351,7 +2622,7 @@ mod tests {
IndexingOp::add(13),
Commit,
];
test_operation_strategy(&ops[..], true).unwrap();
test_operation_strategy(&ops[..], false, true).unwrap();
}
#[test]
@@ -2362,7 +2633,7 @@ mod tests {
IndexingOp::add(9),
IndexingOp::add(10),
];
test_operation_strategy(&ops[..], false).unwrap();
test_operation_strategy(&ops[..], false, false).unwrap();
}
#[test]
@@ -2389,6 +2660,7 @@ mod tests {
IndexingOp::Commit,
IndexingOp::Commit
],
false,
false
)
.is_ok());
@@ -2409,6 +2681,7 @@ mod tests {
IndexingOp::Merge,
],
true,
false,
)
.unwrap();
}

View File

@@ -1,148 +0,0 @@
#[cfg(test)]
mod tests {
use crate::collector::TopDocs;
use crate::fastfield::AliveBitSet;
use crate::index::Index;
use crate::postings::Postings;
use crate::query::QueryParser;
use crate::schema::{
self, BytesOptions, Facet, FacetOptions, IndexRecordOption, NumericOptions,
TextFieldIndexing, TextOptions,
};
use crate::{DocAddress, DocSet, IndexSettings, IndexWriter, Term};
fn create_test_index(index_settings: Option<IndexSettings>) -> crate::Result<Index> {
let mut schema_builder = schema::Schema::builder();
let int_options = NumericOptions::default()
.set_fast()
.set_stored()
.set_indexed();
let int_field = schema_builder.add_u64_field("intval", int_options);
let bytes_options = BytesOptions::default().set_fast().set_indexed();
let bytes_field = schema_builder.add_bytes_field("bytes", bytes_options);
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let multi_numbers =
schema_builder.add_u64_field("multi_numbers", NumericOptions::default().set_fast());
let text_field_options = TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default()
.set_index_option(schema::IndexRecordOption::WithFreqsAndPositions),
)
.set_stored();
let text_field = schema_builder.add_text_field("text_field", text_field_options);
let schema = schema_builder.build();
let mut index_builder = Index::builder().schema(schema);
if let Some(settings) = index_settings {
index_builder = index_builder.settings(settings);
}
let index = index_builder.create_in_ram()?;
{
let mut index_writer = index.writer_for_tests()?;
// segment 1 - range 1-3
index_writer.add_document(doc!(int_field=>1_u64))?;
index_writer.add_document(
doc!(int_field=>3_u64, multi_numbers => 3_u64, multi_numbers => 4_u64, bytes_field => vec![1, 2, 3], text_field => "some text", facet_field=> Facet::from("/book/crime")),
)?;
index_writer.add_document(
doc!(int_field=>1_u64, text_field=> "deleteme", text_field => "ok text more text"),
)?;
index_writer.add_document(
doc!(int_field=>2_u64, multi_numbers => 2_u64, multi_numbers => 3_u64, text_field => "ok text more text"),
)?;
index_writer.commit()?;
index_writer.add_document(doc!(int_field=>20_u64, multi_numbers => 20_u64))?;
let in_val = 1u64;
index_writer.add_document(doc!(int_field=>in_val, text_field=> "deleteme" , text_field => "ok text more text", facet_field=> Facet::from("/book/crime")))?;
index_writer.commit()?;
let int_vals = [10u64, 5];
index_writer.add_document( // position of this doc after delete in desc sorting = [2], in disjunct case [1]
doc!(int_field=>int_vals[0], multi_numbers => 10_u64, multi_numbers => 11_u64, text_field=> "blubber", facet_field=> Facet::from("/book/fantasy")),
)?;
index_writer.add_document(doc!(int_field=>int_vals[1], text_field=> "deleteme"))?;
index_writer.add_document(
doc!(int_field=>1_000u64, multi_numbers => 1001_u64, multi_numbers => 1002_u64, bytes_field => vec![5, 5],text_field => "the biggest num")
)?;
index_writer.delete_term(Term::from_field_text(text_field, "deleteme"));
index_writer.commit()?;
}
// Merging the segments
{
let segment_ids = index.searchable_segment_ids()?;
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.merge(&segment_ids).wait()?;
index_writer.wait_merging_threads()?;
}
Ok(index)
}
#[test]
fn test_merge_index() {
let index = create_test_index(Some(IndexSettings {
..Default::default()
}))
.unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_readers().last().unwrap();
let searcher = index.reader().unwrap().searcher();
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let do_search = |term: &str| {
let query = QueryParser::for_index(&index, vec![my_text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(3).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};
assert_eq!(do_search("some"), vec![1]);
assert_eq!(do_search("blubber"), vec![3]);
assert_eq!(do_search("biggest"), vec![4]);
}
// postings file
{
let my_text_field = index.schema().get_field("text_field").unwrap();
let term_a = Term::from_field_text(my_text_field, "text");
let inverted_index = segment_reader.inverted_index(my_text_field).unwrap();
let mut postings = inverted_index
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap()
.unwrap();
assert_eq!(postings.doc_freq(), 2);
let fallback_bitset = AliveBitSet::for_test_from_deleted_docs(&[0], 100);
assert_eq!(
postings.doc_freq_given_deletes(
segment_reader.alive_bitset().unwrap_or(&fallback_bitset)
),
2
);
assert_eq!(postings.term_freq(), 1);
let mut output = vec![];
postings.positions(&mut output);
assert_eq!(output, vec![1]);
postings.advance();
assert_eq!(postings.term_freq(), 2);
postings.positions(&mut output);
assert_eq!(output, vec![1, 3]);
}
}
}

View File

@@ -1,7 +1,8 @@
use std::sync::Arc;
use columnar::{
ColumnType, ColumnarReader, MergeRowOrder, RowAddr, ShuffleMergeOrder, StackMergeOrder,
compute_merged_term_ord_mapping, BytesColumn, Column, ColumnType, ColumnarReader,
MergeRowOrder, RowAddr, ShuffleMergeOrder, StackMergeOrder,
};
use common::ReadOnlyBitSet;
use itertools::Itertools;
@@ -10,16 +11,52 @@ use measure_time::debug_time;
use crate::directory::WritePtr;
use crate::docset::{DocSet, TERMINATED};
use crate::error::DataCorruption;
use crate::fastfield::AliveBitSet;
use crate::fastfield::{AliveBitSet, FastFieldNotAvailableError};
use crate::fieldnorm::{FieldNormReader, FieldNormReaders, FieldNormsSerializer, FieldNormsWriter};
use crate::index::{Segment, SegmentComponent, SegmentReader};
use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema, Type};
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
use crate::{DocAddress, DocId, InvertedIndexReader};
use crate::{
DocAddress, DocId, IndexSettings, IndexSortByField, InvertedIndexReader, Order, SegmentOrdinal,
};
/// Per-segment accessor for Str/Bytes sort fields during index merging.
///
/// Each segment stores its own term dictionary with segment-local ordinals. To compare terms
/// across segments we compute a merged global dictionary and map each segment's local ordinals
/// to the corresponding merged ordinal via `merged_term_ord_mapping`. This avoids materializing
/// the actual term bytes during the merge sort — ordinal comparison is sufficient because the
/// merged dictionary preserves lexicographic order.
struct StrBytesSortFieldAccessor {
ords: Column<u64>,
merged_term_ord_mapping: Vec<TermOrdinal>,
}
impl StrBytesSortFieldAccessor {
fn remapped_term_ord(&self, doc_id: DocId) -> Option<TermOrdinal> {
self.ords.first(doc_id).map(|old_ord| {
let old_ord = old_ord as usize;
debug_assert!(old_ord < self.merged_term_ord_mapping.len());
self.merged_term_ord_mapping[old_ord]
})
}
}
/// Owned per-segment sort-field accessors, kept alive for the duration of the merge.
///
/// - `Numeric`: direct column value access — all numeric/datetime types share a single u64 column
/// interface, so segments can be compared directly by value.
/// - `StrBytes`: ordinal-based access — each segment's local term ordinals are remapped to merged
/// global ordinals so that cross-segment lexicographic comparison works without loading term
/// bytes.
enum ReaderSortFieldAccessors {
Numeric(Vec<(SegmentOrdinal, Column<u64>)>),
StrBytes(Vec<(SegmentOrdinal, StrBytesSortFieldAccessor)>),
}
/// Segment's max doc must be `< MAX_DOC_LIMIT`.
///
@@ -76,7 +113,9 @@ fn estimate_total_num_tokens(readers: &[SegmentReader], field: Field) -> crate::
Ok(total_num_tokens)
}
/// Merges multiple index segments into a single new segment.
pub struct IndexMerger {
index_settings: IndexSettings,
schema: Schema,
pub(crate) readers: Vec<SegmentReader>,
max_doc: u32,
@@ -112,7 +151,7 @@ fn convert_to_merge_order(
) -> MergeRowOrder {
match doc_id_mapping.mapping_type() {
MappingType::Stacked => MergeRowOrder::Stack(StackMergeOrder::stack(columnars)),
MappingType::StackedWithDeletes => {
MappingType::StackedWithDeletes | MappingType::Shuffled => {
// RUST/LLVM is amazing. The following conversion is actually a no-op:
// no allocation, no copy.
let new_row_id_to_old_row_id: Vec<RowAddr> = doc_id_mapping
@@ -145,25 +184,62 @@ fn extract_fast_field_required_columns(schema: &Schema) -> Vec<(String, ColumnTy
}
impl IndexMerger {
pub fn open(schema: Schema, segments: &[Segment]) -> crate::Result<IndexMerger> {
let alive_bitset = segments.iter().map(|_| None).collect_vec();
Self::open_with_custom_alive_set(schema, segments, alive_bitset)
fn total_num_new_docs(&self) -> usize {
self.readers
.iter()
.map(|reader| reader.num_docs() as usize)
.sum()
}
// Create merge with a custom delete set.
// For every Segment, a delete bitset can be provided, which
// will be merged with the existing bit set. Make sure the index
// corresponds to the segment index.
//
// If `None` is provided for custom alive set, the regular alive set will be used.
// If a alive_bitset is provided, the union between the provided and regular
// alive set will be used.
//
// This can be used to merge but also apply an additional filter.
// One use case is demux, which is basically taking a list of
// segments and partitions them e.g. by a value in a field.
fn collect_alive_bitsets(&self) -> Vec<Option<ReadOnlyBitSet>> {
self.readers
.iter()
.map(|reader| {
reader
.alive_bitset()
.map(|alive_bitset| alive_bitset.bitset().clone())
})
.collect()
}
/// Column cardinality metadata (`Optional`) covers all docs including deleted ones.
/// A segment can report `Optional` but have zero live NULLs if every NULL doc was
/// deleted. We scan alive docs to distinguish this case, because deleted NULLs
/// are excluded from the merge and shouldn't block the disjunct-stack path.
fn segment_has_live_nulls(&self, segment_ord: SegmentOrdinal, col: &Column<u64>) -> bool {
if col.get_cardinality() != columnar::Cardinality::Optional {
return false;
}
let reader = &self.readers[segment_ord as usize];
if !reader.has_deletes() {
return true;
}
reader
.doc_ids_alive()
.any(|doc_id| col.first(doc_id).is_none())
}
/// Opens an [`IndexMerger`] over the given segments using their existing delete sets.
pub fn open(
schema: Schema,
index_settings: IndexSettings,
segments: &[Segment],
) -> crate::Result<IndexMerger> {
let alive_bitset = segments.iter().map(|_| None).collect_vec();
Self::open_with_custom_alive_set(schema, index_settings, segments, alive_bitset)
}
/// Opens an [`IndexMerger`] with a custom alive (delete) set per segment.
///
/// For every segment, an optional [`AliveBitSet`] can be provided which is intersected
/// with the segment's existing alive set. Pass `None` for a segment to use its existing
/// delete set unchanged.
///
/// This allows merging while also applying an additional filter, for example to demux
/// documents by a field value into separate output segments.
pub fn open_with_custom_alive_set(
schema: Schema,
index_settings: IndexSettings,
segments: &[Segment],
alive_bitset_opt: Vec<Option<AliveBitSet>>,
) -> crate::Result<IndexMerger> {
@@ -177,6 +253,12 @@ impl IndexMerger {
}
let max_doc = readers.iter().map(|reader| reader.num_docs()).sum();
if let Some(sort_by_field) = index_settings.sort_by_field.as_ref() {
let schema_field = schema.get_field(&sort_by_field.field)?;
let field_entry = schema.get_field_entry(schema_field);
let field_type = field_entry.field_type().value_type();
readers = Self::sort_readers_by_min_sort_field(readers, sort_by_field, field_type)?;
}
// sort segments by their natural sort setting
if max_doc >= MAX_DOC_LIMIT {
let err_msg = format!(
@@ -186,12 +268,50 @@ impl IndexMerger {
return Err(crate::TantivyError::InvalidArgument(err_msg));
}
Ok(IndexMerger {
index_settings,
schema,
readers,
max_doc,
})
}
fn sort_by_field_type(&self, sort_by_field: &IndexSortByField) -> crate::Result<Type> {
let schema_field = self.schema.get_field(&sort_by_field.field)?;
let field_entry = self.schema.get_field_entry(schema_field);
Ok(field_entry.field_type().value_type())
}
fn sort_readers_by_min_sort_field(
readers: Vec<SegmentReader>,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<Vec<SegmentReader>> {
if matches!(field_type, Type::Str | Type::Bytes) {
// Ordinals are per-segment and not directly comparable, so the "disjunct min/max"
// shortcut that works for numeric fields does not apply here.
return Ok(readers);
}
// presort the readers by their min_values, so that when they are disjunct, we can use
// the regular merge logic (implicitly sorted)
let mut readers_with_min_sort_values = readers
.into_iter()
.map(|reader| {
let accessor = Self::get_numeric_accessor(&reader, sort_by_field)?;
Ok((reader, accessor.min_value()))
})
.collect::<crate::Result<Vec<_>>>()?;
if sort_by_field.order.is_asc() {
readers_with_min_sort_values.sort_by_key(|(_, min_val)| *min_val);
} else {
readers_with_min_sort_values.sort_by_key(|(_, min_val)| std::cmp::Reverse(*min_val));
}
Ok(readers_with_min_sort_values
.into_iter()
.map(|(reader, _)| reader)
.collect())
}
fn write_fieldnorms(
&self,
mut fieldnorms_serializer: FieldNormsSerializer,
@@ -239,14 +359,261 @@ impl IndexMerger {
Ok(())
}
/// Checks if segments can use the fast disjunct-stack path (byte concatenation)
/// instead of a full k-way merge.
///
/// Stacking preserves per-segment order but doesn't reposition docs across segments.
/// NULLs must sort first (ASC) or last (DESC) globally, but stacking can't move a
/// NULL from segment 2 before values in segment 1. So any live NULL forces a full
/// k-way merge to place NULLs correctly.
fn is_disjunct_and_sorted_on_sort_property(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<bool> {
let field_type = self.sort_by_field_type(sort_by_field)?;
// Disjunct shortcut is invalid for Str/Bytes because ords are per-segment.
if matches!(field_type, Type::Str | Type::Bytes) {
return Ok(false);
}
let reader_ordinal_and_field_accessors = self.get_numeric_accessors(sort_by_field)?;
let asc = sort_by_field.order.is_asc();
let values_disjunct = reader_ordinal_and_field_accessors
.iter()
.map(|(_, col)| col)
.tuple_windows()
.all(|(col1, col2)| {
if asc {
col1.max_value() <= col2.min_value()
} else {
col1.min_value() >= col2.max_value()
}
});
if !values_disjunct {
return Ok(false);
}
let has_live_nulls = reader_ordinal_and_field_accessors
.iter()
.any(|(segment_ord, col)| self.segment_has_live_nulls(*segment_ord, col));
Ok(!has_live_nulls)
}
fn get_str_bytes_column(
reader: &SegmentReader,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<BytesColumn> {
let not_available = || -> crate::TantivyError {
FastFieldNotAvailableError {
field_name: sort_by_field.field.to_string(),
}
.into()
};
match field_type {
Type::Str => reader
.fast_fields()
.str(&sort_by_field.field)?
.map(Into::into)
.ok_or_else(not_available),
Type::Bytes => reader
.fast_fields()
.bytes(&sort_by_field.field)?
.ok_or_else(not_available),
_ => unreachable!("get_str_bytes_column called with non-Str/Bytes type"),
}
}
/// Builds per-segment [`StrBytesSortFieldAccessor`]s for Str/Bytes sort fields.
///
/// 1. Extracts each segment's `BytesColumn` (term dictionary + ordinal column).
/// 2. Computes a merged dictionary across all segments via [`compute_merged_term_ord_mapping`],
/// producing a per-segment mapping from local term ordinal → merged global ordinal.
/// 3. Wraps each segment's ordinal column and mapping into a `StrBytesSortFieldAccessor`.
fn get_str_bytes_accessors(
&self,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<Vec<(SegmentOrdinal, StrBytesSortFieldAccessor)>> {
let bytes_columns = self
.readers
.iter()
.map(|reader| Self::get_str_bytes_column(reader, sort_by_field, field_type))
.collect::<crate::Result<Vec<_>>>()?;
let merged_term_ord_mappings = compute_merged_term_ord_mapping(&bytes_columns)?;
debug_assert_eq!(bytes_columns.len(), merged_term_ord_mappings.len());
let accessors = bytes_columns
.into_iter()
.zip(merged_term_ord_mappings)
.enumerate()
.map(
|(reader_ordinal, (bytes_column, merged_term_ord_mapping))| {
(
reader_ordinal as SegmentOrdinal,
StrBytesSortFieldAccessor {
ords: bytes_column.ords().clone(),
merged_term_ord_mapping,
},
)
},
)
.collect::<Vec<_>>();
Ok(accessors)
}
/// Returns the full `Column<u64>` so callers can use `Column::first()` which
/// returns `Option<u64>` — `None` for NULLs, `Some` for real values. This
/// distinction is required for correct NULL ordering during merge sort and
/// for detecting live NULLs in the disjunct-stack check.
fn get_numeric_accessor(
reader: &SegmentReader,
sort_by_field: &IndexSortByField,
) -> crate::Result<Column<u64>> {
reader.schema().get_field(&sort_by_field.field)?;
let (value_accessor, _column_type) = reader
.fast_fields()
.u64_lenient(&sort_by_field.field)?
.ok_or_else(|| FastFieldNotAvailableError {
field_name: sort_by_field.field.to_string(),
})?;
Ok(value_accessor)
}
fn get_numeric_accessors(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<Vec<(SegmentOrdinal, Column<u64>)>> {
self.readers
.iter()
.enumerate()
.map(|(reader_ordinal, reader)| {
let reader_ordinal = reader_ordinal as SegmentOrdinal;
let accessor = Self::get_numeric_accessor(reader, sort_by_field)?;
Ok((reader_ordinal, accessor))
})
.collect::<crate::Result<Vec<_>>>()
}
/// Builds owned per-segment sort accessors so they stay alive during merge.
///
/// Dispatches on the sort field's value type: numeric types use direct column value access,
/// while Str/Bytes types go through the ordinal-remapping path (see
/// [`StrBytesSortFieldAccessor`]).
fn get_reader_with_sort_field_accessor(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<ReaderSortFieldAccessors> {
let field_type = self.sort_by_field_type(sort_by_field)?;
if matches!(field_type, Type::Str | Type::Bytes) {
let accessors = self.get_str_bytes_accessors(sort_by_field, field_type)?;
return Ok(ReaderSortFieldAccessors::StrBytes(accessors));
}
let accessors = self.get_numeric_accessors(sort_by_field)?;
Ok(ReaderSortFieldAccessors::Numeric(accessors))
}
fn extend_sorted_doc_ids<T, F>(
&self,
reader_ordinal_and_field_accessors: &[(SegmentOrdinal, T)],
mut is_less: F,
sorted_doc_ids: &mut Vec<DocAddress>,
) where
F: FnMut(&(DocId, &SegmentOrdinal, &T), &(DocId, &SegmentOrdinal, &T)) -> bool,
{
let doc_id_reader_pair =
reader_ordinal_and_field_accessors
.iter()
.map(|(reader_ord, ff_reader)| {
let reader = &self.readers[*reader_ord as usize];
reader
.doc_ids_alive()
.map(move |doc_id| (doc_id, reader_ord, ff_reader))
});
sorted_doc_ids.extend(
doc_id_reader_pair
.into_iter()
.kmerge_by(|a, b| is_less(a, b))
.map(|(doc_id, &segment_ord, _)| DocAddress {
doc_id,
segment_ord,
}),
);
}
/// Generates the doc_id mapping where position in the vec=new
/// doc_id.
/// ReaderWithOrdinal will include the ordinal position of the
/// reader in self.readers.
pub(crate) fn generate_doc_id_mapping_with_sort_by_field(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<SegmentDocIdMapping> {
let sort_field_accessors = self.get_reader_with_sort_field_accessor(sort_by_field)?;
// Loading the field accessor on demand causes a 15x regression
let total_num_new_docs = self.total_num_new_docs();
let mut sorted_doc_ids: Vec<DocAddress> = Vec::with_capacity(total_num_new_docs);
// K-way merge of alive doc ids across segments, ordered by the sort field.
//
// Numeric: compare raw u64 column values directly.
// Str/Bytes: compare merged global ordinals obtained via `remapped_term_ord`.
// Documents without a value map to `None` — first in ascending, last in descending.
let asc = sort_by_field.order == Order::Asc;
match sort_field_accessors {
ReaderSortFieldAccessors::Numeric(reader_ordinal_and_field_accessors) => {
self.extend_sorted_doc_ids(
&reader_ordinal_and_field_accessors,
|a, b| {
// Column::first() returns Option<u64>: None for NULLs, Some for values.
// Option's Ord puts None < Some, giving NULL-first in ASC, NULL-last in
// DESC.
let val1 = a.2.first(a.0);
let val2 = b.2.first(b.0);
if asc {
val1 < val2
} else {
val1 > val2
}
},
&mut sorted_doc_ids,
);
}
ReaderSortFieldAccessors::StrBytes(reader_ordinal_and_field_accessors) => {
self.extend_sorted_doc_ids(
&reader_ordinal_and_field_accessors,
|a, b| {
let val1 = a.2.remapped_term_ord(a.0);
let val2 = b.2.remapped_term_ord(b.0);
if asc {
val1 < val2
} else {
val1 > val2
}
},
&mut sorted_doc_ids,
);
}
}
let alive_bitsets = self.collect_alive_bitsets();
Ok(SegmentDocIdMapping::new(
sorted_doc_ids,
MappingType::Shuffled,
alive_bitsets,
))
}
/// Creates a mapping if the segments are stacked. this is helpful to merge codelines between
/// index sorting and the others
pub(crate) fn get_doc_id_from_concatenated_data(&self) -> crate::Result<SegmentDocIdMapping> {
let total_num_new_docs = self
.readers
.iter()
.map(|reader| reader.num_docs() as usize)
.sum();
let total_num_new_docs = self.total_num_new_docs();
let mut mapping: Vec<DocAddress> = Vec::with_capacity(total_num_new_docs);
@@ -262,20 +629,13 @@ impl IndexMerger {
}),
);
let has_deletes: bool = self.readers.iter().any(SegmentReader::has_deletes);
let has_deletes = self.readers.iter().any(SegmentReader::has_deletes);
let mapping_type = if has_deletes {
MappingType::StackedWithDeletes
} else {
MappingType::Stacked
};
let alive_bitsets: Vec<Option<ReadOnlyBitSet>> = self
.readers
.iter()
.map(|reader| {
let alive_bitset = reader.alive_bitset()?;
Some(alive_bitset.bitset().clone())
})
.collect();
let alive_bitsets = self.collect_alive_bitsets();
Ok(SegmentDocIdMapping::new(
mapping,
mapping_type,
@@ -356,6 +716,7 @@ impl IndexMerger {
);
let mut segment_postings_containing_the_term: Vec<(usize, SegmentPostings)> = vec![];
let mut doc_id_and_positions = vec![];
while merged_terms.advance() {
segment_postings_containing_the_term.clear();
@@ -451,13 +812,37 @@ impl IndexMerger {
0u32
};
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(remapped_doc_id, term_freq, delta_positions);
// if doc_id_mapping exists, the doc_ids are reordered, they are
// not just stacked. The field serializer expects monotonically increasing
// doc_ids, so we collect and sort them first, before writing.
//
// I think this is not strictly necessary, it would be possible to
// avoid the loading into a vec via some form of kmerge, but then the merge
// logic would deviate much more from the stacking case (unsorted index)
if !doc_id_mapping.is_trivial() {
doc_id_and_positions.push((
remapped_doc_id,
term_freq,
positions_buffer.to_vec(),
));
} else {
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(remapped_doc_id, term_freq, delta_positions);
}
}
doc = segment_postings.advance();
}
}
if !doc_id_mapping.is_trivial() {
doc_id_and_positions.sort_unstable_by_key(|&(doc_id, _, _)| doc_id);
for (doc_id, term_freq, positions) in &doc_id_and_positions {
let delta_positions = delta_computer.compute_delta(positions);
field_serializer.write_doc(*doc_id, *term_freq, delta_positions);
}
doc_id_and_positions.clear();
}
// closing the term.
field_serializer.close_term()?;
}
@@ -486,13 +871,47 @@ impl IndexMerger {
Ok(())
}
fn write_storable_fields(&self, store_writer: &mut StoreWriter) -> crate::Result<()> {
fn write_storable_fields(
&self,
store_writer: &mut StoreWriter,
doc_id_mapping: &SegmentDocIdMapping,
) -> crate::Result<()> {
debug_time!("write-storable-fields");
debug!("write-storable-field");
for reader in &self.readers {
let store_reader = reader.get_store_reader(1)?;
if reader.has_deletes()
if !doc_id_mapping.is_trivial() {
debug!("non-trivial-doc-id-mapping");
let store_readers: Vec<_> = self
.readers
.iter()
.map(|reader| reader.get_store_reader(50))
.collect::<Result<_, _>>()?;
let mut document_iterators: Vec<_> = store_readers
.iter()
.enumerate()
.map(|(i, store)| store.iter_raw(self.readers[i].alive_bitset()))
.collect();
for old_doc_addr in doc_id_mapping.iter_old_doc_addrs() {
let doc_bytes_it = &mut document_iterators[old_doc_addr.segment_ord as usize];
if let Some(doc_bytes_res) = doc_bytes_it.next() {
let doc_bytes = doc_bytes_res?;
store_writer.store_bytes(&doc_bytes)?;
} else {
return Err(DataCorruption::comment_only(format!(
"unexpected missing document in docstore on merge, doc address \
{old_doc_addr:?}",
))
.into());
}
}
} else {
debug!("trivial-doc-id-mapping");
for reader in &self.readers {
let store_reader = reader.get_store_reader(1)?;
if reader.has_deletes()
// If there is not enough data in the store, we avoid stacking in order to
// avoid creating many small blocks in the doc store. Once we have 5 full blocks,
// we start stacking. In the worst case 2/7 of the blocks would be very small.
@@ -508,13 +927,14 @@ impl IndexMerger {
// take 7 in order to not walk over all checkpoints.
|| store_reader.block_checkpoints().take(7).count() < 6
|| store_reader.decompressor() != store_writer.compressor().into()
{
for doc_bytes_res in store_reader.iter_raw(reader.alive_bitset()) {
let doc_bytes = doc_bytes_res?;
store_writer.store_bytes(&doc_bytes)?;
{
for doc_bytes_res in store_reader.iter_raw(reader.alive_bitset()) {
let doc_bytes = doc_bytes_res?;
store_writer.store_bytes(&doc_bytes)?;
}
} else {
store_writer.stack(store_reader)?;
}
} else {
store_writer.stack(store_reader)?;
}
}
Ok(())
@@ -525,8 +945,42 @@ impl IndexMerger {
///
/// # Returns
/// The number of documents in the resulting segment.
pub fn write(&self, mut serializer: SegmentSerializer) -> crate::Result<u32> {
let doc_id_mapping = self.get_doc_id_from_concatenated_data()?;
pub fn write(&self, serializer: SegmentSerializer) -> crate::Result<u32> {
let doc_id_mapping = if let Some(sort_by_field) = self.index_settings.sort_by_field.as_ref()
{
if self.is_disjunct_and_sorted_on_sort_property(sort_by_field)? {
self.get_doc_id_from_concatenated_data()?
} else {
self.generate_doc_id_mapping_with_sort_by_field(sort_by_field)?
}
} else {
self.get_doc_id_from_concatenated_data()?
};
self.write_with_mapping(serializer, doc_id_mapping)
}
/// Like [`write`], but uses the caller-supplied `doc_id_mapping` instead of
/// deriving one from an index sort field.
///
/// The mapping must cover *all* live documents across every segment passed to
/// [`IndexMerger::open_with_custom_alive_set`]. The simplest way to build one
/// is [`SegmentDocIdMapping::new_shuffled`].
///
/// # Returns
/// The number of documents in the resulting segment.
pub fn write_with_doc_id_mapping(
&self,
serializer: SegmentSerializer,
doc_id_mapping: SegmentDocIdMapping,
) -> crate::Result<u32> {
self.write_with_mapping(serializer, doc_id_mapping)
}
fn write_with_mapping(
&self,
mut serializer: SegmentSerializer,
doc_id_mapping: SegmentDocIdMapping,
) -> crate::Result<u32> {
debug!("write-fieldnorms");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
self.write_fieldnorms(fieldnorms_serializer, &doc_id_mapping)?;
@@ -543,7 +997,7 @@ impl IndexMerger {
)?;
debug!("write-storagefields");
self.write_storable_fields(serializer.get_store_writer())?;
self.write_storable_fields(serializer.get_store_writer(), &doc_id_mapping)?;
debug!("write-fastfields");
self.write_fast_fields(serializer.get_fast_field_write(), doc_id_mapping)?;
@@ -555,7 +1009,6 @@ impl IndexMerger {
#[cfg(test)]
mod tests {
use columnar::Column;
use proptest::prop_oneof;
use proptest::strategy::Strategy;
@@ -575,7 +1028,7 @@ mod tests {
use crate::time::OffsetDateTime;
use crate::{
assert_nearly_equals, schema, DateTime, DocAddress, DocId, DocSet, IndexSettings,
IndexWriter, Searcher,
IndexSortByField, IndexWriter, Order, Searcher,
};
#[test]
@@ -1048,6 +1501,60 @@ mod tests {
test_merge_facets(None, true)
}
#[test]
fn test_merge_facets_sort_asc() {
// In the merge case this will go through the doc_id mapping code
test_merge_facets(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "intval".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
true,
);
// In the merge case this will not go through the doc_id mapping code, because the data
// sorted and disjunct
test_merge_facets(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "intval".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
false,
);
}
#[test]
fn test_merge_facets_sort_desc() {
// In the merge case this will go through the doc_id mapping code
test_merge_facets(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "intval".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
true,
);
// In the merge case this will not go through the doc_id mapping code, because the data
// sorted and disjunct
test_merge_facets(
Some(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "intval".to_string(),
order: Order::Desc,
}),
..Default::default()
}),
false,
);
}
// force_segment_value_overlap forces the int value for sorting to have overlapping min and max
// ranges between segments so that merge algorithm can't apply certain optimizations
fn test_merge_facets(index_settings: Option<IndexSettings>, force_segment_value_overlap: bool) {

File diff suppressed because it is too large Load Diff

View File

@@ -8,17 +8,18 @@
pub(crate) mod delete_queue;
pub(crate) mod path_to_unordered_id;
pub(crate) mod doc_id_mapping;
pub mod doc_id_mapping;
mod doc_opstamp_mapping;
mod flat_map_with_buffer;
pub(crate) mod index_writer;
pub(crate) mod index_writer_status;
pub(crate) mod indexing_term;
mod log_merge_policy;
mod merge_index_test;
mod merge_operation;
pub(crate) mod merge_policy;
pub(crate) mod merger;
/// Segment merger: combines multiple segments into one.
pub mod merger;
mod merger_sorted_index_test;
pub(crate) mod operation;
pub(crate) mod prepared_commit;
mod segment_entry;
@@ -33,15 +34,19 @@ mod stamper;
use crossbeam_channel as channel;
use smallvec::SmallVec;
pub use self::doc_id_mapping::SegmentDocIdMapping;
pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions};
pub use self::log_merge_policy::LogMergePolicy;
pub use self::merge_operation::MergeOperation;
pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy};
pub use self::merger::IndexMerger;
pub use self::operation::{AddOperation, DeleteOperation, UserOperation};
pub use self::prepared_commit::PreparedCommit;
pub use self::segment_entry::SegmentEntry;
pub(crate) use self::segment_serializer::SegmentSerializer;
pub use self::segment_updater::{merge_filtered_segments, merge_indices};
pub use self::segment_updater::{
merge_filtered_segments, merge_indices, merge_segments_with_doc_id_mapping,
};
pub use self::segment_writer::SegmentWriter;
pub use self::single_segment_index_writer::SingleSegmentIndexWriter;

View File

@@ -18,9 +18,27 @@ pub struct SegmentSerializer {
impl SegmentSerializer {
/// Creates a new `SegmentSerializer`.
pub fn for_segment(mut segment: Segment) -> crate::Result<SegmentSerializer> {
pub fn for_segment(
mut segment: Segment,
is_in_merge: bool,
) -> crate::Result<SegmentSerializer> {
// If the segment is going to be sorted, we stream the docs first to a temporary file.
// In the merge case this is not necessary because we can kmerge the already sorted
// segments
let remapping_required = segment.index().settings().sort_by_field.is_some() && !is_in_merge;
let settings = segment.index().settings().clone();
let store_writer = {
let store_writer = if remapping_required {
let store_write = segment.open_write(SegmentComponent::TempStore)?;
StoreWriter::new(
store_write,
crate::store::Compressor::None,
// We want fast random access on the docs, so we choose a small block size.
// If this is zero, the skip index will contain too many checkpoints and
// therefore will be relatively slow.
16000,
settings.docstore_compress_dedicated_thread,
)?
} else {
let store_write = segment.open_write(SegmentComponent::Store)?;
StoreWriter::new(
store_write,
@@ -54,6 +72,10 @@ impl SegmentSerializer {
&self.segment
}
pub fn segment_mut(&mut self) -> &mut Segment {
&mut self.segment
}
/// Accessor to the `PostingsSerializer`.
pub fn get_postings_serializer(&mut self) -> &mut InvertedIndexSerializer {
&mut self.postings_serializer

View File

@@ -15,6 +15,7 @@ use crate::directory::{Directory, DirectoryClone, GarbageCollectionResult};
use crate::fastfield::AliveBitSet;
use crate::index::{Index, IndexMeta, IndexSettings, Segment, SegmentId, SegmentMeta};
use crate::indexer::delete_queue::DeleteCursor;
use crate::indexer::doc_id_mapping::SegmentDocIdMapping;
use crate::indexer::index_writer::advance_deletes;
use crate::indexer::merge_operation::MergeOperationInventory;
use crate::indexer::merger::IndexMerger;
@@ -114,10 +115,11 @@ fn merge(
.collect();
// An IndexMerger is like a "view" of our merged segments.
let merger: IndexMerger = IndexMerger::open(index.schema(), &segments[..])?;
let merger: IndexMerger =
IndexMerger::open(index.schema(), index.settings().clone(), &segments[..])?;
// ... we just serialize this index merger in our new segment to merge the segments.
let segment_serializer = SegmentSerializer::for_segment(merged_segment.clone())?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment.clone(), true)?;
let num_docs = merger.write(segment_serializer)?;
@@ -218,9 +220,13 @@ pub fn merge_filtered_segments<T: Into<Box<dyn Directory>>>(
)?;
let merged_segment = merged_index.new_segment();
let merged_segment_id = merged_segment.id();
let merger: IndexMerger =
IndexMerger::open_with_custom_alive_set(merged_index.schema(), segments, filter_doc_ids)?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment)?;
let merger: IndexMerger = IndexMerger::open_with_custom_alive_set(
merged_index.schema(),
merged_index.settings().clone(),
segments,
filter_doc_ids,
)?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment, true)?;
let num_docs = merger.write(segment_serializer)?;
let segment_meta = merged_index.new_segment_meta(merged_segment_id, num_docs);
@@ -250,6 +256,82 @@ pub fn merge_filtered_segments<T: Into<Box<dyn Directory>>>(
Ok(merged_index)
}
/// Like [`merge_filtered_segments`], but uses a caller-supplied [`SegmentDocIdMapping`]
/// to control the final document order. The mapping should be built from the same
/// segments (in the same order) passed here.
///
/// Use this to apply an external reordering during a merge without relying on a persistent fast field.
///
/// # Warning
/// Same caveats as [`merge_filtered_segments`]: no live `IndexWriter` allowed.
#[doc(hidden)]
pub fn merge_segments_with_doc_id_mapping<T: Into<Box<dyn Directory>>>(
segments: &[Segment],
target_settings: IndexSettings,
filter_doc_ids: Vec<Option<AliveBitSet>>,
doc_id_mapping: SegmentDocIdMapping,
output_directory: T,
) -> crate::Result<Index> {
if segments.is_empty() {
return Err(crate::TantivyError::InvalidArgument(
"No segments given to merge".to_string(),
));
}
let target_schema = segments[0].schema();
if segments
.iter()
.skip(1)
.any(|seg| seg.schema() != target_schema)
{
return Err(crate::TantivyError::InvalidArgument(
"Attempt to merge different schema indices".to_string(),
));
}
let mut merged_index = Index::create(
output_directory,
target_schema.clone(),
target_settings.clone(),
)?;
let merged_segment = merged_index.new_segment();
let merged_segment_id = merged_segment.id();
let merger: IndexMerger = IndexMerger::open_with_custom_alive_set(
merged_index.schema(),
merged_index.settings().clone(),
segments,
filter_doc_ids,
)?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment, true)?;
let num_docs = merger.write_with_doc_id_mapping(segment_serializer, doc_id_mapping)?;
let segment_meta = merged_index.new_segment_meta(merged_segment_id, num_docs);
let stats = format!(
"Segments Merge (external reordering): [{}]",
segments
.iter()
.fold(String::new(), |sum, current| format!(
"{sum}{} ",
current.meta().id().uuid_string()
))
.trim_end()
);
let index_meta = IndexMeta {
index_settings: target_settings,
segments: vec![segment_meta],
schema: target_schema,
opstamp: 0u64,
payload: Some(stats),
};
save_metas(&index_meta, merged_index.directory_mut())?;
Ok(merged_index)
}
pub(crate) struct InnerSegmentUpdater {
// we keep a copy of the current active IndexMeta to
// avoid loading the file every time we need it in the
@@ -1115,6 +1197,7 @@ mod tests {
)?;
let merger: IndexMerger = IndexMerger::open_with_custom_alive_set(
merged_index.schema(),
merged_index.settings().clone(),
&segments[..],
filter_segments,
)?;
@@ -1130,6 +1213,7 @@ mod tests {
Index::create(RamDirectory::default(), target_schema, target_settings)?;
let merger: IndexMerger = IndexMerger::open_with_custom_alive_set(
merged_index.schema(),
merged_index.settings().clone(),
&segments[..],
filter_segments,
)?;

View File

@@ -3,6 +3,7 @@ use common::JsonPathWriter;
use itertools::Itertools;
use tokenizer_api::BoxTokenStream;
use super::doc_id_mapping::{get_doc_id_mapping_from_field, DocIdMapping};
use super::operation::AddOperation;
use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
@@ -16,6 +17,7 @@ use crate::postings::{
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
use crate::store::{StoreReader, StoreWriter};
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
use crate::{DocId, Opstamp, TantivyError};
@@ -40,6 +42,20 @@ fn compute_initial_table_size(per_thread_memory_budget: usize) -> crate::Result<
})
}
fn remap_doc_opstamps(
opstamps: Vec<Opstamp>,
doc_id_mapping_opt: Option<&DocIdMapping>,
) -> Vec<Opstamp> {
if let Some(doc_id_mapping_opt) = doc_id_mapping_opt {
doc_id_mapping_opt
.iter_old_doc_ids()
.map(|doc| opstamps[doc as usize])
.collect()
} else {
opstamps
}
}
/// A `SegmentWriter` is in charge of creating segment index from a
/// set of documents.
///
@@ -75,7 +91,7 @@ impl SegmentWriter {
let tokenizer_manager = segment.index().tokenizers().clone();
let tokenizer_manager_fast_field = segment.index().fast_field_tokenizer().clone();
let table_size = compute_initial_table_size(memory_budget_in_bytes)?;
let segment_serializer = SegmentSerializer::for_segment(segment)?;
let segment_serializer = SegmentSerializer::for_segment(segment, false)?;
let per_field_postings_writers = PerFieldPostingsWriter::for_schema(&schema);
let per_field_text_analyzers = schema
.fields()
@@ -124,6 +140,15 @@ impl SegmentWriter {
/// be used afterwards.
pub fn finalize(mut self) -> crate::Result<Vec<u64>> {
self.fieldnorms_writer.fill_up_to_max_doc(self.max_doc);
let mapping: Option<DocIdMapping> = self
.segment_serializer
.segment()
.index()
.settings()
.sort_by_field
.clone()
.map(|sort_by_field| get_doc_id_mapping_from_field(sort_by_field, &self))
.transpose()?;
remap_and_write(
self.schema,
&self.per_field_postings_writers,
@@ -131,8 +156,10 @@ impl SegmentWriter {
self.fast_field_writers,
&self.fieldnorms_writer,
self.segment_serializer,
mapping.as_ref(),
)?;
Ok(self.doc_opstamps)
let doc_opstamps = remap_doc_opstamps(self.doc_opstamps, mapping.as_ref());
Ok(doc_opstamps)
}
/// Returns an estimation of the current memory usage of the segment writer.
@@ -393,10 +420,11 @@ fn remap_and_write(
fast_field_writers: FastFieldsWriter,
fieldnorms_writer: &FieldNormsWriter,
mut serializer: SegmentSerializer,
doc_id_map: Option<&DocIdMapping>,
) -> crate::Result<()> {
debug!("remap-and-write");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
fieldnorms_writer.serialize(fieldnorms_serializer)?;
fieldnorms_writer.serialize(fieldnorms_serializer, doc_id_map)?;
}
let fieldnorm_data = serializer
.segment()
@@ -407,10 +435,39 @@ fn remap_and_write(
schema,
per_field_postings_writers,
fieldnorm_readers,
doc_id_map,
serializer.get_postings_serializer(),
)?;
debug!("fastfield-serialize");
fast_field_writers.serialize(serializer.get_fast_field_write())?;
fast_field_writers.serialize(serializer.get_fast_field_write(), doc_id_map)?;
// finalize temp docstore and create version, which reflects the doc_id_map
if let Some(doc_id_map) = doc_id_map {
debug!("resort-docstore");
let store_write = serializer
.segment_mut()
.open_write(SegmentComponent::Store)?;
let settings = serializer.segment().index().settings();
let store_writer = StoreWriter::new(
store_write,
settings.docstore_compression,
settings.docstore_blocksize,
settings.docstore_compress_dedicated_thread,
)?;
let old_store_writer = std::mem::replace(&mut serializer.store_writer, store_writer);
old_store_writer.close()?;
let store_read = StoreReader::open(
serializer
.segment()
.open_read(SegmentComponent::TempStore)?,
1, /* The docstore is configured to have one doc per block, and each doc is
* accessed only once: we don't need caching. */
)?;
for old_doc_id in doc_id_map.iter_old_doc_ids() {
let doc_bytes = store_read.get_document_bytes(old_doc_id)?;
serializer.get_store_writer().store_bytes(&doc_bytes)?;
}
}
debug!("serializer-close");
serializer.close()?;

View File

@@ -226,10 +226,13 @@ pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
pub use crate::directory::Directory;
pub use crate::index::{
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,
SegmentMeta, SegmentReader,
Index, IndexBuilder, IndexMeta, IndexSettings, IndexSortByField, InvertedIndexReader, Order,
Segment, SegmentMeta, SegmentReader,
};
pub use crate::indexer::{
IndexMerger, IndexWriter, SegmentDocIdMapping, SingleSegmentIndexWriter,
merge_segments_with_doc_id_mapping,
};
pub use crate::indexer::{IndexWriter, SingleSegmentIndexWriter};
pub use crate::schema::{Document, TantivyDocument, Term};
/// Index format version.

View File

@@ -249,6 +249,12 @@ impl BlockSegmentPostings {
/// Returns the length of the current block.
///
/// Returns the decoded term-frequency buffer for the current block.
#[inline]
pub(crate) fn freq_output_array(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
@@ -281,6 +287,33 @@ impl BlockSegmentPostings {
doc
}
/// Returns the number of documents with a doc id strictly smaller than `target`
/// (i.e. the *rank* of `target` in this posting list).
///
/// This jumps to the block that may contain `target` through the skip list, so no
/// skipped block is decoded; a single block is then decoded to locate `target`
/// within it. The cost is therefore `O(number_of_skip_list_entries)` plus one block
/// decode, rather than `O(doc_freq)`.
///
/// Like [`Self::seek`], the underlying cursor only ever moves forward. This method
/// must be called with **non-decreasing** `target` values (galloping); calling it
/// with a `target` smaller than a previous one yields an incorrect result. `target`
/// must be a valid doc id (i.e. `target <= TERMINATED`), exactly as for `seek`.
///
/// Edge cases: returns `0` when `target` is smaller than every doc id, and
/// `doc_freq()` when `target` is larger than every doc id.
pub fn rank(&mut self, target: DocId) -> u32 {
if self.doc_freq == 0 {
return 0;
}
// `within` = number of docs in the landed block with a doc id < target.
let within = self.seek(target);
// `remaining_docs` counts the landed block and everything after it, so the
// difference is the number of docs in all blocks strictly before it.
let docs_before_block = self.doc_freq - self.skip_reader.remaining_docs();
docs_before_block + within as u32
}
pub(crate) fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
@@ -298,6 +331,11 @@ impl BlockSegmentPostings {
}
}
#[inline]
pub(crate) fn has_remaining_docs(&self) -> bool {
self.skip_reader.has_remaining_docs()
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.block_loaded
}
@@ -557,4 +595,38 @@ mod tests {
assert_eq!(block_segments.docs(), &[1, 3, 5]);
Ok(())
}
#[test]
fn test_block_segment_postings_rank() -> crate::Result<()> {
// ~8 blocks worth of docs so the skip list is actually exercised.
let docs: Vec<DocId> = (0..1000u32).map(|i| i * 3).collect();
let mut block_postings = build_block_postings(&docs[..])?;
let doc_freq = block_postings.doc_freq();
// rank(target) must equal the number of docs strictly below target.
// Targets are queried in non-decreasing order, as the API requires.
// `target` values must be a valid doc id (<= TERMINATED) and non-decreasing.
let targets = [
0u32, 1, 2, 3, 4, 299, 300, 301, 1500, 2996, 2997, 3000, 10_000,
];
for &target in &targets {
let expected = docs.iter().filter(|&&d| d < target).count() as u32;
assert_eq!(
block_postings.rank(target),
expected,
"rank({target}) mismatch"
);
}
// Edge cases: below the first doc -> 0, above the last doc -> doc_freq.
let mut fresh = build_block_postings(&docs[..])?;
assert_eq!(fresh.rank(0), 0);
let mut fresh = build_block_postings(&docs[..])?;
assert_eq!(fresh.rank(1_000_000), doc_freq);
// Empty postings: rank is always 0.
let mut empty = BlockSegmentPostings::empty();
assert_eq!(empty.rank(42), 0);
Ok(())
}
}

View File

@@ -3,6 +3,7 @@ use std::io;
use common::json_path_writer::JSON_END_OF_PATH;
use stacker::Addr;
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
@@ -62,6 +63,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
&self,
ordered_term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
doc_id_map: Option<&DocIdMapping>,
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
@@ -84,6 +86,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
doc_id_map,
&mut buffer_lender,
ctx,
serializer,
@@ -92,6 +95,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
doc_id_map,
&mut buffer_lender,
ctx,
serializer,

View File

@@ -5,6 +5,7 @@ use std::ops::Range;
use stacker::Addr;
use crate::fieldnorm::FieldNormReaders;
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::recorder::{BufferLender, Recorder};
@@ -50,6 +51,7 @@ pub(crate) fn serialize_postings(
schema: Schema,
per_field_postings_writers: &PerFieldPostingsWriter,
fieldnorm_readers: FieldNormReaders,
doc_id_map: Option<&DocIdMapping>,
serializer: &mut InvertedIndexSerializer,
) -> crate::Result<()> {
// Replace unordered ids by ordered ids to be able to sort
@@ -85,6 +87,7 @@ pub(crate) fn serialize_postings(
postings_writer.serialize(
&term_offsets[byte_offsets],
&ordered_id_to_path,
doc_id_map,
&ctx,
&mut field_serializer,
)?;
@@ -120,6 +123,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
doc_id_map: Option<&DocIdMapping>,
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
) -> io::Result<()>;
@@ -184,6 +188,7 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
pub(crate) fn serialize_one_term(
term: &[u8],
addr: Addr,
doc_id_map: Option<&DocIdMapping>,
buffer_lender: &mut BufferLender,
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
@@ -191,7 +196,7 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
let recorder: Rec = ctx.term_index.read(addr);
let term_doc_freq = recorder.term_doc_freq().unwrap_or(0u32);
serializer.new_term(term, term_doc_freq, recorder.has_term_freq())?;
recorder.serialize(&ctx.arena, serializer, buffer_lender);
recorder.serialize(&ctx.arena, doc_id_map, serializer, buffer_lender);
serializer.close_term()?;
Ok(())
}
@@ -231,12 +236,13 @@ impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
_ordered_id_to_path: &[&str],
doc_id_map: Option<&DocIdMapping>,
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let mut buffer_lender = BufferLender::default();
for (_field, _path_id, term, addr) in term_addrs {
Self::serialize_one_term(term, *addr, &mut buffer_lender, ctx, serializer)?;
Self::serialize_one_term(term, *addr, doc_id_map, &mut buffer_lender, ctx, serializer)?;
}
Ok(())
}

View File

@@ -1,6 +1,7 @@
use common::read_u32_vint;
use stacker::{ExpUnrolledLinkedList, MemoryArena};
use crate::indexer::doc_id_mapping::DocIdMapping;
use crate::postings::FieldSerializer;
use crate::DocId;
@@ -70,6 +71,7 @@ pub(crate) trait Recorder: Copy + Default + Send + Sync + 'static {
fn serialize(
&self,
arena: &MemoryArena,
doc_id_map: Option<&DocIdMapping>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
);
@@ -113,15 +115,26 @@ impl Recorder for DocIdRecorder {
fn serialize(
&self,
arena: &MemoryArena,
doc_id_map: Option<&DocIdMapping>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let buffer = buffer_lender.lend_u8();
let (buffer, doc_ids) = buffer_lender.lend_all();
// TODO avoid reading twice.
self.stack.read_to_end(arena, buffer);
let iter = get_sum_reader(VInt32Reader::new(&buffer[..]));
for doc_id in iter {
serializer.write_doc(doc_id, 0u32, &[][..]);
if let Some(doc_id_map) = doc_id_map {
let iter = get_sum_reader(VInt32Reader::new(&buffer[..]));
doc_ids.extend(iter.map(|old_doc_id| doc_id_map.get_new_doc_id(old_doc_id)));
doc_ids.sort_unstable();
for doc in doc_ids {
serializer.write_doc(*doc, 0u32, &[][..]);
}
} else {
let iter = get_sum_reader(VInt32Reader::new(&buffer[..]));
for doc_id in iter {
serializer.write_doc(doc_id, 0u32, &[][..]);
}
}
}
@@ -181,18 +194,35 @@ impl Recorder for TermFrequencyRecorder {
fn serialize(
&self,
arena: &MemoryArena,
doc_id_map: Option<&DocIdMapping>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let buffer = buffer_lender.lend_u8();
self.stack.read_to_end(arena, buffer);
let mut u32_it = VInt32Reader::new(&buffer[..]);
let mut prev_doc = 0;
while let Some(delta_doc_id) = u32_it.next() {
let doc_id = prev_doc + delta_doc_id;
prev_doc = doc_id;
let term_freq = u32_it.next().unwrap_or(self.current_tf);
serializer.write_doc(doc_id, term_freq, &[][..]);
if let Some(doc_id_map) = doc_id_map {
let mut doc_id_and_tf = vec![];
let mut prev_doc = 0;
while let Some(delta_doc_id) = u32_it.next() {
let doc_id = prev_doc + delta_doc_id;
prev_doc = doc_id;
let term_freq = u32_it.next().unwrap_or(self.current_tf);
doc_id_and_tf.push((doc_id_map.get_new_doc_id(doc_id), term_freq));
}
doc_id_and_tf.sort_unstable_by_key(|&(doc_id, _)| doc_id);
for (doc_id, tf) in doc_id_and_tf {
serializer.write_doc(doc_id, tf, &[][..]);
}
} else {
let mut prev_doc = 0;
while let Some(delta_doc_id) = u32_it.next() {
let doc_id = prev_doc + delta_doc_id;
prev_doc = doc_id;
let term_freq = u32_it.next().unwrap_or(self.current_tf);
serializer.write_doc(doc_id, term_freq, &[][..]);
}
}
}
@@ -238,12 +268,14 @@ impl Recorder for TfAndPositionRecorder {
fn serialize(
&self,
arena: &MemoryArena,
doc_id_map: Option<&DocIdMapping>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let (buffer_u8, buffer_positions) = buffer_lender.lend_all();
self.stack.read_to_end(arena, buffer_u8);
let mut u32_it = VInt32Reader::new(&buffer_u8[..]);
let mut doc_id_and_positions = vec![];
let mut prev_doc = 0;
while let Some(delta_doc_id) = u32_it.next() {
let doc_id = prev_doc + delta_doc_id;
@@ -262,7 +294,19 @@ impl Recorder for TfAndPositionRecorder {
}
}
}
serializer.write_doc(doc_id, buffer_positions.len() as u32, buffer_positions);
if let Some(doc_id_map) = doc_id_map {
// this simple variant to remap may consume to much memory
doc_id_and_positions
.push((doc_id_map.get_new_doc_id(doc_id), buffer_positions.to_vec()));
} else {
serializer.write_doc(doc_id, buffer_positions.len() as u32, buffer_positions);
}
}
if doc_id_map.is_some() {
doc_id_and_positions.sort_unstable_by_key(|&(doc_id, _)| doc_id);
for (doc_id, positions) in doc_id_and_positions {
serializer.write_doc(doc_id, positions.len() as u32, &positions);
}
}
}
@@ -275,8 +319,9 @@ impl Recorder for TfAndPositionRecorder {
mod tests {
use common::write_u32_vint;
use stacker::MemoryArena;
use super::{BufferLender, VInt32Reader};
use super::{BufferLender, Recorder, TermFrequencyRecorder, VInt32Reader};
#[test]
fn test_buffer_lender() {
@@ -314,4 +359,98 @@ mod tests {
let res: Vec<u32> = VInt32Reader::new(&buffer[..]).collect();
assert_eq!(&res[..], &vals[..]);
}
// ── TermFrequencyRecorder ─────────────────────────────────────────────────
#[test]
fn term_frequency_recorder_has_term_freq() {
let rec = TermFrequencyRecorder::default();
assert!(
rec.has_term_freq(),
"TermFrequencyRecorder must advertise term-frequency support"
);
}
#[test]
fn term_frequency_recorder_term_doc_freq_single_doc() {
let mut arena = MemoryArena::default();
let mut rec = TermFrequencyRecorder::default();
// Record one document with two term occurrences.
rec.new_doc(0, &mut arena);
rec.record_position(0, &mut arena);
rec.record_position(1, &mut arena);
rec.close_doc(&mut arena);
assert_eq!(
rec.term_doc_freq(),
Some(1),
"term_doc_freq should be 1 after recording one document"
);
}
#[test]
fn term_frequency_recorder_term_doc_freq_multiple_docs() {
let mut arena = MemoryArena::default();
let mut rec = TermFrequencyRecorder::default();
// Three documents with 1, 3, and 2 occurrences respectively.
for (doc, tf) in [(0u32, 1u32), (5, 3), (10, 2)] {
rec.new_doc(doc, &mut arena);
for pos in 0..tf {
rec.record_position(pos, &mut arena);
}
rec.close_doc(&mut arena);
}
assert_eq!(
rec.term_doc_freq(),
Some(3),
"term_doc_freq should equal the number of documents recorded"
);
}
#[test]
fn term_frequency_recorder_zero_docs() {
let rec = TermFrequencyRecorder::default();
assert_eq!(
rec.term_doc_freq(),
Some(0),
"term_doc_freq should be 0 before any document is recorded"
);
}
#[test]
fn term_frequency_recorder_single_occurrence_per_doc() {
let mut arena = MemoryArena::default();
let mut rec = TermFrequencyRecorder::default();
// Each document has exactly one occurrence — the minimum non-trivial case.
for doc in [1u32, 2, 100] {
rec.new_doc(doc, &mut arena);
rec.record_position(0, &mut arena);
rec.close_doc(&mut arena);
}
assert_eq!(rec.term_doc_freq(), Some(3));
}
#[test]
fn term_frequency_recorder_high_frequency_doc() {
let mut arena = MemoryArena::default();
let mut rec = TermFrequencyRecorder::default();
// A document where the term appears many times.
rec.new_doc(42, &mut arena);
for pos in 0..1000 {
rec.record_position(pos, &mut arena);
}
rec.close_doc(&mut arena);
assert_eq!(
rec.term_doc_freq(),
Some(1),
"term_doc_freq counts documents, not occurrences"
);
}
}

View File

@@ -146,6 +146,11 @@ impl SkipReader {
skip_reader
}
#[inline(always)]
pub fn has_remaining_docs(&self) -> bool {
self.remaining_docs != 0
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
@@ -182,6 +187,12 @@ impl SkipReader {
self.last_doc_in_block
}
/// Number of docs from the start of the current block to the end of the postings
/// (i.e. the current block plus every block after it).
pub(crate) fn remaining_docs(&self) -> u32 {
self.remaining_docs
}
pub fn position_offset(&self) -> u64 {
self.position_offset
}

View File

@@ -0,0 +1,464 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
/// Block-max pruning for top-K over intersection of term scorers.
///
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
/// For each window, the sum of block_max_scores is compared to the current threshold;
/// if the block can't beat it, the entire block is skipped.
///
/// Within non-skipped blocks, individual documents are pruned by checking whether
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
/// performing the expensive intersection membership check (seeking into secondary scorers).
///
/// # Preconditions
/// - `scorers` has at least 2 elements
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
pub(crate) fn block_wand_intersection(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
assert!(scorers.len() >= 2);
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
scorers.sort_by_key(TermScorer::size_hint);
let (leader, secondaries) = scorers.split_first_mut().unwrap();
// Precompute global max scores for early termination checks.
let leader_max_score: Score = leader.max_score();
let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum();
// Early exit: no document can possibly beat the threshold.
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
// Borrow fieldnorm reader and BM25 weight before the main loop.
// These are immutable references to disjoint fields from block_cursor,
// but Rust's borrow checker can't see through method calls, so we
// extract them once upfront.
let fieldnorm_reader = leader.fieldnorm_reader().clone();
let bm25_weight = leader.bm25_weight().clone();
let mut doc = leader.doc();
let mut secondary_block_max_scores: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
let mut secondary_suffix_block_max: Box<[f32]> =
vec![0.0f32; secondaries.len()].into_boxed_slice();
while doc < TERMINATED {
// --- Phase 1: Block-level pruning ---
//
// Position all skip readers on the block containing `doc`.
// seek_block is cheap: it only advances the skip reader, no block decompression.
leader.seek_block(doc);
let leader_block_max: Score = leader.block_max_score();
// Compute the window end as the minimum last_doc_in_block across all scorers.
// This ensures the block_max values are valid for all docs in [doc, window_end].
// Different scorers have independently aligned blocks, so we must use the
// smallest window where all block_max values hold.
let mut window_end: DocId = leader.last_doc_in_block();
let mut secondary_block_max_sum: Score = 0.0;
let num_secondaries = secondaries.len();
for (idx, secondary) in secondaries.iter_mut().enumerate() {
secondary.block_cursor().seek_block(doc);
if !secondary.block_cursor().has_remaining_docs() {
return;
}
window_end = window_end.min(secondary.last_doc_in_block());
let bms = secondary.block_max_score();
secondary_block_max_scores[idx] = bms;
secondary_block_max_sum += bms;
}
if leader_block_max + secondary_block_max_sum <= threshold {
// The entire window cannot beat the threshold. Skip past it.
doc = window_end + 1;
continue;
}
// --- Phase 2: Batch processing within the window ---
//
// Score-first approach: decode the leader's block, filter by threshold,
// then check intersection membership only for survivors. This avoids expensive
// secondary seeks for docs that can't beat the threshold.
let block_cursor = leader.block_cursor();
// seek loads the block and returns the in-block index of the first doc >= `doc`.
let start_idx = block_cursor.seek(doc);
// Use the branchless binary search on the doc decoder to find the first
// index past window_end.
let end_idx = block_cursor
.doc_decoder
.seek_within_block(window_end + 1)
.min(block_cursor.block_len());
let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx];
let block_freqs = &block_cursor.freq_output_array()[start_idx..end_idx];
// Pass 1: Batch-compute leader BM25 scores and branchlessly filter
// candidates that can't beat the threshold.
//
// The trick: always write to the buffer at `num_candidates`, then
// conditionally advance the count. The compiler can turn this into
// a cmov instead of a branch, avoiding misprediction costs.
let score_threshold = threshold - secondary_block_max_sum;
let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE];
let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE];
let mut num_candidates = 0usize;
for (candidate_doc, term_freq) in
block_docs.iter().copied().zip(block_freqs.iter().copied())
{
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc);
let leader_score = bm25_weight.score(fieldnorm_id, term_freq);
candidate_doc_ids[num_candidates] = candidate_doc;
candidate_scores[num_candidates] = leader_score;
num_candidates += (leader_score > score_threshold) as usize;
}
// Precompute suffix sums: suffix[i] = sum of block_max for secondaries[i+1..].
// Used in Phase 2 to prune candidates that can't beat threshold even with
// remaining secondaries contributing their block_max.
if num_candidates == 0 {
doc = window_end + 1;
continue;
}
let mut running = 0.0f32;
for idx in (0..num_secondaries).rev() {
secondary_suffix_block_max[idx] = running;
running += secondary_block_max_scores[idx];
}
// Pass 2: Check intersection membership only for survivors.
// score_threshold may be stale (threshold can increase from callbacks),
// but that's conservative — we may check a few extra candidates, never miss one.
'next_candidate: for candidate_idx in 0..num_candidates {
let candidate_doc = candidate_doc_ids[candidate_idx];
let mut total_score: Score = candidate_scores[candidate_idx];
for (secondary_idx, secondary) in secondaries.iter_mut().enumerate() {
// If a previous candidate already advanced this secondary past
// candidate_doc, the candidate can't be in the intersection.
if secondary.doc() > candidate_doc {
continue 'next_candidate;
}
let seek_result = secondary.seek(candidate_doc);
if seek_result != candidate_doc {
continue 'next_candidate;
}
total_score += secondary.score();
// Prune: even if all remaining secondaries score at their block max,
// can we still beat the threshold?
if total_score + secondary_suffix_block_max[secondary_idx] <= threshold {
continue 'next_candidate;
}
}
// All secondaries matched.
if total_score > threshold {
threshold = callback(candidate_doc, total_score);
if leader_max_score + secondaries_global_max_sum <= threshold {
return;
}
}
}
doc = window_end + 1;
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use proptest::prelude::*;
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
struct Float(Score);
impl Eq for Float {}
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Float {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn nearly_equals(left: Score, right: Score) -> bool {
(left - right).abs() < 0.0001 * (left + right).abs()
}
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
fn compute_checkpoints_block_wand_intersection(
term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit: Score = 0.0;
let callback = &mut |doc, score| {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
limit
};
super::block_wand_intersection(term_scorers, Score::MIN, callback);
checkpoints
}
/// Naive baseline: intersect by iterating all docs.
fn compute_checkpoints_naive_intersection(
mut term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit = Score::MIN;
// Sort by cost to use the cheapest as driver.
term_scorers.sort_by_key(|s| s.cost());
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
let mut doc = leader.doc();
while doc != TERMINATED {
let mut all_match = true;
for secondary in secondaries.iter_mut() {
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
all_match = false;
break;
}
}
if all_match {
let score: Score =
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
if score > limit {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
}
}
doc = leader.advance();
}
checkpoints
}
const MAX_TERM_FREQ: u32 = 100u32;
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
(1..max_doc + 1)
.prop_flat_map(move |doc_freq| {
(
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
)
})
.prop_map(|(docset, term_freqs)| {
docset
.iter()
.map(|doc| doc as u32)
.zip(term_freqs.iter().cloned())
.collect::<Vec<_>>()
})
.boxed()
}
#[expect(clippy::type_complexity)]
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
(1u32..100u32)
.prop_flat_map(move |max_doc: u32| {
(
proptest::collection::vec(posting_list(max_doc), num_scorers),
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
)
})
.boxed()
}
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
// strategy.
const REPEAT: usize = 64;
let fieldnorms_expanded: Vec<u32> = fieldnorms
.iter()
.cloned()
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
.collect();
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
.iter()
.map(|posting_list| {
posting_list
.iter()
.cloned()
.flat_map(|(doc, term_freq)| {
(0_u32..REPEAT as u32).map(move |offset| {
(
doc * (REPEAT as u32) + offset,
if offset == 0 { term_freq } else { 1 },
)
})
})
.collect::<Vec<(DocId, u32)>>()
})
.collect();
let total_fieldnorms: u64 = fieldnorms_expanded
.iter()
.cloned()
.map(|fieldnorm| fieldnorm as u64)
.sum();
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
let max_doc = fieldnorms_expanded.len();
let make_scorers = || -> Vec<TermScorer> {
postings_lists_expanded
.iter()
.map(|postings| {
let bm25_weight = Bm25Weight::for_one_term(
postings.len() as u64,
max_doc as u64,
average_fieldnorm,
);
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
})
.collect()
};
for top_k in 1..4 {
let checkpoints_optimized =
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
assert_eq!(
checkpoints_optimized.len(),
checkpoints_naive.len(),
"Mismatch in checkpoint count for top_k={top_k}"
);
for (&(left_doc, left_score), &(right_doc, right_score)) in
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
{
assert_eq!(left_doc, right_doc);
assert!(
nearly_equals(left_score, right_score),
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_two_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(2)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_three_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(3)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
#[test]
fn test_block_wand_intersection_disjoint() {
// Two posting lists with no overlap — intersection is empty.
let fieldnorms: Vec<u32> = vec![10; 200];
let average_fieldnorm = 10.0;
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
let scorer_a = TermScorer::create_for_test(
&postings_a,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let scorer_b = TermScorer::create_for_test(
&postings_b,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
assert!(checkpoints.is_empty());
}
#[test]
fn test_block_wand_intersection_all_overlap() {
// Two posting lists with full overlap.
let fieldnorms: Vec<u32> = vec![10; 50];
let average_fieldnorm = 10.0;
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
let make_scorer = || {
TermScorer::create_for_test(
&postings,
&fieldnorms,
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
)
};
let checkpoints_opt =
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
let checkpoints_naive =
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
}
}

View File

@@ -50,7 +50,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers: &mut [TermScorerWithMaxScore],
pivot_len: usize,
) {
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
let mut scorer_to_seek = pivot_len - 1;
let mut global_max_score = scorers[scorer_to_seek].max_score;
let mut doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block();
@@ -76,7 +76,7 @@ fn block_max_was_too_low_advance_one_scorer(
scorers[scorer_to_seek].seek(doc_to_seek_after);
restore_ordering(scorers, scorer_to_seek);
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
}
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
@@ -90,7 +90,7 @@ fn restore_ordering(term_scorers: &mut [TermScorerWithMaxScore], ord: usize) {
}
term_scorers.swap(i, i - 1);
}
debug_assert!(is_sorted(term_scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(term_scorers.iter().map(|scorer| scorer.doc()).is_sorted());
}
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
@@ -150,17 +150,21 @@ pub fn block_wand(
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
scorers.retain(|scorer| scorer.doc() < TERMINATED);
if scorers.len() == 1 {
let scorer = scorers.pop().unwrap();
return block_wand_single_scorer(scorer, threshold, callback);
}
let mut scorers: Vec<TermScorerWithMaxScore> = scorers
.iter_mut()
.map(TermScorerWithMaxScore::from)
.collect();
scorers.sort_by_key(|scorer| scorer.doc());
// At this point we need to ensure that the scorers are sorted!
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
scorers.sort_by_key(|scorer| scorer.doc());
while let Some((before_pivot_len, pivot_len, pivot_doc)) =
find_pivot_doc(&scorers[..], threshold)
{
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert_ne!(pivot_doc, TERMINATED);
debug_assert!(before_pivot_len < pivot_len);
@@ -228,7 +232,7 @@ pub fn block_wand_single_scorer(
loop {
// We position the scorer on a block that can reach
// the threshold.
while scorer.block_max_score() < threshold {
while scorer.block_max_score() <= threshold {
let last_doc_in_block = scorer.last_doc_in_block();
if last_doc_in_block == TERMINATED {
return;
@@ -286,18 +290,6 @@ impl DerefMut for TermScorerWithMaxScore<'_> {
}
}
fn is_sorted<I: Iterator<Item = DocId>>(mut it: I) -> bool {
if let Some(first) = it.next() {
let mut prev = first;
for doc in it {
if doc < prev {
return false;
}
prev = doc;
}
}
true
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

Some files were not shown because too many files have changed in this diff Show More