Compare commits

..

91 Commits

Author SHA1 Message Date
Stu Hood
147214b0eb Implement GreaterThanOrEqual and LessThanOrEqual to handle boundary cases in Chain. 2025-12-29 15:38:28 -07:00
Stu Hood
865a12f4bb Simpler implementation of first_vals_in_value_range. 2025-12-29 14:51:21 -07:00
Stu Hood
00110312c9 Fix compound filters, and remove redundant implementation in Chain implementation 2025-12-29 14:51:17 -07:00
Stu Hood
b2e980b450 Property test for Comparator/ValueRange consistency, and fixes. 2025-12-27 21:03:02 -07:00
Stu Hood
1a701b86bd Remove allow-dead-code annotation. 2025-12-27 17:56:55 -07:00
Stu Hood
ee4538d6c2 test_order_by_u64_prop 2025-12-27 16:04:24 -07:00
Stu Hood
25f1e9aa9f Move ComparableDoc to a reusable location, allowing for pushing directly from ColumnValues into a TopNComputer buffer in some cases. 2025-12-27 14:01:02 -07:00
Stu Hood
6b03b28bac Use a Buffer generic scratch buffer parameter on TopNComputer to allow for internal iteration in SegmentSortKeyComputer. 2025-12-27 13:03:56 -07:00
Stu Hood
7a5241cb83 Update comments. 2025-12-26 17:52:40 -07:00
Stu Hood
0f5e0f6f87 TODO: Audit. 2025-12-26 16:12:15 -07:00
Stu Hood
a654115d9a Squash with Optional. WIP: Still needs work: we are allocating. 2025-12-26 15:18:17 -07:00
Stu Hood
1a17515ead Convert test_order_by_compound_filtering_with_none to a proptest. 2025-12-26 14:25:01 -07:00
Stu Hood
0f1b0ce527 Optimize Optional indexes. TODO: Audit. 2025-12-26 13:01:30 -07:00
Stu Hood
0c920dfc61 Add a ValueRange filter to SegmentSortKeyComputer::segment_sort_keys. 2025-12-26 12:27:11 -07:00
Stu Hood
996fc936f6 Add null handling to first_vals_in_value_range. 2025-12-26 11:12:53 -07:00
Stu Hood
5ff38e1605 WIP: Add ValueRange cases for Comparators. 2025-12-26 11:02:19 -07:00
Stu Hood
e8a4adeedd Replace Column::first_vals with Column::first_vals_in_value_range. 2025-12-25 15:39:18 -07:00
Stu Hood
efc9e585a9 WIP: Add ValueRange::All 2025-12-25 15:16:26 -07:00
Stu Hood
f4252fc184 WIP: Add ValueRange. 2025-12-25 14:53:15 -07:00
Stu Hood
53c067d1f3 Restore laziness in ChainSegmentSortKeyComputer. 2025-12-24 10:39:26 -07:00
Stu Hood
259c1ed965 Isolate accept_sort_key_lazy to ChainSegmentSortKeyComputer. 2025-12-23 17:37:33 -07:00
Stu Hood
1afc432df8 Use an internal buffer in the SegmentSortKeyComputer. 2025-12-23 17:23:10 -07:00
Stu Hood
b8acd3ac94 WIP: Add and use segment_sort_keys to remove dynamic dispatch to the column. 2025-12-23 16:44:50 -07:00
Stu Hood
b5321d2125 Implement laziness for collect_block. 2025-12-23 15:48:36 -07:00
Stu Hood
ad3e2363fe WIP: Add failing test. 2025-12-23 15:48:34 -07:00
Stu Hood
9ec5750c25 Implement collect_block for lazy scorers. 2025-12-23 15:46:41 -07:00
Stu Hood
03f09a2b5b chore: Add support for natural-order-with-none-highest in TopDocs::order_by (#90)
Add `ComparatorEnum::NaturalNoneHigher`, which matches Postgres's `DESC NULLS FIRST` behavior in `TopDocs::order_by`.

Expands comments on `Comparator` implementations to ensure that behavior for `None` is explicit.

Upstream as https://github.com/quickwit-oss/tantivy/pull/2780
2025-12-23 09:15:31 -08:00
Stu Hood
9ffe4af096 Fix TopN performance regression.
https://github.com/quickwit-oss/tantivy/pull/2777
2025-12-17 10:43:29 -07:00
Stu Hood
c56ddcb6d7 Add an erased SortKeyComputer to sort on types which are not known until runtime.
https://github.com/quickwit-oss/tantivy/pull/2770
2025-12-17 10:43:29 -07:00
Ming
5b8fff154b fix: overflow in vint buffer (#88) 2025-12-17 10:43:29 -07:00
Mohammad Dashti
ff6ee3a5db fix: post-rebase fixes
- Add missing size_hint module declaration
- Remove test-only export serialize_and_load_u64_based_column_values
- fixed quickwit CI issues
2025-12-10 10:17:28 -08:00
Moe
eda9aa437f fix: boolean query incorrectly dropping documents when AllScorer is present (#84)
Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-10 10:17:28 -08:00
Piotr Olszak
538da08eb5 Add polish stemmer (#82)
This commit adds support for Polish language stemming.
The previously used rust-stemmers crate is abandoned and unmaintained, which blocked the addition of new languages. This change addresses a user request for Polish stemming to improve BM25 recall in their use case. The tantivy-stemmers crate is a modern, maintained alternative that also opens the door for supporting many other languages in the future.
- Added the tantivy-stemmers crate as a dependency to the workspace, alongside the existing rust-stemmers dependency (for backward compatibility)
- Introduced an internal enum that can hold an algorithm from either rust-stemmers or tantivy-stemmers
- Added Polish to the main Language enum, mapped to the new tantivy-stemmers implementation
- Updated the token stream to handle both types of stemmers internally
- Added the POLISH variant to the stopwords list
- Existing tests pass
- Added test_pl_tokenizer to verify that the Polish stemmer works correctly
2025-12-10 10:17:28 -08:00
Moe
7bd5cc5417 fix: fixed integer overflow in ExpUnrolledLinkedList for large datasets (#80) 2025-12-10 10:17:28 -08:00
Moe
5d46137556 feat: Added multiple snippet support (#76)
Adds `SnippetGenerator::snippets` to render multiple snippets in either score or position order.

Additionally: renames the existing `limit` and `offset` arguments to disambiguate between "match" positions (which are concatenated into fragments), and "snippet" positions.

Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-10 10:17:28 -08:00
Stu Hood
92c784f697 perf: Optimize TermSet for very large sets of terms. (#75)
* Removes allocation in a bunch of places
* Removes sorting of terms if we're going to use the fast field execution method
* Adds back the (accidentally dropped) cardinality threshold
* Removes `bool` support -- using the posting lists is always more efficient for a `bool`, since there are at most two of them
* More eagerly constructs the term `HashSet` so that it happens once, rather than once per segment
2025-12-10 10:17:28 -08:00
Stu Hood
b3541d10e1 chore: Use smaller merge buffers. (#74)
## What

Reduce the per-segment buffer sizes from 4MB to 512KB.

## Why

#71 moved from buffers which covered the entire file to maximum 4MB buffers. But for merges with very large segment counts, we need to be using more conservative buffer sizes. 512KB will still eliminate most posting list reads: posting lists larger than 512KB will skip the buffer.
2025-12-10 10:17:28 -08:00
Stu Hood
7183ac6cbc fix: Use smaller buffers during merging (#71)
`MergeOptimizedInvertedIndexReader` was added in #32 in order to avoid making small reads to our underlying `FileHandle`. It did so by reading the entire content of the posting lists and positions at open time.

As that PR says:
> A likely downside to this approach is that now pg_search will be, indirectly, holding onto a lot of heap-allocated memory that was read from its block storage. Perhaps in the (near) future we can further optimize the new `MergeOptimizedInvertedIndexReader` such that it pages in blocks of a few megabytes at a time, on demand, rather than the whole file.

This PR makes that change. But it additionally removes code that was later added in #47 to borrow individual entries rather than creating `OwnedBytes` for them. I believe that this code was added due to a misunderstanding:

`OwnedBytes` is a total misnomer: the bytes are not "owned": they are immutably borrowed and reference counted. An `OwnedBytes` object can be created for any type which derefs to a slice of bytes, and can be cheaply cloned and sliced. So there is no need to actually borrow _or_ copy the buffer under the `OwnedBytes`. Removing the code that was doing so allows us to safely recreate our buffer without worrying about the lifetimes of buffers that we've handed out.
2025-12-10 10:17:28 -08:00
Stu Hood
e0476d2eb2 fix: Add support for bool to the fast field TermSet implementation (#70)
Missed in #69.

The `TermSet` fast fields implementation cribbed from `RangeQuery`'s fast fields implementation: ... which also has this bug. Will fix upstream.
2025-12-10 10:17:28 -08:00
Stu Hood
9fe0899934 perf: Implement a TermSet variant which uses fast fields (#69)
The `TermSet` `Query` currently produces one `Scorer`/`DocSet` per matched term by scanning the term dictionary and then consuming posting lists. For very large sets of terms and a fast field, it is faster to scan the fast field column while intersecting with a `HashSet` of (encoded) term values.

Following the pattern set by the two execution modes of `RangeQuery`, this PR introduces a variant of `TermSet` which uses fast fields, and then uses it when there are more than 1024 input terms (an arbitrary threshold!).

Performance is significantly improved for large `TermSet`s of primitives.
2025-12-10 10:17:28 -08:00
Stu Hood
aaa5abb7d6 chore: Expose a method to create a segment with a particular id (#68)
In support of https://github.com/paradedb/paradedb/pull/3203
2025-12-10 10:17:28 -08:00
Ming
f8b8fd0321 feat: SnippetGenerator accepts limit/offset (#66) 2025-12-10 10:17:27 -08:00
Eric Ridge
cd878a5c90 fix: support MemoryArena allocations up to 4GB (#62)
A MemoryArena should support allocations up to 4GB and https://github.com/paradedb/tantivy/pull/60 broke this by not accounting for the "max page id" when pages are now 50% the size of what the originally were.

This cleans up the code so things stay in sync if we change NUM_BITS_PAGE_ADDR again and adds a unit test
2025-12-10 10:17:27 -08:00
Eric Ridge
30c237e895 perf: various optimizations around arenas (#60)
- Use a bitset to track used buckets in the `SharedArenaHashmap`, allowing for more efficient iteration
- Create a global pool for both `MemoryArena` and `IndexingContext`
- Reduce the MemoryArea page size by half (it's now 512KB instead of 1MB)
- Centralize thread pool instances in `SegmentUpdater` so we can elide making them if all nthread sizes are zero
2025-12-10 10:17:27 -08:00
Eric Ridge
b6cd39872b fix: Allow zero indexing & merging threads (#59)
This removes a check against `IndexWriterOptions` which disallowed zero indexing worker threads (`num_worker_threads`).
2025-12-10 10:17:27 -08:00
Stu Hood
c96d801c68 perf: Lazily load in BitpackedCodec (#56)
We would like to be able to lazily load `BitpackedCodec` columns (similar to what 020bdffd61 did for `BlockwiseLinearCodec`), because in the context of `pg_search`, immediately constructing `OwnedBytes` means copying the entire content of the column into memory.

To do so, we expose some (slightly overlapped) block boundaries from `BitUnpacker`, and then lazily load each block when it is requested. Only the `get_val` function uses the cache: `get_row_ids_for_value_range` does not (yet), because it would be necessary to partition the row ids by block, and most of the time consumers using it are already loading reasonably large ranges anyway.

See https://github.com/paradedb/paradedb/pull/2894 for usage. There are a few 2x speedups in the benchmark suite, as well as a 1.8x speedup on a representative customer query. Unfortunately there are also some 13-19% slowdowns on aggregates: it looks like that is because aggregates use `get_vals`, for which the default implementation is to just call `get_val` in a loop.
2025-12-10 10:17:27 -08:00
Stu Hood
7a13e0294d Avoid copying into OwnedBytes when opening a fast field column Dictionary. (#55)
When a fast fields string/bytes `Dictionary` is opened, we currently read the entire dictionary from `FileSlice` -> `OwnedBytes`... and then immediately wrap it back into a `FileSlice`.

Switching to `Dictionary::open` preserves the `FileSlice`, such that only the portions of the `Dictionary` which are actually accessed are read from disk/buffers.
2025-12-10 10:17:27 -08:00
Eric Ridge
20d00701ee perf: lazily open positions file (#54)
Not all queries use positions and it's okay if we (from the perspective of `pg_search`, anyways) defer opening/loading them until they're first needed.

This is probably completely wrong for a mmap-based Directory, but we (again, `pg_search`) decided long ago that we don't care about that use case.

This saves a lot of disk I/O when an index has lots of segments and the query doesn't need positions.

As a drive by, make sure a random Vec has enough space before pushing items to it.  This showed up in the profiler, believe it or not.
2025-12-10 10:17:27 -08:00
Eric Ridge
526afc6111 chore: internal API visibility adjustments (#53) 2025-12-10 10:17:27 -08:00
Ming
f9e4a8413b make the directory BufWriter capacity configurable (#52) 2025-12-10 10:17:27 -08:00
Ming
58124bb164 changes to make merging work (#48) 2025-12-10 10:17:27 -08:00
Eric Ridge
176f7e852a perf: remove general overhead during segment merging (#47) 2025-12-10 10:17:27 -08:00
Ming
cfa5f94114 chore: Make some delete-related functions public (#46) 2025-12-10 10:17:26 -08:00
Ming
5e449e7dda feat: SnippetGenerator can handle JSON fields (#42) 2025-12-10 10:17:26 -08:00
Stu Hood
1617459b01 Expose some methods which are necessary to create a streaming version of sorted_ords_to_term_cb. (#43)
See https://github.com/paradedb/paradedb/pull/2612.

We might eventually want that function upstreamed, but there are more changes planned to it for https://github.com/paradedb/paradedb/issues/2619, so doing the expedient thing now.
2025-12-10 10:17:26 -08:00
Ming
0e1a7e213e chore: allow merge_foreground to ignore the store (#40) 2025-12-10 10:17:26 -08:00
Ming
b0660ba196 chore: make some structs pub (#39) 2025-12-10 10:17:26 -08:00
Eric Ridge
936d6af471 feat: ability to directly merge segments in the foregound (#36)
This adds new public-facing (and internal) APIs for being able to merge a list of segments in the foreground, without using any threads.  It's largely a cut-n-paste of the existing background merge code.

For pg_search, this is beneficial because it allows us to merge directly using our `MVCCDirectory` rather than going through the `ChannelDirectory`, which has quite a bit of overhead.
2025-12-10 10:17:26 -08:00
Eric Ridge
2560de3a01 feat: IndexWriter::wait_merging_threads() return Err on merge failure (#34) 2025-12-10 10:17:26 -08:00
Eric Ridge
75a8384c2b feat: remove Directory::reconsider_merge_policy() and add other niceties to Directory API (#33)
This removes `Directory::reconsider_merge_policy()`.  After reconsidering this, it's better to make this decision ahead of time.

Also adds a `Directory::log(message: &str)` function along with passing a `Directory` reference to `MergePolicy::compute_merge_candidates()`.

It also hits some `#[derive(Debug)]` and `#[derive(Serialize)]` on a couple of structs that can benefit.
2025-12-10 10:17:26 -08:00
Eric Ridge
5b6da9123c feat: introduce a MergeOptimizedInvertedIndexReader (#32)
This is probably a bit of a misnomer as it's really a "PgSearchOptimizedInvertedIndexReaderForMerge".

What we've done here is copied `InvertedIndexReader` and internally adjusted it to hold onto the complete `OwnedBytes` of the index's postings and positions.  One or two other small touch points were required to make other internal APIs compatabile with this but they don't otherwise change functionality or I/O patterns.

`MergeOptimizedInvertedIndexReader` does change I/O patterns, however, in that the merge process now does two (potentially) very large reads when it obtains the new "merge optimized inverted index reader" for each segment.  This changes access patterns such that all the reads happen up-front rather than term-by-term as the merge process is solving.

A likely downside to this approach is that now pg_search will be, indirectly, holding onto a lot of heap-allocated memory that was read from its block storage.  Perhaps in the (near) future we can further optimize the new `MergeOptimizedInvertedIndexReader` such that it pages in blocks of a few megabytes at a time, on demand, rather than the whole file.

---

Some unit tests were also updated to resolve compilation problems by PR https://github.com/paradedb/tantivy/pull/31 that for some reason didn't show in CI.  #weird
2025-12-10 10:17:26 -08:00
Eric Ridge
8b7db36c99 feat: Add Directory::wants_cancel() function (#31)
This adds a function named `wants_cancel() -> bool` to the `Directory` trait.  It allows a Directory implementation to indicate that it would like Tantivy to cancel an operation.

Right now, querying this function only happens during key points of index merging, but _could_ be used in other places.  Technically, segment merging is the only "black box" in tantivy that users don't otherwise have the direct ability to control.

The default implementaiton of `wants_cancel()` returns false, so there's no fear of default tantivy spuriously cancelling a merge.

The cancels happen "cleanly" such that if `wants_cancel()` returns true an `Err(TantivyError::Cancelled)` is returned from the calling function at that point, and the error result will be propogated up the stack.  No panics are raised.
2025-12-10 10:17:26 -08:00
Eric Ridge
eabe589814 feat: ability to assign a panic handler to a Directory (#30)
Tantivy creates thread pools for some of its background work, specifically committing and merging.

It's possible if one of the thread workers panics that rayon will simply abort the process.  This is terrible from pg_search as that takes down the entire Postgres cluster.

These changes allow a Directory to assign a panic handler that gets called in such cases.  Which allows pg_search to gracefully rollback the current transaction, while presenting the panic message to the user.
2025-12-10 10:17:26 -08:00
Eric Ridge
65d3574dfd feat: make garbage collection opt-out (#28) 2025-12-10 10:17:26 -08:00
Ming
26d623c411 Change default index precision to microseconds (#27) 2025-12-10 10:17:25 -08:00
Eric Ridge
0552dddeb9 feat: delete docs by (SegmentId, DocId) (#26)
This teaches tantivy how to "directly" delete a document in a segment.
    
Our use case from pg_search is that we already know the segment_id and doc_id so it's waaaaay more efficient for us to delete docs through our `ambulkdelete()` routine.

It avoids doing a search, and all the stuff around that, for each of our "ctid" terms that we want to delete.
2025-12-10 10:17:25 -08:00
Eric Ridge
1b88bb61f9 feat: Add ability to construct a SegmentId from raw bytes (#24)
This allows a `SegmentId` to be constructed from a `[u8; 16]` byte array.  

It also adds a `impl Default for SegementId`, which defaults to all nulls
2025-12-10 10:17:25 -08:00
Eric Ridge
16da31cf06 perf: make the footer fixed width (#23)
Prior to this commit, the Footer tantivy serialized at the end of every file included a json blob that could be an arbitrary size.

This changes the Footer to be exactly 24 bytes (6 u32s), making sure to keep the `crc` value.  The other change we make here is to not actually read/validate the footer bytes when opening a file.

From pg_search's perspective, this is quite unnecessary Postgres buffer cache I/O and increases index/segment opening overhead, which is something pg_search does often for each query.

Two tests are ignored here as they test physical index files stored here in the repo that this change completely breaks.
2025-12-10 10:17:25 -08:00
Eric Ridge
658b9b22e0 perf: remove some fast fields loading overhead (#22)
This removes up some overhead the profiler exposed.  In the case I was testing, fast fields no longer shows up in the profile at all.

I also renamed `BlockWithLength` to `BlockWithData`
2025-12-10 10:17:25 -08:00
Eric Ridge
95661fba30 perf: teach SegmentReader to lazily open/read its various SegmentComponents (#20)
This overhauls `SegmentReader` to put its various components behind `OnceLock`s such that they can be opened and read on their first use, as oppoed when a SegmentReader is constructed -- which is once for every segment when an Index is opened.

This has a negative impact on some of Tantivy's expectations in that an existing SegementReader can still read from physical files that were deleted by a merge.  This isn't true now that the segment's physical files aren't opened until needed.  As such, I've `#[ignore]`'d six tests that expose this problem.

From our (pg_search's) side of things, we don't really have physical files and don't need to rely on the filesystem/kernel to allow reading unlinked files that are still open.

Overall, this cuts down a signficiant number of disk reads during pg_search's query planning.  With my test data it goes from 808 individual reads totalling 999,799 bytes, to 18 reads totalling 814,514 bytes.

This reduces the time it takes to plan a simple query from about 1.4ms to 0.436ms -- roughly a 3.2x improvement.
2025-12-10 10:17:25 -08:00
Philippe Noël
ddd169b77c chore: Don't do codecov (#21) 2025-12-10 10:17:25 -08:00
Eric Ridge
bb4c4b8522 perf: push FileSlices down through most of fast fields (#19)
This PR modifies internal API signatures and implementation details so that `FileSlice`s are passed down into the innards of (at least) the `BlockwiseLinearCodec`.  This allows tantivy to defer dereferencing large slices of bytes when reading numeric fast fields, and instead dereference only the slice of bytes it needs for any given compressed Block.

The motivation here is for external `Directory` implementations where it's not exactly efficient to dereference large slices of bytes.
2025-12-10 10:17:25 -08:00
Neil Hansen
ffa558e3a9 fix: tests in ci (#18) 2025-12-10 10:17:25 -08:00
Neil Hansen
a35e3dcb5a suppress warnings after rebase 2025-12-10 10:17:25 -08:00
Neil Hansen
1e3998fbad implement fuzzy scoring in sstable 2025-12-10 10:17:25 -08:00
Neil Hansen
f3df079d6b chore: point tantivy-fst to paradedb fork to fix regex 2025-12-10 10:17:24 -08:00
Ming Ying
f7c0335857 comments 2025-12-10 10:17:24 -08:00
Ming Ying
2584325e0d add reconsider_merge_policy to directory 2025-12-10 10:17:24 -08:00
Eric B. Ridge
1f2c2d0c8a fix compilation warnings on rust v1.83 2025-12-10 10:17:24 -08:00
Eric Ridge
91db6909d1 Add a payload: &mut (dyn Any + '_) argument to Directory::save_meta() (#17) 2025-12-10 10:17:24 -08:00
Ming Ying
7639b47615 small changes to make MVCC work with delete 2025-12-10 10:17:24 -08:00
Ming Ying
8b55f0f355 Make DeleteMeta pub 2025-12-10 10:17:24 -08:00
Ming Ying
8d29f19110 make save_metas provide previous metas 2025-12-10 10:17:24 -08:00
Ming Ying
d742d3277a undo changes to segment_updater.rs 2025-12-10 10:17:24 -08:00
Eric B. Ridge
3afe3714a2 no pgrx, please 2025-12-10 10:17:24 -08:00
Ming Ying
67ea8e53a8 quickwit compiles 2025-12-10 10:17:24 -08:00
Ming Ying
3adc85c017 Directory trait can read/write meta/managed 2025-12-10 10:17:24 -08:00
Ming
6bb3a22c98 expose AddOperation and with_max_doc (#7) 2025-12-10 10:17:23 -08:00
Ming
5503cfb8ef Fix managed paths (#5) 2025-12-10 10:17:23 -08:00
Alexander Alexandrov
ea0e88ae4b feat: implement TokenFilter for Option<F> (#4) 2025-12-10 10:17:23 -08:00
Neil Hansen
dee2dd3f21 Use Levenshtein distance to score documents in fuzzy term queries 2025-12-10 10:17:19 -08:00
247 changed files with 13898 additions and 8372 deletions

View File

@@ -1,29 +0,0 @@
name: Coverage
on:
push:
branches: [main]
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
coverage:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
files: lcov.info
fail_ci_if_error: true

View File

@@ -39,11 +39,11 @@ jobs:
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
- name: Check Stable Compilation
run: cargo build --all-features
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
@@ -59,10 +59,10 @@ jobs:
strategy:
matrix:
features:
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
- { label: "none", flags: "" }
features: [
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
]
name: test-${{ matrix.features.label}}
@@ -76,25 +76,13 @@ jobs:
profile: minimal
override: true
- uses: taiki-e/install-action@nextest
- uses: taiki-e/install-action@v2
with:
tool: 'nextest'
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
cargo +stable nextest run --lib --no-default-features --verbose --workspace
else
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
fi
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
- name: Run doctests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
echo "no doctest for no feature flag"
else
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
fi
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace

5
.gitignore vendored
View File

@@ -6,7 +6,6 @@ target
target/debug
.vscode
target/release
Cargo.lock
benchmark
.DS_Store
*.bk
@@ -15,3 +14,7 @@ trace.dat
cargo-timing*
control
variable
# for `sample record -p`
profile.json
profile.json.gz

2361
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -21,13 +21,13 @@ byteorder = "1.4.3"
crc32fast = "1.3.2"
once_cell = "1.10.0"
regex = { version = "1.5.5", default-features = false, features = [
"std",
"unicode",
"std",
"unicode",
] }
aho-corasick = "1.0"
tantivy-fst = "0.5"
tantivy-fst = { git = "https://github.com/paradedb/fst.git" }
memmap2 = { version = "0.9.0", optional = true }
lz4_flex = { version = "0.12", default-features = false, optional = true }
lz4_flex = { version = "0.11", default-features = false, optional = true }
zstd = { version = "0.13", optional = true, default-features = false }
tempfile = { version = "3.12.0", optional = true }
log = "0.4.16"
@@ -37,10 +37,11 @@ fs4 = { version = "0.13.1", optional = true }
levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = { version = "1.2.0", optional = true }
rust-stemmers = "1.2.0"
tantivy-stemmers = { version = "0.4.0", default-features = false, features = ["polish_yarovoy"] }
downcast-rs = "2.0.1"
bitpacking = { version = "0.9.3", default-features = false, features = [
"bitpacker4x",
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
] }
census = "0.4.2"
rustc-hash = "2.0.0"
@@ -48,9 +49,13 @@ thiserror = "2.0.1"
htmlescape = "0.3.1"
fail = { version = "0.5.0", optional = true }
time = { version = "0.3.35", features = ["serde-well-known"] }
# TODO: We have integer wrappers with PartialOrd, and a misfeature of
# `deranged` causes inference to fail in a bunch of cases. See
# https://github.com/jhpratt/deranged/issues/18#issuecomment-2746844093
deranged = "=0.4.0"
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.16.3"
lru = "0.12.0"
fastdivide = "0.4.0"
itertools = "0.14.0"
measure_time = "0.9.0"
@@ -69,26 +74,27 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
parking_lot = "0.12.4"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.2"
rand = "0.9"
binggan = "0.14.0"
rand = "0.8.5"
maplit = "1.0.2"
matches = "0.1.9"
pretty_assertions = "1.2.1"
proptest = "1.7.0"
proptest = "1.0.0"
test-log = "0.2.10"
futures = "0.3.21"
paste = "1.0.11"
more-asserts = "0.3.1"
rand_distr = "0.5"
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
@@ -113,8 +119,7 @@ debug-assertions = true
overflow-checks = true
[features]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
stemmer = ["rust-stemmers"]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
mmap = ["fs4", "tempfile", "memmap2"]
stopwords = []
@@ -136,14 +141,14 @@ compare_hash_only = ["stacker/compare_hash_only"]
[workspace]
members = [
"query-grammar",
"bitpacker",
"common",
"ownedbytes",
"stacker",
"sstable",
"tokenizer-api",
"columnar",
"query-grammar",
"bitpacker",
"common",
"ownedbytes",
"stacker",
"sstable",
"tokenizer-api",
"columnar",
]
# Following the "fail" crate best practises, we isolate
@@ -174,22 +179,6 @@ harness = false
name = "exists_json"
harness = false
[[bench]]
name = "range_query"
harness = false
[[bench]]
name = "and_or_queries"
harness = false
[[bench]]
name = "range_queries"
harness = false
[[bench]]
name = "bool_queries_with_range"
harness = false
[[bench]]
name = "str_search_and_get"
harness = false

View File

@@ -1,8 +1,8 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distr::weighted::WeightedIndex;
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::seq::IndexedRandom;
use rand::{Rng, SeedableRng};
use rand_distr::Distribution;
use serde_json::json;
@@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_7);
register!(group, terms_few);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status);
register!(group, terms_few_with_histogram);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_status);
register!(group, histogram_with_term_agg_few);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
@@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
fn terms_few_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"terms": { "field": "text_few_terms" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -175,7 +175,13 @@ fn terms_status_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
fn terms_few(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_status(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
@@ -188,7 +194,7 @@ fn terms_all_unique(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
fn terms_many(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -247,6 +253,17 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -259,18 +276,17 @@ fn terms_status_with_histogram(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
fn terms_few_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
"average_f64": { "avg": { "field": "score_f64" } }
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -283,25 +299,6 @@ fn terms_status_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -357,7 +354,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_status(index: &Index) {
fn range_agg_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -372,7 +369,7 @@ fn range_agg_with_term_agg_status(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } },
"my_texts": { "terms": { "field": "text_few_terms" } },
}
},
});
@@ -428,12 +425,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_status(index: &Index) {
fn histogram_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } }
"my_texts": { "terms": { "field": "text_few_terms" } }
}
}
});
@@ -478,13 +475,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -496,44 +486,24 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000.0, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -543,12 +513,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
@@ -558,10 +524,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -576,8 +542,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
}
let _val_max = 1_000_000.0;
for _ in 0..doc_with_value {
let val: f64 = rng.random_range(0.0..1_000_000.0);
let json = if rng.random_bool(0.1) {
let val: f64 = rng.gen_range(0.0..1_000_000.0);
let json = if rng.gen_bool(0.1) {
// 10% are numeric values
json!({ "mixed_type": val })
} else {
@@ -586,10 +552,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -641,7 +607,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
"terms": { "field": "text_few_terms" }
}
}
}
@@ -657,7 +623,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
"terms": { "field": "text_few_terms" }
}
}
}

View File

@@ -55,29 +55,29 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.random_bool(p_a as f64);
let has_b = rng.random_bool(p_b as f64);
let has_c = rng.random_bool(p_c as f64);
let score = rng.random_range(0u64..100u64);
let score2 = rng.random_range(0u64..100_000u64);
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");

View File

@@ -1,288 +0,0 @@
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
// Unified schema
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.random_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.random_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.random_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.random_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build query parser for title field
let qp_title = QueryParser::for_index(&index, vec![f_title]);
BenchIndex {
index,
searcher,
query_parser: qp_title,
}
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
0,
9,
),
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
990,
999,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
0,
9,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
9_999_990,
9_999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define all four field types
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
// Define the three terms we want to test with
let terms = ["a", "b", "z"];
// Generate all combinations of terms and field names
let mut queries = Vec::new();
for &term in &terms {
for &field_name in &field_names {
let query_str = format!(
"{} AND {}:[{} TO {}]",
term, field_name, range_low, range_high
);
queries.push((query_str, field_name.to_string()));
}
}
let query_str = format!(
"{}:[{} TO {}] AND {}:[{} TO {}]",
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
);
queries.push((query_str, "num_asc_fast".to_string()));
// Run all benchmark tasks for each query and its corresponding field name
for (query_str, field_name) in queries {
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
}
group.run();
}
}
/// Run all benchmark tasks for a given query string and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
field_name: &str,
) {
// Test count
add_bench_task(bench_group, bench_index, query_str, Count, "count");
// Test all results
add_bench_task(
bench_group,
bench_index,
query_str,
DocSetCollector,
"all results",
);
// Test top 100 by the field (if it's a FAST field)
if field_name.ends_with("_fast") {
// Ascending order
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
&collector_name,
);
}
// Descending order
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
&collector_name,
);
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &self.collector).unwrap();
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
*count
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(top_docs) =
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
{
doc_set.len()
} else {
eprintln!(
"Unknown collector result type: {:?}",
std::any::type_name::<C::Fruit>()
);
0
}
}
}

View File

@@ -1,365 +0,0 @@
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector, TopDocs};
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, FAST, INDEXED};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with fast fields only
let mut schema_builder = Schema::builder();
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
let num_rand = rng.random_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
let num_rand = rng.random_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
// Dense distribution - random values in small range (0-999)
(
"dense_values_search_low_value_range".to_string(),
10_000_000,
"dense",
0,
9,
),
(
"dense_values_search_high_value_range".to_string(),
10_000_000,
"dense",
990,
999,
),
(
"dense_values_search_out_of_range".to_string(),
10_000_000,
"dense",
1000,
1002,
),
(
"sparse_values_search_low_value_range".to_string(),
10_000_000,
"sparse",
0,
9,
),
(
"sparse_values_search_high_value_range".to_string(),
10_000_000,
"sparse",
9_999_990,
9_999_999,
),
(
"sparse_values_search_out_of_range".to_string(),
10_000_000,
"sparse",
10_000_000,
10_000_002,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define fast field types
let field_names = ["num_rand_fast", "num_asc_fast"];
// Generate range queries for fast fields
for &field_name in &field_names {
// Create the range query
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
let lower_term = Term::from_field_u64(field, range_low);
let upper_term = Term::from_field_u64(field, range_high);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(
&mut group,
&bench_index,
query,
field_name,
range_low,
range_high,
);
}
group.run();
}
}
/// Run all benchmark tasks for a given range query and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
field_name: &str,
range_low: u64,
range_high: u64,
) {
// Test count
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
"count",
field_name,
range_low,
range_high,
);
// Test top 100 by the field (ascending order)
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_asc(
bench_group,
bench_index,
query.clone(),
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
// Test top 100 by the field (descending order)
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_desc(
bench_group,
bench_index,
query,
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_asc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100AscSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_desc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100DescSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct Top100AscSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100AscSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}
struct Top100DescSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100DescSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}

View File

@@ -1,260 +0,0 @@
use std::fmt::Display;
use std::net::Ipv6Addr;
use std::ops::RangeInclusive;
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use columnar::MonotonicallyMappableToU128;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
bench_range_query();
}
fn bench_range_query() {
let index = get_index_0_to_100();
let mut runner = BenchRunner::new();
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
runner.set_name("range_query on u64");
let field_name_and_descr: Vec<_> = vec![
("id", "Single Valued Range Field"),
("ids", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent()),
("10_percent", get_10_percent()),
("1_percent", get_1_percent()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
runner.set_name("range_query on ip");
let field_name_and_descr: Vec<_> = vec![
("ip", "Single Valued Range Field"),
("ips", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent_ip()),
("10_percent", get_10_percent_ip()),
("1_percent", get_1_percent_ip()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
}
fn test_range<T: Display>(
runner: &mut BenchRunner,
index: &Index,
field_name_and_descr: &[(&str, &str)],
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
) {
for (field, suffix) in field_name_and_descr {
let term_num_hits = vec![
("", ""),
("1_percent", "veryfew"),
("10_percent", "few"),
("90_percent", "most"),
];
let mut group = runner.new_group();
group.set_name(suffix);
// all intersect combinations
for (range_name, range) in &range_num_hits {
for (term_name, term) in &term_num_hits {
let index = &index;
let test_name = if term_name.is_empty() {
format!("id_range_hit_{}", range_name)
} else {
format!(
"id_range_hit_{}_intersect_with_term_{}",
range_name, term_name
)
};
group.register(test_name, move |_| {
let query = if term_name.is_empty() {
"".to_string()
} else {
format!("AND id_name:{}", term)
};
black_box(execute_query(field, range, &query, index));
});
}
}
group.run();
}
}
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.random_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.random_bool(0.1) {
"few".to_string() // 9%
} else {
"most".to_string() // 90%
};
Doc {
id_name,
id: rng.random_range(0..100),
// Multiply by 1000, so that we create most buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.random_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
#[derive(Clone, Debug)]
pub struct Doc {
pub id_name: String,
pub id: u64,
pub ip: Ipv6Addr,
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
let ids_u64_field =
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
let ids_f64_field = schema_builder.add_f64_field(
"ids_f64",
NumericOptions::default().set_fast().set_indexed(),
);
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
let ids_i64_field = schema_builder.add_i64_field(
"ids_i64",
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ids_i64_field => doc.id as i64,
ids_i64_field => doc.id as i64,
ids_f64_field => doc.id as f64,
ids_f64_field => doc.id as f64,
ids_u64_field => doc.id,
ids_u64_field => doc.id,
id_u64_field => doc.id,
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
struct NumHits {
count: usize,
}
impl OutputValue for NumHits {
fn column_title() -> &'static str {
"NumHits"
}
fn format(&self) -> Option<String> {
Some(self.count.to_string())
}
}
fn execute_query<T: Display>(
field: &str,
id_range: &RangeInclusive<T>,
suffix: &str,
index: &Index,
) -> NumHits {
let gen_query_inclusive = |from: &T, to: &T| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
execute_query_(&query, index)
}
fn execute_query_(query: &str, index: &Index) -> NumHits {
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let num_hits = searcher
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap()
.1;
NumHits { count: num_hits }
}

View File

@@ -1,421 +0,0 @@
// This benchmark compares different approaches for retrieving string values:
//
// 1. Fast Field Approach: retrieves string values via term_ords() and ord_to_str()
//
// 2. Doc Store Approach: retrieves string values via searcher.doc() and field extraction
//
// The benchmark includes various data distributions:
// - Dense Sequential: Sequential document IDs with dense data
// - Dense Random: Random document IDs with dense data
// - Sparse Sequential: Sequential document IDs with sparse data
// - Sparse Random: Random document IDs with sparse data
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector};
use tantivy::query::RangeQuery;
use tantivy::schema::document::TantivyDocument;
use tantivy::schema::{Schema, Value, FAST, STORED, STRING};
use tantivy::{doc, Index, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with string fast field and stored field for doc access
let mut schema_builder = Schema::builder();
let f_str_fast = schema_builder.add_text_field("str_fast", STRING | STORED | FAST);
let f_str_stored = schema_builder.add_text_field("str_stored", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.random_range(0u64..1000u64);
let str_val = format!("str_{:03}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"dense_sequential" => {
for doc_id in 0..num_docs {
let suffix = doc_id as u64 % 1000;
let str_val = format!("str_{:03}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"sparse_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.random_range(0u64..1000000u64);
let str_val = format!("str_{:07}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"sparse_sequential" => {
for doc_id in 0..num_docs {
let suffix = doc_id as u64;
let str_val = format!("str_{:07}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense_random_search_low_range".to_string(),
1_000_000,
"dense_random",
0,
9,
),
(
"dense_random_search_high_range".to_string(),
1_000_000,
"dense_random",
990,
999,
),
(
"dense_sequential_search_low_range".to_string(),
1_000_000,
"dense_sequential",
0,
9,
),
(
"dense_sequential_search_high_range".to_string(),
1_000_000,
"dense_sequential",
990,
999,
),
(
"sparse_random_search_low_range".to_string(),
1_000_000,
"sparse_random",
0,
9999,
),
(
"sparse_random_search_high_range".to_string(),
1_000_000,
"sparse_random",
990_000,
999_999,
),
(
"sparse_sequential_search_low_range".to_string(),
1_000_000,
"sparse_sequential",
0,
9999,
),
(
"sparse_sequential_search_high_range".to_string(),
1_000_000,
"sparse_sequential",
990_000,
999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, distribution, range_low, range_high) in scenarios {
let bench_index = build_shared_indices(n, distribution);
let mut group = runner.new_group();
group.set_name(scenario_id);
let field = bench_index.searcher.schema().get_field("str_fast").unwrap();
let (lower_str, upper_str) =
if distribution == "dense_sequential" || distribution == "dense_random" {
(
format!("str_{:03}", range_low),
format!("str_{:03}", range_high),
)
} else {
(
format!("str_{:07}", range_low),
format!("str_{:07}", range_high),
)
};
let lower_term = Term::from_field_text(field, &lower_str);
let upper_term = Term::from_field_text(field, &upper_str);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(&mut group, &bench_index, query, range_low, range_high);
group.run();
}
}
/// Run all benchmark tasks for a given range query
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
// Test count of matching documents
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all DocIds of matching documents
add_bench_task_docset(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all string fast field values of matching documents
add_bench_task_fetch_all_strings(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all string values of matching documents through doc() method
add_bench_task_fetch_all_strings_from_doc(
bench_group,
bench_index,
query,
range_low,
range_high,
);
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!("string_search_count_[{}-{}]", range_low, range_high);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!("string_fetch_all_docset_[{}-{}]", range_low, range_high);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_fetch_all_strings(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"string_fastfield_fetch_all_strings_[{}-{}]",
range_low, range_high
);
let search_task = FetchAllStringsSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| {
let result = black_box(search_task.run());
result.len()
});
}
fn add_bench_task_fetch_all_strings_from_doc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"string_doc_fetch_all_strings_[{}-{}]",
range_low, range_high
);
let search_task = FetchAllStringsFromDocTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| {
let result = black_box(search_task.run());
result.len()
});
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct FetchAllStringsSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl FetchAllStringsSearchTask {
#[inline(never)]
pub fn run(&self) -> Vec<String> {
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
docs.sort();
let mut strings = Vec::with_capacity(docs.len());
for doc_address in docs {
let segment_reader = &self.searcher.segment_readers()[doc_address.segment_ord as usize];
let str_column_opt = segment_reader.fast_fields().str("str_fast");
if let Ok(Some(str_column)) = str_column_opt {
let doc_id = doc_address.doc_id;
let term_ord = str_column.term_ords(doc_id).next().unwrap();
let mut str_buffer = String::new();
if str_column.ord_to_str(term_ord, &mut str_buffer).is_ok() {
strings.push(str_buffer);
}
}
}
strings
}
}
struct FetchAllStringsFromDocTask {
searcher: Searcher,
query: RangeQuery,
}
impl FetchAllStringsFromDocTask {
#[inline(never)]
pub fn run(&self) -> Vec<String> {
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
docs.sort();
let mut strings = Vec::with_capacity(docs.len());
let str_stored_field = self
.searcher
.schema()
.get_field("str_stored")
.expect("str_stored field should exist");
for doc_address in docs {
// Get the document from the doc store (row store access)
if let Ok(doc) = self.searcher.doc::<TantivyDocument>(doc_address) {
// Extract string values from the stored field
if let Some(field_value) = doc.get_first(str_stored_field) {
if let Some(text) = field_value.as_value().as_str() {
strings.push(text.to_string());
}
}
}
}
strings
}
}

View File

@@ -11,12 +11,9 @@ keywords = []
documentation = "https://docs.rs/tantivy-bitpacker/latest/tantivy_bitpacker"
homepage = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
[dev-dependencies]
rand = "0.9"
rand = "0.8"
proptest = "1"

View File

@@ -4,8 +4,8 @@ extern crate test;
#[cfg(test)]
mod tests {
use rand::rng;
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
@@ -27,7 +27,7 @@ mod tests {
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut thread_rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for &idx in &idxs {

View File

@@ -48,7 +48,7 @@ impl BitPacker {
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
if self.mini_buffer_written > 0 {
let num_bytes = self.mini_buffer_written.div_ceil(8);
let num_bytes = (self.mini_buffer_written + 7) / 8;
let bytes = self.mini_buffer.to_le_bytes();
output.write_all(&bytes[..num_bytes])?;
self.mini_buffer_written = 0;
@@ -65,10 +65,16 @@ impl BitPacker {
#[derive(Clone, Debug, Default, Copy)]
pub struct BitUnpacker {
num_bits: usize,
num_bits: u32,
mask: u64,
}
pub type BlockNumber = usize;
// 16k
const BLOCK_SIZE_MIN_POW: u8 = 14;
const BLOCK_SIZE_MIN: usize = 2 << BLOCK_SIZE_MIN_POW;
impl BitUnpacker {
/// Creates a bit unpacker, that assumes the same bitwidth for all values.
///
@@ -82,8 +88,9 @@ impl BitUnpacker {
} else {
(1u64 << num_bits) - 1u64
};
BitUnpacker {
num_bits: usize::from(num_bits),
num_bits: u32::from(num_bits),
mask,
}
}
@@ -92,16 +99,69 @@ impl BitUnpacker {
self.num_bits as u8
}
/// Calculates a block number for the given `idx`.
#[inline]
pub fn block_num(&self, idx: u32) -> BlockNumber {
// Find the address in bits of the index.
let addr_in_bits = (idx * self.num_bits) as usize;
// Then round down to the nearest byte.
let addr_in_bytes = addr_in_bits >> 3;
// And compute the containing BlockNumber.
addr_in_bytes >> (BLOCK_SIZE_MIN_POW + 1)
}
/// Given a block number and dataset length, calculates a data Range for the block.
pub fn block(&self, block: BlockNumber, data_len: usize) -> Range<usize> {
let block_addr = block << (BLOCK_SIZE_MIN_POW + 1);
// We extend the end of the block by a constant factor, so that it overlaps the next
// block. That ensures that we never need to read on a block boundary.
block_addr..(std::cmp::min(block_addr + BLOCK_SIZE_MIN + 8, data_len))
}
/// Calculates the number of blocks for the given data_len.
///
/// Usually only called at startup to pre-allocate structures.
pub fn block_count(&self, data_len: usize) -> usize {
let block_count = data_len / (BLOCK_SIZE_MIN as usize);
if data_len % (BLOCK_SIZE_MIN as usize) == 0 {
block_count
} else {
block_count + 1
}
}
/// Returns a range within the data which covers the given id_range.
///
/// NOTE: This method is used for batch reads which bypass blocks to avoid dealing with block
/// boundaries.
#[inline]
pub fn block_oblivious_range(&self, id_range: Range<u32>, data_len: usize) -> Range<usize> {
let start_in_bits = id_range.start * self.num_bits;
let start = (start_in_bits >> 3) as usize;
let end_in_bits = id_range.end * self.num_bits;
let end = (end_in_bits >> 3) as usize;
// TODO: We fetch more than we need and then truncate.
start..(std::cmp::min(end + 8, data_len))
}
#[inline]
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
let addr_in_bits = idx as usize * self.num_bits;
let addr = addr_in_bits >> 3;
self.get_from_subset(idx, 0, data)
}
/// Get the value at the given idx, which must exist within the given subset of the data.
#[inline]
pub fn get_from_subset(&self, idx: u32, data_offset: usize, data: &[u8]) -> u64 {
let addr_in_bits = idx * self.num_bits;
let addr = (addr_in_bits >> 3) as usize - data_offset;
if addr + 8 > data.len() {
if self.num_bits == 0 {
return 0;
}
let bit_shift = addr_in_bits & 7;
return self.get_slow_path(addr, bit_shift as u32, data);
return self.get_slow_path(addr, bit_shift, data);
}
let bit_shift = addr_in_bits & 7;
let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
@@ -113,6 +173,7 @@ impl BitUnpacker {
#[inline(never)]
fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
let mut bytes: [u8; 8] = [0u8; 8];
let available_bytes = data.len() - addr;
// This function is meant to only be called if we did not have 8 bytes to load.
debug_assert!(available_bytes < 8);
@@ -128,26 +189,25 @@ impl BitUnpacker {
// #Panics
//
// This methods panics if `num_bits` is > 32.
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
fn get_batch_u32s(&self, start_idx: u32, data_offset: usize, data: &[u8], output: &mut [u32]) {
assert!(
self.bit_width() <= 32,
"Bitwidth must be <= 32 to use this method."
);
let end_idx: u32 = start_idx + output.len() as u32;
let end_idx = start_idx + output.len() as u32;
// We use `usize` here to avoid overflow issues.
let end_bit_read = (end_idx as usize) * self.num_bits;
let end_byte_read = end_bit_read.div_ceil(8);
let end_bit_read = end_idx * self.num_bits;
let end_byte_read = (end_bit_read + 7) / 8;
assert!(
end_byte_read <= data.len(),
end_byte_read as usize <= data_offset + data.len(),
"Requested index is out of bounds."
);
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
for (out, idx) in output.iter_mut().zip(start_idx..) {
*out = self.get(idx, data) as u32;
*out = self.get_from_subset(idx, data_offset, data) as u32;
}
};
@@ -160,24 +220,24 @@ impl BitUnpacker {
// We want the start of the fast track to start align with bytes.
// A sufficient condition is to start with an idx that is a multiple of 8,
// so highway start is the closest multiple of 8 that is >= start_idx.
let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
let highway_start: u32 = start_idx + entrance_ramp_len;
if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
// We don't have enough values to have even a single block of highway.
// Let's just supply the values the simple way.
get_batch_ramp(start_idx, output);
return;
}
let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
// Entrance ramp
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
// Highway
let mut offset = (highway_start as usize * self.num_bits) / 8;
let mut offset = ((highway_start * self.num_bits) as usize / 8) - data_offset;
let mut output_cursor = (highway_start - start_idx) as usize;
for _ in 0..num_blocks {
offset += BitPacker1x.decompress(
@@ -189,7 +249,7 @@ impl BitUnpacker {
}
// Exit ramp
let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
get_batch_ramp(highway_end, &mut output[output_cursor..]);
}
@@ -199,16 +259,27 @@ impl BitUnpacker {
id_range: Range<u32>,
data: &[u8],
positions: &mut Vec<u32>,
) {
self.get_ids_for_value_range_from_subset(range, id_range, 0, data, positions)
}
pub fn get_ids_for_value_range_from_subset(
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
if self.bit_width() > 32 {
self.get_ids_for_value_range_slow(range, id_range, data, positions)
self.get_ids_for_value_range_slow(range, id_range, data_offset, data, positions)
} else {
if *range.start() > u32::MAX as u64 {
positions.clear();
return;
}
let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
self.get_ids_for_value_range_fast(range_u32, id_range, data_offset, data, positions)
}
}
@@ -216,6 +287,7 @@ impl BitUnpacker {
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
@@ -223,7 +295,7 @@ impl BitUnpacker {
for i in id_range {
// If we cared we could make this branchless, but the slow implementation should rarely
// kick in.
let val = self.get(i, data);
let val = self.get_from_subset(i, data_offset, data);
if range.contains(&val) {
positions.push(i);
}
@@ -234,11 +306,12 @@ impl BitUnpacker {
&self,
value_range: RangeInclusive<u32>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
positions.resize(id_range.len(), 0u32);
self.get_batch_u32s(id_range.start, data, positions);
self.get_batch_u32s(id_range.start, data_offset, data, positions);
crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
}
}
@@ -329,14 +402,14 @@ mod test {
fn test_get_batch_panics_over_32_bits() {
let bitunpacker = BitUnpacker::new(33);
let mut output: [u32; 1] = [0u32];
bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
bitunpacker.get_batch_u32s(0, 0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
}
#[test]
fn test_get_batch_limit() {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
bitunpacker.get_batch_u32s(8 * 4 - 3, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
#[test]
@@ -345,7 +418,7 @@ mod test {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
// We are missing exactly one bit.
bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
bitunpacker.get_batch_u32s(8 * 4 - 2, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
proptest::proptest! {
@@ -368,7 +441,7 @@ mod test {
for len in [0, 1, 2, 32, 33, 34, 64] {
for start_idx in 0u32..32u32 {
output.resize(len, 0);
bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
bitunpacker.get_batch_u32s(start_idx, 0, &buffer, &mut output);
for (i, output_byte) in output.iter().enumerate() {
let expected = (start_idx + i as u32) & mask;
assert_eq!(*output_byte, expected);

View File

@@ -16,13 +16,13 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
serde = "1.0.152"
serde = { version = "1.0.152", features = ["derive"] }
downcast-rs = "2.0.1"
[dev-dependencies]
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
rand = "0.8"
binggan = "0.14.0"
[[bench]]

View File

@@ -1,6 +1,6 @@
use binggan::{InputGroup, black_box};
use common::*;
use tantivy_columnar::Column;
use tantivy_columnar::{Column, ValueRange};
pub mod common;
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
runner.register("access_first_vals", |column| {
let mut sum = 0;
const BLOCK_SIZE: usize = 32;
let mut docs = vec![0; BLOCK_SIZE];
let mut buffer = vec![None; BLOCK_SIZE];
let mut docs = Vec::with_capacity(BLOCK_SIZE);
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
// fill docs
#[allow(clippy::needless_range_loop)]
docs.clear();
for idx in 0..BLOCK_SIZE {
docs[idx] = idx as u32 + i;
docs.push(idx as u32 + i);
}
column.first_vals(&docs, &mut buffer);
buffer.clear();
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
for val in buffer.iter() {
let Some(val) = val else { continue };
sum += *val;

View File

@@ -9,7 +9,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_co
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55_000_u64)
.map(|num| num + rng.random::<u8>() as u64)
.map(|num| num + rng.r#gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);

View File

@@ -6,7 +6,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_valu
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55_000_u64)
.map(|num| num + rng.random::<u8>() as u64)
.map(|num| num + rng.r#gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);

View File

@@ -40,7 +40,14 @@ fn main() {
let columnar_readers = columnar_readers.iter().collect::<Vec<_>>();
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
merge_columnar(
&columnar_readers,
&[],
merge_row_order.into(),
&mut out,
|| false,
)
.unwrap();
Some(out.len() as u64)
},
);

View File

@@ -8,7 +8,7 @@ const TOTAL_NUM_VALUES: u32 = 1_000_000;
fn gen_optional_index(fill_ratio: f64) -> OptionalIndex {
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let vals: Vec<u32> = (0..TOTAL_NUM_VALUES)
.map(|_| rng.random_bool(fill_ratio))
.map(|_| rng.gen_bool(fill_ratio))
.enumerate()
.filter(|(_pos, val)| *val)
.map(|(pos, _)| pos as u32)
@@ -25,7 +25,7 @@ fn random_range_iterator(
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let mut current = start;
std::iter::from_fn(move || {
current += rng.random_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
if current >= end { None } else { Some(current) }
})
}

View File

@@ -39,7 +39,7 @@ fn get_data_50percent_item() -> Vec<u128> {
let mut data = vec![];
for _ in 0..300_000 {
let val = rng.random_range(1..=100);
let val = rng.gen_range(1..=100);
data.push(val);
}
data.push(SINGLE_ITEM);

View File

@@ -34,7 +34,7 @@ fn get_data_50percent_item() -> Vec<u128> {
let mut data = vec![];
for _ in 0..300_000 {
let val = rng.random_range(1..=100);
let val = rng.gen_range(1..=100);
data.push(val);
}
data.push(SINGLE_ITEM);

View File

@@ -29,20 +29,12 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -1,6 +1,7 @@
mod dictionary_encoded;
mod serialize;
use std::cell::RefCell;
use std::fmt::{self, Debug};
use std::io::Write;
use std::ops::{Range, RangeInclusive};
@@ -19,6 +20,11 @@ use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
use crate::column_values::{ColumnValues, monotonic_map_column};
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
thread_local! {
static ROWS: RefCell<Vec<RowId>> = const { RefCell::new(Vec::new()) };
static DOCS: RefCell<Vec<DocId>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone)]
pub struct Column<T = u64> {
pub index: ColumnIndex,
@@ -85,33 +91,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
#[inline]
pub fn first(&self, doc_id: DocId) -> Option<T> {
self.values_for_doc(doc_id).next()
}
/// Load the first value for each docid in the provided slice.
#[inline]
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
match &self.index {
ColumnIndex::Empty { .. } => {}
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
ColumnIndex::Optional(optional_index) => {
for (i, docid) in docids.iter().enumerate() {
output[i] = optional_index
.rank_if_exists(*docid)
.map(|rowid| self.values.get_val(rowid));
}
}
ColumnIndex::Multivalued(multivalued_index) => {
for (i, docid) in docids.iter().enumerate() {
let range = multivalued_index.range(*docid);
let is_empty = range.start == range.end;
if !is_empty {
output[i] = Some(self.values.get_val(range.start));
}
}
}
}
pub fn first(&self, row_id: RowId) -> Option<T> {
self.values_for_doc(row_id).next()
}
/// Translates a block of docids to row_ids.
@@ -143,7 +124,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
#[inline]
pub fn get_docids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
selected_docid_range: Range<u32>,
doc_ids: &mut Vec<u32>,
) {
@@ -168,6 +149,181 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
}
// Separate impl block for methods requiring `Default` for `T`.
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
/// Load the first value for each docid in the provided slice.
///
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
/// The `values` vector is populated with the values of the remaining documents.
#[inline]
pub fn first_vals_in_value_range(
&self,
input_docs: &[DocId],
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
value_range: ValueRange<T>,
) {
match (&self.index, value_range) {
(ColumnIndex::Empty { .. }, value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
if nulls_match {
for &doc in input_docs {
output.push(crate::ComparableDoc {
doc,
sort_key: None,
});
}
}
}
(ColumnIndex::Full, value_range) => {
self.values
.get_vals_in_value_range(input_docs, input_docs, output, value_range);
}
(ColumnIndex::Optional(optional_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
let fallback_needed = ROWS.with(|rows_cell| {
DOCS.with(|docs_cell| {
let mut rows = rows_cell.borrow_mut();
let mut docs = docs_cell.borrow_mut();
rows.clear();
docs.clear();
let mut has_nulls = false;
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
rows.push(row_id);
docs.push(doc_id);
} else {
has_nulls = true;
if nulls_match {
break;
}
}
}
if !has_nulls || !nulls_match {
self.values.get_vals_in_value_range(
&rows,
&docs,
output,
value_range.clone(),
);
return false;
}
true
})
});
if fallback_needed {
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
let val = self.values.get_val(row_id);
let value_matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if value_matches {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: None,
});
}
}
}
}
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
for i in 0..input_docs.len() {
let docid = input_docs[i];
let row_range = multivalued_index.range(docid);
let is_empty = row_range.start == row_range.end;
if !is_empty {
let val = self.values.get_val(row_range.start);
let matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: None,
});
}
}
}
}
}
}
/// A range of values.
///
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
/// is outweighed by the time spent processing a batch.
///
/// Implementers should pattern match on the variants to use optimized loops for each case.
#[derive(Clone, Debug)]
pub enum ValueRange<T> {
/// A range that includes both start and end.
Inclusive(RangeInclusive<T>),
/// A range that matches all values.
All,
/// A range that matches all values greater than the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThan(T, bool),
/// A range that matches all values greater than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThanOrEqual(T, bool),
/// A range that matches all values less than the threshold.
/// The boolean flag indicates if null values should be included.
LessThan(T, bool),
/// A range that matches all values less than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
LessThanOrEqual(T, bool),
}
impl BinarySerializable for Cardinality {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
self.to_code().serialize(writer)

View File

@@ -2,7 +2,7 @@ use std::io;
use std::io::Write;
use std::sync::Arc;
use common::OwnedBytes;
use common::file_slice::FileSlice;
use sstable::Dictionary;
use crate::column::{BytesColumn, Column};
@@ -41,12 +41,13 @@ pub fn serialize_column_mappable_to_u64<T: MonotonicallyMappableToU64>(
}
pub fn open_column_u64<T: MonotonicallyMappableToU64>(
bytes: OwnedBytes,
file_slice: FileSlice,
format_version: Version,
) -> io::Result<Column<T>> {
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -61,12 +62,13 @@ pub fn open_column_u64<T: MonotonicallyMappableToU64>(
}
pub fn open_column_u128<T: MonotonicallyMappableToU128>(
bytes: OwnedBytes,
file_slice: FileSlice,
format_version: Version,
) -> io::Result<Column<T>> {
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -84,12 +86,13 @@ pub fn open_column_u128<T: MonotonicallyMappableToU128>(
///
/// See [`open_u128_as_compact_u64`] for more details.
pub fn open_column_u128_as_compact_u64(
bytes: OwnedBytes,
file_slice: FileSlice,
format_version: Version,
) -> io::Result<Column<u64>> {
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -103,11 +106,21 @@ pub fn open_column_u128_as_compact_u64(
})
}
pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Result<BytesColumn> {
let (body, dictionary_len_bytes) = data.rsplit(4);
let dictionary_len = u32::from_le_bytes(dictionary_len_bytes.as_slice().try_into().unwrap());
pub fn open_column_bytes(
file_slice: FileSlice,
format_version: Version,
) -> io::Result<BytesColumn> {
let (body, dictionary_len_bytes) = file_slice.split_from_end(4);
let dictionary_len = u32::from_le_bytes(
dictionary_len_bytes
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
let (dictionary_bytes, column_bytes) = body.split(dictionary_len as usize);
let dictionary = Arc::new(Dictionary::from_bytes(dictionary_bytes)?);
let dictionary = Arc::new(Dictionary::open(dictionary_bytes)?);
let term_ord_column = crate::column::open_column_u64::<u64>(column_bytes, format_version)?;
Ok(BytesColumn {
dictionary,
@@ -115,7 +128,7 @@ pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Resul
})
}
pub fn open_column_str(data: OwnedBytes, format_version: Version) -> io::Result<StrColumn> {
let bytes_column = open_column_bytes(data, format_version)?;
pub fn open_column_str(file_slice: FileSlice, format_version: Version) -> io::Result<StrColumn> {
let bytes_column = open_column_bytes(file_slice, format_version)?;
Ok(StrColumn::wrap(bytes_column))
}

View File

@@ -95,7 +95,7 @@ pub fn merge_column_index<'a>(
#[cfg(test)]
mod tests {
use common::OwnedBytes;
use common::file_slice::FileSlice;
use crate::column_index::merge::detect_cardinality;
use crate::column_index::multivalued_index::{
@@ -178,7 +178,7 @@ mod tests {
let mut output = Vec::new();
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
let multivalue =
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
assert_eq!(&start_indexes, &[0, 3, 5]);
}
@@ -216,7 +216,7 @@ mod tests {
let mut output = Vec::new();
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
let multivalue =
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
assert_eq!(&start_indexes, &[0, 3, 5, 6]);
}

View File

@@ -3,7 +3,8 @@ use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
use common::{CountingWriter, OwnedBytes};
use common::CountingWriter;
use common::file_slice::FileSlice;
use super::optional_index::{open_optional_index, serialize_optional_index};
use super::{OptionalIndex, SerializableOptionalIndex, Set};
@@ -44,21 +45,26 @@ pub fn serialize_multivalued_index(
}
pub fn open_multivalued_index(
bytes: OwnedBytes,
file_slice: FileSlice,
format_version: Version,
) -> io::Result<MultiValueIndex> {
match format_version {
Version::V1 => {
let start_index_column: Arc<dyn ColumnValues<RowId>> =
load_u64_based_column_values(bytes)?;
load_u64_based_column_values(file_slice)?;
Ok(MultiValueIndex::MultiValueIndexV1(MultiValueIndexV1 {
start_index_column,
}))
}
Version::V2 => {
let (body_bytes, optional_index_len) = bytes.rsplit(4);
let optional_index_len =
u32::from_le_bytes(optional_index_len.as_slice().try_into().unwrap());
let (body_bytes, optional_index_len) = file_slice.split_from_end(4);
let optional_index_len = u32::from_le_bytes(
optional_index_len
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
let (optional_index_bytes, start_index_bytes) =
body_bytes.split(optional_index_len as usize);
let optional_index = open_optional_index(optional_index_bytes)?;
@@ -185,8 +191,8 @@ impl MultiValueIndex {
};
let mut buffer = Vec::new();
serialize_multivalued_index(&serializable_multivalued_index, &mut buffer).unwrap();
let bytes = OwnedBytes::new(buffer);
open_multivalued_index(bytes, Version::V2).unwrap()
let file_slice = FileSlice::from(buffer);
open_multivalued_index(file_slice, Version::V2).unwrap()
}
pub fn get_start_index_column(&self) -> &Arc<dyn crate::ColumnValues<RowId>> {
@@ -333,7 +339,7 @@ mod tests {
use std::ops::Range;
use super::MultiValueIndex;
use crate::{ColumnarReader, DynamicColumn};
use crate::{ColumnarReader, DynamicColumn, ValueRange};
fn index_to_pos_helper(
index: &MultiValueIndex,
@@ -413,7 +419,7 @@ mod tests {
assert_eq!(row_id_range, 0..4);
let check = |range, expected| {
let full_range = 0..=u64::MAX;
let full_range = ValueRange::All;
let mut docids = Vec::new();
column.get_docids_for_value_range(full_range, range, &mut docids);
assert_eq!(docids, expected);

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
mod set;
mod set_block;
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes, VInt};
pub use set::{SelectCursor, Set, SetCodec};
use set_block::{
@@ -268,8 +269,8 @@ impl OptionalIndex {
);
let mut buffer = Vec::new();
serialize_optional_index(&row_ids, num_rows, &mut buffer).unwrap();
let bytes = OwnedBytes::new(buffer);
open_optional_index(bytes).unwrap()
let file_slice = FileSlice::from(buffer);
open_optional_index(file_slice).unwrap()
}
pub fn num_docs(&self) -> RowId {
@@ -486,10 +487,17 @@ fn deserialize_optional_index_block_metadatas(
(block_metas.into_boxed_slice(), non_null_rows_before_block)
}
pub fn open_optional_index(bytes: OwnedBytes) -> io::Result<OptionalIndex> {
let (mut bytes, num_non_empty_blocks_bytes) = bytes.rsplit(2);
let num_non_empty_block_bytes =
u16::from_le_bytes(num_non_empty_blocks_bytes.as_slice().try_into().unwrap());
pub fn open_optional_index(file_slice: FileSlice) -> io::Result<OptionalIndex> {
let (bytes, num_non_empty_blocks_bytes) = file_slice.split_from_end(2);
let num_non_empty_block_bytes = u16::from_le_bytes(
num_non_empty_blocks_bytes
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
let mut bytes = bytes.read_bytes()?;
let num_docs = VInt::deserialize_u64(&mut bytes)? as u32;
let block_metas_num_bytes =
num_non_empty_block_bytes as usize * SERIALIZED_BLOCK_META_NUM_BYTES;

View File

@@ -59,7 +59,7 @@ fn test_with_random_sets_simple() {
let vals = 10..ELEMENTS_PER_BLOCK * 2;
let mut out: Vec<u8> = Vec::new();
serialize_optional_index(&vals, 100, &mut out).unwrap();
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
let ranks: Vec<u32> = (65_472u32..65_473u32).collect();
let els: Vec<u32> = ranks.iter().copied().map(|rank| rank + 10).collect();
let mut select_cursor = null_index.select_cursor();
@@ -102,7 +102,7 @@ impl<'a> Iterable<RowId> for &'a [bool] {
fn test_null_index(data: &[bool]) {
let mut out: Vec<u8> = Vec::new();
serialize_optional_index(&data, data.len() as RowId, &mut out).unwrap();
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
let orig_idx_with_value: Vec<u32> = data
.iter()
.enumerate()
@@ -223,3 +223,170 @@ fn test_optional_index_for_tests() {
assert!(!optional_index.contains(3));
assert_eq!(optional_index.num_docs(), 4);
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::*;
const TOTAL_NUM_VALUES: u32 = 1_000_000;
fn gen_bools(fill_ratio: f64) -> OptionalIndex {
let mut out = Vec::new();
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let vals: Vec<RowId> = (0..TOTAL_NUM_VALUES)
.map(|_| rng.gen_bool(fill_ratio))
.enumerate()
.filter(|(_pos, val)| *val)
.map(|(pos, _)| pos as RowId)
.collect();
serialize_optional_index(&&vals[..], TOTAL_NUM_VALUES, &mut out).unwrap();
open_optional_index(FileSlice::from(out)).unwrap()
}
fn random_range_iterator(
start: u32,
end: u32,
avg_step_size: u32,
avg_deviation: u32,
) -> impl Iterator<Item = u32> {
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let mut current = start;
std::iter::from_fn(move || {
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
if current >= end { None } else { Some(current) }
})
}
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
let ratio = percent / 100.0;
let step_size = (1f32 / ratio) as u32;
let deviation = step_size - 1;
random_range_iterator(0, num_values, step_size, deviation)
}
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
walk_over_data_from_positions(
codec,
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
)
}
fn walk_over_data_from_positions(
codec: &OptionalIndex,
positions: impl Iterator<Item = u32>,
) -> Option<u32> {
let mut dense_idx: Option<u32> = None;
for idx in positions {
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
}
dense_idx
}
#[bench]
fn bench_translate_orig_to_codec_1percent_filled_10percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.01f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_5percent_filled_10percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.05f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_5percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.05f64);
bench.iter(|| walk_over_data(&codec, 1000));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_1percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.01f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_10percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.1f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_90percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.9f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_10percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.1f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_50percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.5f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_90percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.9f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 0.005f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_10percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.1f64, 0.005f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_10percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 10f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_full_scan(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 100f32, bench);
}
fn bench_translate_codec_to_orig_util(
percent_filled: f64,
percent_hit: f32,
bench: &mut Bencher,
) {
let codec = gen_bools(percent_filled);
let num_non_nulls = codec.num_non_nulls();
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
(0..num_non_nulls).collect()
} else {
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
};
let mut output = vec![0u32; idxs.len()];
bench.iter(|| {
output.copy_from_slice(&idxs[..]);
codec.select_batch(&mut output);
});
}
#[bench]
fn bench_translate_codec_to_orig_90percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.9f64, 0.005, bench);
}
#[bench]
fn bench_translate_codec_to_orig_90percent_filled_full_scan(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.9f64, 100.0f32, bench);
}
}

View File

@@ -1,7 +1,8 @@
use std::io;
use std::io::Write;
use common::{CountingWriter, OwnedBytes};
use common::file_slice::FileSlice;
use common::{CountingWriter, HasLen};
use super::OptionalIndex;
use super::multivalued_index::SerializableMultivalueIndex;
@@ -65,27 +66,28 @@ pub fn serialize_column_index(
/// Open a serialized column index.
pub fn open_column_index(
mut bytes: OwnedBytes,
file_slice: FileSlice,
format_version: Version,
) -> io::Result<ColumnIndex> {
if bytes.is_empty() {
if file_slice.len() == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Failed to deserialize column index. Empty buffer.",
));
}
let cardinality_code = bytes[0];
let (header, body) = file_slice.split(1);
let cardinality_code = header.read_bytes()?.as_slice()[0];
let cardinality = Cardinality::try_from_code(cardinality_code)?;
bytes.advance(1);
match cardinality {
Cardinality::Full => Ok(ColumnIndex::Full),
Cardinality::Optional => {
let optional_index = super::optional_index::open_optional_index(bytes)?;
let optional_index = super::optional_index::open_optional_index(body)?;
Ok(ColumnIndex::Optional(optional_index))
}
Cardinality::Multivalued => {
let multivalue_index =
super::multivalued_index::open_multivalued_index(bytes, format_version)?;
super::multivalued_index::open_multivalued_index(body, format_version)?;
Ok(ColumnIndex::Multivalued(multivalue_index))
}
}

View File

@@ -7,13 +7,15 @@
//! - Monotonically map values to u64/u128
use std::fmt::Debug;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use std::sync::Arc;
use downcast_rs::DowncastSync;
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
use crate::column::ValueRange;
mod merge;
pub(crate) mod monotonic_mapping;
pub(crate) mod monotonic_mapping_u128;
@@ -27,8 +29,7 @@ mod monotonic_column;
pub(crate) use merge::MergedColumnValues;
pub use stats::ColumnStats;
pub use u64_based::{
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values,
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values, serialize_u64_based_column_values,
};
pub use u128_based::{
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
@@ -109,6 +110,307 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
}
/// Load the values for the provided docids.
///
/// The values are filtered by the provided value range.
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let len = input_indexes.len();
let mut read_head = 0;
match value_range {
ValueRange::All => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
read_head += 4;
}
}
ValueRange::Inclusive(ref range) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if range.contains(&val0) {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if range.contains(&val1) {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if range.contains(&val2) {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if range.contains(&val3) {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 > *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 > *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 > *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 > *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 < *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 < *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 < *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 < *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
}
// Process remaining elements (0 to 3)
while read_head < len {
let idx = input_indexes[read_head];
let doc = input_doc_ids[read_head];
let val = self.get_val(idx);
let matches = match value_range {
// 'value_range' is still moved here. This is the outer `value_range`
ValueRange::All => true,
ValueRange::Inclusive(ref r) => r.contains(&val),
ValueRange::GreaterThan(ref t, _) => val > *t,
ValueRange::GreaterThanOrEqual(ref t, _) => val >= *t,
ValueRange::LessThan(ref t, _) => val < *t,
ValueRange::LessThanOrEqual(ref t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(val),
});
}
read_head += 1;
}
}
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
@@ -129,15 +431,54 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// Note that position == docid for single value fast fields
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
for idx in row_id_range {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
match value_range {
ValueRange::Inclusive(range) => {
for idx in row_id_range {
let val = self.get_val(idx);
if range.contains(&val) {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val > threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val >= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val < threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val <= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::All => {
row_id_hits.extend(row_id_range);
}
}
}
@@ -193,6 +534,17 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn num_vals(&self) -> u32 {
0
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let _ = (input_indexes, input_doc_ids, output, value_range);
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
}
}
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
@@ -206,6 +558,18 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
self.as_ref().get_vals_opt(indexes, output)
}
#[inline(always)]
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
self.as_ref()
.get_vals_in_value_range(input_indexes, input_doc_ids, output, value_range)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
@@ -234,7 +598,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<T>,
range: ValueRange<T>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {

View File

@@ -1,8 +1,9 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use crate::ColumnValues;
use crate::column::ValueRange;
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
struct MonotonicMappingColumn<C, T, Input> {
@@ -80,16 +81,52 @@ where
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<Output>,
range: ValueRange<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
match range {
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
ValueRange::Inclusive(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
),
doc_id_range,
positions,
),
ValueRange::All => self.from_column.get_row_ids_for_value_range(
ValueRange::All,
doc_id_range,
positions,
),
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::GreaterThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThanOrEqual(
self.monotonic_mapping.inverse(threshold),
false,
),
doc_id_range,
positions,
)
}
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::LessThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::LessThanOrEqual(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
)
}
}
}
// We voluntarily do not implement get_range as it yields a regression,

View File

@@ -2,7 +2,8 @@ use std::io;
use std::io::Write;
use std::num::NonZeroU64;
use common::{BinarySerializable, VInt};
use common::file_slice::FileSlice;
use common::{BinarySerializable, HasLen, VInt};
use crate::RowId;
@@ -27,6 +28,55 @@ impl ColumnStats {
}
}
impl ColumnStats {
/// Deserialize from the tail of the given FileSlice, and return the stats and remaining prefix
/// FileSlice.
pub fn deserialize_from_tail(file_slice: FileSlice) -> io::Result<(Self, FileSlice)> {
// [`deserialize_with_size`] deserializes 4 variable-width encoded u64s, which
// could end up being, in the worst case, 9 bytes each. this is where the 36 comes from
let (stats, _) = file_slice.clone().split(36.min(file_slice.len())); // hope that's enough bytes
let mut stats = stats.read_bytes()?;
let (stats, stats_nbytes) = ColumnStats::deserialize_with_size(&mut stats)?;
let (_, remainder) = file_slice.split(stats_nbytes);
Ok((stats, remainder))
}
/// Same as [`BinarySeerializable::deserialize`] but also returns the number of bytes
/// consumed from the reader `R`
fn deserialize_with_size<R: io::Read>(reader: &mut R) -> io::Result<(Self, usize)> {
let mut nbytes = 0;
let (min_value, len) = VInt::deserialize_with_size(reader)?;
let min_value = min_value.0;
nbytes += len;
let (gcd, len) = VInt::deserialize_with_size(reader)?;
let gcd = gcd.0;
let gcd = NonZeroU64::new(gcd)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "GCD of 0 is forbidden"))?;
nbytes += len;
let (amplitude, len) = VInt::deserialize_with_size(reader)?;
let amplitude = amplitude.0 * gcd.get();
let max_value = min_value + amplitude;
nbytes += len;
let (num_rows, len) = VInt::deserialize_with_size(reader)?;
let num_rows = num_rows.0 as RowId;
nbytes += len;
Ok((
ColumnStats {
min_value,
max_value,
num_rows,
gcd,
},
nbytes,
))
}
}
impl BinarySerializable for ColumnStats {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
VInt(self.min_value).serialize(writer)?;

View File

@@ -25,6 +25,7 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
use tantivy_bitpacker::{BitPacker, BitUnpacker};
use crate::RowId;
use crate::column::ValueRange;
use crate::column_values::ColumnValues;
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
@@ -338,14 +339,48 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u64>,
value_range: ValueRange<u64>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
match value_range {
ValueRange::Inclusive(value_range) => {
let value_range = ValueRange::Inclusive(
self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32),
);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
}
ValueRange::GreaterThan(threshold, _) => {
let value_range =
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let value_range =
ValueRange::GreaterThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThan(threshold, _) => {
let value_range =
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let value_range =
ValueRange::LessThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
}
}
}
@@ -375,10 +410,47 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u128>,
value_range: ValueRange<u128>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = match value_range {
ValueRange::Inclusive(value_range) => value_range,
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
return;
}
ValueRange::GreaterThan(threshold, _) => {
let max = self.max_value();
if threshold >= max {
return;
}
(threshold + 1)..=max
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let max = self.max_value();
if threshold > max {
return;
}
threshold..=max
}
ValueRange::LessThan(threshold, _) => {
let min = self.min_value();
if threshold <= min {
return;
}
min..=(threshold - 1)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let min = self.min_value();
if threshold < min {
return;
}
min..=threshold
}
};
if value_range.start() > value_range.end() {
return;
}
@@ -560,7 +632,7 @@ mod tests {
.collect::<Vec<_>>();
let mut positions = Vec::new();
decompressor.get_row_ids_for_value_range(
range,
ValueRange::Inclusive(range),
0..decompressor.num_vals(),
&mut positions,
);
@@ -604,7 +676,11 @@ mod tests {
let val = *val;
let pos = pos as u32;
let mut positions = Vec::new();
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
decomp.get_row_ids_for_value_range(
ValueRange::Inclusive(val..=val),
pos..pos + 1,
&mut positions,
);
assert_eq!(positions, vec![pos]);
}
@@ -746,7 +822,11 @@ mod tests {
doc_id_range: Range<u32>,
) -> Vec<u32> {
let mut positions = Vec::new();
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
column.get_row_ids_for_value_range(
ValueRange::Inclusive(value_range),
doc_id_range,
&mut positions,
);
positions
}
@@ -769,7 +849,7 @@ mod tests {
];
let mut out = Vec::new();
serialize_column_values_u128(&&vals[..], &mut out).unwrap();
let decomp = open_u128_mapped(OwnedBytes::new(out)).unwrap();
let decomp = open_u128_mapped(FileSlice::from(out)).unwrap();
let complete_range = 0..vals.len() as u32;
assert_eq!(
@@ -823,6 +903,7 @@ mod tests {
let _data = test_aux_vals(vals);
}
use common::file_slice::FileSlice;
use proptest::prelude::*;
fn num_strategy() -> impl Strategy<Value = u128> {

View File

@@ -5,7 +5,8 @@ use std::sync::Arc;
mod compact_space;
use common::{BinarySerializable, OwnedBytes, VInt};
use common::file_slice::FileSlice;
use common::{BinarySerializable, VInt};
pub use compact_space::{
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
};
@@ -101,8 +102,9 @@ impl U128FastFieldCodecType {
/// Returns the correct codec reader wrapped in the `Arc` for the data.
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
mut bytes: OwnedBytes,
file_slice: FileSlice,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let mut bytes = file_slice.read_bytes()?;
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceDecompressor::open(bytes)?;
@@ -120,7 +122,8 @@ pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
/// # Notice
/// In case there are new codecs added, check for usages of `CompactSpaceDecompressorU64` and
/// also handle the new codecs.
pub fn open_u128_as_compact_u64(mut bytes: OwnedBytes) -> io::Result<Arc<dyn ColumnValues<u64>>> {
pub fn open_u128_as_compact_u64(file_slice: FileSlice) -> io::Result<Arc<dyn ColumnValues<u64>>> {
let mut bytes = file_slice.read_bytes()?;
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceU64Accessor::open(bytes)?;

View File

@@ -1,11 +1,14 @@
use std::io::{self, Write};
use std::num::NonZeroU64;
use std::ops::{Range, RangeInclusive};
use std::sync::{Arc, OnceLock};
use common::{BinarySerializable, OwnedBytes};
use common::file_slice::FileSlice;
use common::{BinarySerializable, HasLen, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
use crate::column::ValueRange;
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
use crate::{ColumnValues, RowId};
@@ -13,9 +16,40 @@ use crate::{ColumnValues, RowId};
/// fast field is required.
#[derive(Clone)]
pub struct BitpackedReader {
data: OwnedBytes,
data: FileSlice,
bit_unpacker: BitUnpacker,
stats: ColumnStats,
blocks: Arc<[OnceLock<Block>]>,
}
impl BitpackedReader {
#[inline(always)]
fn unpack_val(&self, doc: u32) -> u64 {
let block_num = self.bit_unpacker.block_num(doc);
if block_num == 0 && self.blocks.len() == 0 {
return 0;
}
let block = self.blocks[block_num].get_or_init(|| {
let block_range = self.bit_unpacker.block(block_num, self.data.len());
let offset = block_range.start;
let data = self
.data
.slice(block_range)
.read_bytes()
.expect("Failed to read column values.");
Block { offset, data }
});
self.bit_unpacker
.get_from_subset(doc, block.offset, &block.data)
}
}
struct Block {
offset: usize,
data: OwnedBytes,
}
#[inline(always)]
@@ -41,6 +75,12 @@ fn transform_range_before_linear_transformation(
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
@@ -51,8 +91,9 @@ fn transform_range_before_linear_transformation(
impl ColumnValues for BitpackedReader {
#[inline(always)]
fn get_val(&self, doc: u32) -> u64 {
self.stats.min_value + self.stats.gcd.get() * self.bit_unpacker.get(doc, &self.data)
self.stats.min_value + self.stats.gcd.get() * self.unpack_val(doc)
}
#[inline]
fn min_value(&self) -> u64 {
self.stats.min_value
@@ -66,24 +107,329 @@ impl ColumnValues for BitpackedReader {
self.stats.num_rows
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
value_range: ValueRange<u64>,
) {
match value_range {
ValueRange::All => {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
}
ValueRange::Inclusive(range) => {
if let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
{
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if transformed_range.contains(&raw_val) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold >= self.stats.max_value {
// All filtered out
} else {
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val > raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold > self.stats.max_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val >= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold <= self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val < raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold < self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = diff / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val <= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
}
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<u64>,
range: ValueRange<u64>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
match range {
ValueRange::All => {
positions.extend(doc_id_range);
return;
}
ValueRange::Inclusive(range) => {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
// TODO: This does not use the `self.blocks` cache, because callers are usually
// already doing sequential, and fairly dense reads. Fix it to
// iterate over blocks if that assumption turns out to be incorrect!
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold >= self.stats.max_value {
return;
}
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
// We want raw > raw_threshold.
// bit_unpacker.get_ids_for_value_range_from_subset takes a RangeInclusive.
// We can construct a RangeInclusive: (raw_threshold + 1) ..= u64::MAX
// But max raw value is known? (max_value - min_value) / gcd.
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = (raw_threshold + 1)..=max_raw;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold > self.stats.max_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
// We want raw >= raw_threshold.
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = raw_threshold..=max_raw;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold <= self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw < raw_threshold_limit
// raw <= raw_threshold_limit - 1
let raw_threshold_limit = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
if raw_threshold_limit == 0 {
return;
}
let transformed_range = 0..=(raw_threshold_limit - 1);
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold < self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw <= raw_threshold.
let raw_threshold = diff / gcd;
let transformed_range = 0..=raw_threshold;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
}
}
}
@@ -127,14 +473,20 @@ impl ColumnCodec for BitpackedCodec {
type Estimator = BitpackedCodecEstimator;
/// Opens a fast field given a file.
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
let stats = ColumnStats::deserialize(&mut data)?;
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let (stats, data) = ColumnStats::deserialize_from_tail(file_slice)?;
let num_bits = num_bits(&stats);
let bit_unpacker = BitUnpacker::new(num_bits);
let block_count = bit_unpacker.block_count(data.len());
Ok(BitpackedReader {
data,
bit_unpacker,
stats,
blocks: (0..block_count)
.into_iter()
.map(|_| OnceLock::new())
.collect(),
})
}
}

View File

@@ -1,8 +1,10 @@
use std::io;
use std::io::Write;
use std::sync::Arc;
use std::{io, iter};
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, OnceLock};
use common::{BinarySerializable, CountingWriter, DeserializeFrom, OwnedBytes};
use common::file_slice::FileSlice;
use common::{BinarySerializable, CountingWriter, DeserializeFrom, HasLen, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
@@ -172,32 +174,63 @@ impl ColumnCodec<u64> for BlockwiseLinearCodec {
type Estimator = BlockwiseLinearEstimator;
fn load(mut bytes: OwnedBytes) -> io::Result<Self::ColumnValues> {
let stats = ColumnStats::deserialize(&mut bytes)?;
let footer_len: u32 = (&bytes[bytes.len() - 4..]).deserialize()?;
let footer_offset = bytes.len() - 4 - footer_len as usize;
let (data, mut footer) = bytes.split(footer_offset);
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let (stats, body) = ColumnStats::deserialize_from_tail(file_slice)?;
let (_, footer) = body.clone().split_from_end(4);
let footer_len: u32 = footer.read_bytes()?.as_slice().deserialize()?;
let (data, footer) = body.split_from_end(footer_len as usize + 4);
let mut footer = footer.read_bytes()?;
let num_blocks = compute_num_blocks(stats.num_rows);
let mut blocks: Vec<Block> = iter::repeat_with(|| Block::deserialize(&mut footer))
.take(num_blocks as usize)
.collect::<io::Result<_>>()?;
let mut start_offset = 0;
for block in &mut blocks {
let mut blocks = Vec::with_capacity(num_blocks as usize);
for _ in 0..num_blocks {
let mut block = Block::deserialize(&mut footer)?;
let len = (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
block.data_start_offset = start_offset;
start_offset += (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
blocks.push(BlockWithData {
block,
file_slice: data.slice(start_offset..(start_offset + len).min(data.len())),
data: Default::default(),
});
start_offset += len;
}
Ok(BlockwiseLinearReader {
blocks: blocks.into_boxed_slice().into(),
data,
stats,
})
}
}
struct BlockWithData {
block: Block,
file_slice: FileSlice,
data: OnceLock<OwnedBytes>,
}
impl Deref for BlockWithData {
type Target = Block;
fn deref(&self) -> &Self::Target {
&self.block
}
}
impl DerefMut for BlockWithData {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.block
}
}
#[derive(Clone)]
pub struct BlockwiseLinearReader {
blocks: Arc<[Block]>,
data: OwnedBytes,
blocks: Arc<[BlockWithData]>,
stats: ColumnStats,
}
@@ -208,7 +241,9 @@ impl ColumnValues for BlockwiseLinearReader {
let idx_within_block = idx % BLOCK_SIZE;
let block = &self.blocks[block_id];
let interpoled_val: u64 = block.line.eval(idx_within_block);
let block_bytes = &self.data[block.data_start_offset..];
let block_bytes = block
.data
.get_or_init(|| block.file_slice.read_bytes().unwrap());
let bitpacked_diff = block.bit_unpacker.get(idx_within_block, block_bytes);
// TODO optimize me! the line parameters could be tweaked to include the multiplication and
// remove the dependency.

View File

@@ -1,5 +1,6 @@
use std::io;
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes};
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
@@ -190,7 +191,8 @@ impl ColumnCodec for LinearCodec {
type Estimator = LinearCodecEstimator;
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let mut data = file_slice.read_bytes()?;
let stats = ColumnStats::deserialize(&mut data)?;
let linear_params = LinearParams::deserialize(&mut data)?;
Ok(LinearReader {
@@ -268,7 +270,7 @@ mod tests {
#[test]
fn linear_interpol_fast_field_rand() {
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
for _ in 0..50 {
let mut data = (0..10_000).map(|_| rng.next_u64()).collect::<Vec<_>>();
create_and_validate::<LinearCodec>(&data, "random");

View File

@@ -8,7 +8,8 @@ use std::io;
use std::io::Write;
use std::sync::Arc;
use common::{BinarySerializable, OwnedBytes};
use common::BinarySerializable;
use common::file_slice::FileSlice;
use crate::column_values::monotonic_mapping::{
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
@@ -60,7 +61,7 @@ pub trait ColumnCodec<T: PartialOrd = u64> {
type Estimator: ColumnCodecEstimator + Default;
/// Loads a column that has been serialized using this codec.
fn load(bytes: OwnedBytes) -> io::Result<Self::ColumnValues>;
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues>;
/// Returns an estimator.
fn estimator() -> Self::Estimator {
@@ -111,20 +112,22 @@ impl CodecType {
fn load<T: MonotonicallyMappableToU64>(
&self,
bytes: OwnedBytes,
file_slice: FileSlice,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
match self {
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(bytes),
CodecType::Linear => load_specific_codec::<LinearCodec, T>(bytes),
CodecType::BlockwiseLinear => load_specific_codec::<BlockwiseLinearCodec, T>(bytes),
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(file_slice),
CodecType::Linear => load_specific_codec::<LinearCodec, T>(file_slice),
CodecType::BlockwiseLinear => {
load_specific_codec::<BlockwiseLinearCodec, T>(file_slice)
}
}
}
}
fn load_specific_codec<C: ColumnCodec, T: MonotonicallyMappableToU64>(
bytes: OwnedBytes,
file_slice: FileSlice,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let reader = C::load(bytes)?;
let reader = C::load(file_slice)?;
let reader_typed = monotonic_map_column(
reader,
StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::<T>::new()),
@@ -189,25 +192,28 @@ pub fn serialize_u64_based_column_values<T: MonotonicallyMappableToU64>(
///
/// This method first identifies the codec off the first byte.
pub fn load_u64_based_column_values<T: MonotonicallyMappableToU64>(
mut bytes: OwnedBytes,
file_slice: FileSlice,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let codec_type: CodecType = bytes
.first()
.copied()
let (header, body) = file_slice.split(1);
let codec_type: CodecType = header
.read_bytes()?
.as_slice()
.get(0)
.cloned()
.and_then(CodecType::try_from_code)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to read codec type"))?;
bytes.advance(1);
codec_type.load(bytes)
codec_type.load(body)
}
/// Helper function to serialize a column (autodetect from all codecs) and then open it
#[cfg(test)]
pub fn serialize_and_load_u64_based_column_values<T: MonotonicallyMappableToU64>(
vals: &dyn Iterable,
codec_types: &[CodecType],
) -> Arc<dyn ColumnValues<T>> {
let mut buffer = Vec::new();
serialize_u64_based_column_values(vals, codec_types, &mut buffer).unwrap();
load_u64_based_column_values::<T>(OwnedBytes::new(buffer)).unwrap()
load_u64_based_column_values::<T>(FileSlice::from(buffer)).unwrap()
}
#[cfg(test)]

View File

@@ -1,3 +1,4 @@
use common::HasLen;
use proptest::prelude::*;
use proptest::{prop_oneof, proptest};
use rand::Rng;
@@ -13,7 +14,7 @@ fn test_serialize_and_load_simple() {
)
.unwrap();
assert_eq!(buffer.len(), 7);
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
assert_eq!(col.num_vals(), 3);
assert_eq!(col.get_val(0), 1);
assert_eq!(col.get_val(1), 2);
@@ -30,7 +31,7 @@ fn test_empty_column_i64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<i64>(OwnedBytes::new(buffer)).unwrap();
let col = load_u64_based_column_values::<i64>(FileSlice::from(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), i64::MIN);
assert_eq!(col.max_value(), i64::MIN);
@@ -48,7 +49,7 @@ fn test_empty_column_u64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), u64::MIN);
assert_eq!(col.max_value(), u64::MIN);
@@ -66,7 +67,7 @@ fn test_empty_column_f64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<f64>(OwnedBytes::new(buffer)).unwrap();
let col = load_u64_based_column_values::<f64>(FileSlice::from(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
// FIXME. f64::MIN would be better!
assert!(col.min_value().is_nan());
@@ -97,7 +98,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
let actual_compression = buffer.len() as u64;
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
let reader = TColumnCodec::load(FileSlice::from(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() {
@@ -122,7 +123,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
assert_eq!(vals, buffer);
if !vals.is_empty() {
let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1);
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals
.iter()
.enumerate()
@@ -131,7 +132,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
.collect();
let mut positions = Vec::new();
reader.get_row_ids_for_value_range(
vals[test_rand_idx]..=vals[test_rand_idx],
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
0..vals.len() as u32,
&mut positions,
);
@@ -326,7 +327,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer,
)?;
let buffer = OwnedBytes::new(buffer);
let buffer = FileSlice::from(buffer);
let column = crate::column_values::load_u64_based_column_values::<i64>(buffer.clone())?;
assert_eq!(column.get_val(0), -4000i64);
assert_eq!(column.get_val(1), -3000i64);
@@ -343,7 +344,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer_without_gcd,
)?;
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
assert!(buffer_without_gcd.len() > buffer.len());
Ok(())
@@ -369,7 +370,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer,
)?;
let buffer = OwnedBytes::new(buffer);
let buffer = FileSlice::from(buffer);
let column = crate::column_values::load_u64_based_column_values::<u64>(buffer.clone())?;
assert_eq!(column.get_val(0), 1000u64);
assert_eq!(column.get_val(1), 2000u64);
@@ -386,7 +387,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer_without_gcd,
)?;
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
assert!(buffer_without_gcd.len() > buffer.len());
Ok(())
}
@@ -405,7 +406,7 @@ fn test_fastfield_gcd_u64() -> io::Result<()> {
#[test]
pub fn test_fastfield2() {
let test_fastfield = crate::column_values::serialize_and_load_u64_based_column_values::<u64>(
let test_fastfield = serialize_and_load_u64_based_column_values::<u64>(
&&[100u64, 200u64, 300u64][..],
&ALL_U64_CODEC_TYPES,
);

View File

@@ -4,6 +4,7 @@ mod term_merger;
use std::collections::{BTreeMap, HashSet};
use std::io;
use std::io::ErrorKind;
use std::net::Ipv6Addr;
use std::sync::Arc;
@@ -78,6 +79,7 @@ pub fn merge_columnar(
required_columns: &[(String, ColumnType)],
merge_row_order: MergeRowOrder,
output: &mut impl io::Write,
cancel: impl Fn() -> bool,
) -> io::Result<()> {
let mut serializer = ColumnarSerializer::new(output);
let num_docs_per_columnar = columnar_readers
@@ -87,6 +89,9 @@ pub fn merge_columnar(
let columns_to_merge = group_columns_for_merge(columnar_readers, required_columns)?;
for res in columns_to_merge {
if cancel() {
return Err(io::Error::new(ErrorKind::Interrupted, "Merge cancelled"));
}
let ((column_name, _column_type_category), grouped_columns) = res;
let grouped_columns = grouped_columns.open(&merge_row_order)?;
if grouped_columns.is_empty() {

View File

@@ -205,6 +205,7 @@ fn test_merge_columnar_numbers() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -233,6 +234,7 @@ fn test_merge_columnar_texts() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -282,6 +284,7 @@ fn test_merge_columnar_byte() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -338,6 +341,7 @@ fn test_merge_columnar_byte_with_missing() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -390,6 +394,7 @@ fn test_merge_columnar_different_types() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -455,6 +460,7 @@ fn test_merge_columnar_different_empty_cardinality() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -565,6 +571,7 @@ proptest! {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut out,
|| false,
).unwrap();
let merged_reader = ColumnarReader::open(out).unwrap();
@@ -582,6 +589,7 @@ proptest! {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut out,
|| false,
).unwrap();
}

View File

@@ -71,7 +71,14 @@ fn test_format(path: &str) {
let columnar_readers = vec![&reader, &reader2];
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
let mut out = Vec::new();
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
merge_columnar(
&columnar_readers,
&[],
merge_row_order.into(),
&mut out,
|| false,
)
.unwrap();
let reader = ColumnarReader::open(out).unwrap();
check_columns(&reader);
}

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, OwnedBytes};
use common::{ByteCount, DateTime};
use serde::{Deserialize, Serialize};
use crate::column::{BytesColumn, Column, StrColumn};
@@ -239,8 +239,7 @@ pub struct DynamicColumnHandle {
impl DynamicColumnHandle {
// TODO rename load
pub fn open(&self) -> io::Result<DynamicColumn> {
let column_bytes: OwnedBytes = self.file_slice.read_bytes()?;
self.open_internal(column_bytes)
self.open_internal(self.file_slice.clone())
}
#[doc(hidden)]
@@ -259,16 +258,15 @@ impl DynamicColumnHandle {
/// If not, the fastfield reader will returns the u64-value associated with the original
/// FastValue.
pub fn open_u64_lenient(&self) -> io::Result<Option<Column<u64>>> {
let column_bytes = self.file_slice.read_bytes()?;
match self.column_type {
ColumnType::Str | ColumnType::Bytes => {
let column: BytesColumn =
crate::column::open_column_bytes(column_bytes, self.format_version)?;
crate::column::open_column_bytes(self.file_slice.clone(), self.format_version)?;
Ok(Some(column.term_ord_column))
}
ColumnType::IpAddr => {
let column = crate::column::open_column_u128_as_compact_u64(
column_bytes,
self.file_slice.clone(),
self.format_version,
)?;
Ok(Some(column))
@@ -278,40 +276,40 @@ impl DynamicColumnHandle {
| ColumnType::U64
| ColumnType::F64
| ColumnType::DateTime => {
let column =
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?;
let column = crate::column::open_column_u64::<u64>(
self.file_slice.clone(),
self.format_version,
)?;
Ok(Some(column))
}
}
}
fn open_internal(&self, column_bytes: OwnedBytes) -> io::Result<DynamicColumn> {
fn open_internal(&self, file_slice: FileSlice) -> io::Result<DynamicColumn> {
let dynamic_column: DynamicColumn = match self.column_type {
ColumnType::Bytes => {
crate::column::open_column_bytes(column_bytes, self.format_version)?.into()
crate::column::open_column_bytes(file_slice, self.format_version)?.into()
}
ColumnType::Str => {
crate::column::open_column_str(column_bytes, self.format_version)?.into()
crate::column::open_column_str(file_slice, self.format_version)?.into()
}
ColumnType::I64 => {
crate::column::open_column_u64::<i64>(column_bytes, self.format_version)?.into()
crate::column::open_column_u64::<i64>(file_slice, self.format_version)?.into()
}
ColumnType::U64 => {
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?.into()
crate::column::open_column_u64::<u64>(file_slice, self.format_version)?.into()
}
ColumnType::F64 => {
crate::column::open_column_u64::<f64>(column_bytes, self.format_version)?.into()
crate::column::open_column_u64::<f64>(file_slice, self.format_version)?.into()
}
ColumnType::Bool => {
crate::column::open_column_u64::<bool>(column_bytes, self.format_version)?.into()
crate::column::open_column_u64::<bool>(file_slice, self.format_version)?.into()
}
ColumnType::IpAddr => {
crate::column::open_column_u128::<Ipv6Addr>(column_bytes, self.format_version)?
.into()
crate::column::open_column_u128::<Ipv6Addr>(file_slice, self.format_version)?.into()
}
ColumnType::DateTime => {
crate::column::open_column_u64::<DateTime>(column_bytes, self.format_version)?
.into()
crate::column::open_column_u64::<DateTime>(file_slice, self.format_version)?.into()
}
};
Ok(dynamic_column)

View File

@@ -29,6 +29,7 @@ mod column;
pub mod column_index;
pub mod column_values;
mod columnar;
mod comparable_doc;
mod dictionary;
mod dynamic_column;
mod iterable;
@@ -36,7 +37,7 @@ pub(crate) mod utils;
mod value;
pub use block_accessor::ColumnBlockAccessor;
pub use column::{BytesColumn, Column, StrColumn};
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
pub use column_index::ColumnIndex;
pub use column_values::{
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,
@@ -45,6 +46,7 @@ pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
};
pub use comparable_doc::ComparableDoc;
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};

View File

@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect();
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
}
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect();
assert_eq!(
&vals,
&[
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
panic!();
};
let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(str_col.num_rows(), 5);
let mut term_buffer = String::new();
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
panic!();
};
let index: Vec<Option<u64>> = (0..5)
.map(|doc_id| bytes_col.ords().first(doc_id))
.map(|row_id| bytes_col.ords().first(row_id))
.collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(bytes_col.num_rows(), 5);
@@ -641,7 +641,7 @@ proptest! {
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
let mut output: Vec<u8> = Vec::new();
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]).into();
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output).unwrap();
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output, || false,).unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> = columnar_docs.iter().flatten().cloned().collect();
let expected_merged_columnar = build_columnar(&concat_rows[..]);
@@ -665,6 +665,7 @@ fn test_columnar_merging_empty_columnar() {
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -702,6 +703,7 @@ fn test_columnar_merging_number_columns() {
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -775,6 +777,7 @@ fn test_columnar_merge_and_remap(
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -817,6 +820,7 @@ fn test_columnar_merge_empty() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -843,6 +847,7 @@ fn test_columnar_merge_single_str_column() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -875,6 +880,7 @@ fn test_delete_decrease_cardinality() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();

View File

@@ -21,5 +21,5 @@ serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.14.0"
proptest = "1.0.0"
rand = "0.9"
rand = "0.8.4"

View File

@@ -1,6 +1,6 @@
use binggan::{BenchRunner, black_box};
use rand::rng;
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_common::{BitSet, TinySet, serialize_vint_u32};
fn bench_vint() {
@@ -17,7 +17,7 @@ fn bench_vint() {
black_box(out);
});
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut rng(), 100_000);
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
runner.bench_function("bench_vint_rand", move |_| {
let mut out = 0u64;
for val in vals.iter().cloned() {

View File

@@ -178,15 +178,9 @@ impl TinySet {
#[derive(Clone)]
pub struct BitSet {
tinysets: Box<[TinySet]>,
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)
@@ -210,6 +204,7 @@ impl BitSet {
let tinybitsets = vec![TinySet::empty(); num_buckets as usize].into_boxed_slice();
BitSet {
tinysets: tinybitsets,
len: 0,
max_value,
}
}
@@ -227,6 +222,7 @@ impl BitSet {
}
BitSet {
tinysets: tinybitsets,
len: max_value as u64,
max_value,
}
}
@@ -245,19 +241,17 @@ impl BitSet {
/// Intersect with tinysets
fn intersect_update_with_iter(&mut self, other: impl Iterator<Item = TinySet>) {
self.len = 0;
for (left, right) in self.tinysets.iter_mut().zip(other) {
*left = left.intersect(right);
self.len += left.len() as u64;
}
}
/// Returns the number of elements in the `BitSet`.
#[inline]
pub fn len(&self) -> usize {
self.tinysets
.iter()
.copied()
.map(|tinyset| tinyset.len())
.sum::<u32>() as usize
self.len as usize
}
/// Inserts an element in the `BitSet`
@@ -266,7 +260,7 @@ impl BitSet {
// we do not check saturated els.
let higher = el / 64u32;
let lower = el % 64u32;
self.tinysets[higher as usize].insert_mut(lower);
self.len += u64::from(self.tinysets[higher as usize].insert_mut(lower));
}
/// Inserts an element in the `BitSet`
@@ -275,7 +269,7 @@ impl BitSet {
// we do not check saturated els.
let higher = el / 64u32;
let lower = el % 64u32;
self.tinysets[higher as usize].remove_mut(lower);
self.len -= u64::from(self.tinysets[higher as usize].remove_mut(lower));
}
/// Returns true iff the elements is in the `BitSet`.
@@ -297,9 +291,6 @@ impl BitSet {
.map(|delta_bucket| bucket + delta_bucket as u32)
}
/// Returns the maximum number of elements in the bitset.
///
/// Warning: The largest element the bitset can contain is `max_value - 1`.
#[inline]
pub fn max_value(&self) -> u32 {
self.max_value
@@ -417,7 +408,7 @@ mod tests {
use std::collections::HashSet;
use ownedbytes::OwnedBytes;
use rand::distr::Bernoulli;
use rand::distributions::Bernoulli;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

View File

@@ -0,0 +1,106 @@
use std::cell::RefCell;
use std::cmp::min;
use std::io;
use std::ops::Range;
use super::file_slice::FileSlice;
use super::{HasLen, OwnedBytes};
const DEFAULT_BUFFER_MAX_SIZE: usize = 512 * 1024; // 512K
/// A buffered reader for a FileSlice.
///
/// Reads the underlying `FileSlice` in large, sequential chunks to amortize
/// the cost of `read_bytes` calls, while keeping peak memory usage under control.
///
/// TODO: Rather than wrapping a `FileSlice` in buffering, it will usually be better to adjust a
/// `FileHandle` to directly handle buffering itself.
/// TODO: See: https://github.com/paradedb/paradedb/issues/3374
pub struct BufferedFileSlice {
file_slice: FileSlice,
buffer: RefCell<OwnedBytes>,
buffer_range: RefCell<Range<u64>>,
buffer_max_size: usize,
}
impl BufferedFileSlice {
/// Creates a new `BufferedFileSlice`.
///
/// The `buffer_max_size` is the amount of data that will be read from the
/// `FileSlice` on a buffer miss.
pub fn new(file_slice: FileSlice, buffer_max_size: usize) -> Self {
Self {
file_slice,
buffer: RefCell::new(OwnedBytes::empty()),
buffer_range: RefCell::new(0..0),
buffer_max_size,
}
}
/// Creates a new `BufferedFileSlice` with a default buffer max size.
pub fn new_with_default_buffer_size(file_slice: FileSlice) -> Self {
Self::new(file_slice, DEFAULT_BUFFER_MAX_SIZE)
}
/// Creates an empty `BufferedFileSlice`.
pub fn empty() -> Self {
Self::new(FileSlice::empty(), 0)
}
/// Returns an `OwnedBytes` corresponding to the given `required_range`.
///
/// If the requested range is not in the buffer, this will trigger a read
/// from the underlying `FileSlice`.
///
/// If the requested range is larger than the buffer_max_size, it will be read directly from the
/// source without buffering.
///
/// # Errors
///
/// Returns an `io::Error` if the underlying read fails or the range is
/// out of bounds.
pub fn get_bytes(&self, required_range: Range<u64>) -> io::Result<OwnedBytes> {
let buffer_range = self.buffer_range.borrow();
// Cache miss condition: the required range is not fully contained in the current buffer.
if required_range.start < buffer_range.start || required_range.end > buffer_range.end {
drop(buffer_range); // release borrow before mutating
if required_range.end > self.file_slice.len() as u64 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Requested range extends beyond the end of the file slice.",
));
}
if (required_range.end - required_range.start) as usize > self.buffer_max_size {
// This read is larger than our buffer max size.
// Read it directly and bypass the buffer to avoid churning.
return self
.file_slice
.read_bytes_slice(required_range.start as usize..required_range.end as usize);
}
let new_buffer_start = required_range.start;
let new_buffer_end = min(
new_buffer_start + self.buffer_max_size as u64,
self.file_slice.len() as u64,
);
let read_range = new_buffer_start..new_buffer_end;
let new_buffer = self
.file_slice
.read_bytes_slice(read_range.start as usize..read_range.end as usize)?;
self.buffer.replace(new_buffer);
self.buffer_range.replace(read_range);
}
// Now the data is guaranteed to be in the buffer.
let buffer = self.buffer.borrow();
let buffer_range = self.buffer_range.borrow();
let local_start = (required_range.start - buffer_range.start) as usize;
let local_end = (required_range.end - buffer_range.start) as usize;
Ok(buffer.slice(local_start..local_end))
}
}

View File

@@ -1,7 +1,7 @@
use std::fs::File;
use std::ops::{Deref, Range, RangeBounds};
use std::path::Path;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::{fmt, io};
use async_trait::async_trait;
@@ -339,6 +339,27 @@ impl FileHandle for OwnedBytes {
}
}
pub struct DeferredFileSlice {
opener: Arc<dyn Fn() -> io::Result<FileSlice> + Send + Sync + 'static>,
file_slice: OnceLock<std::io::Result<FileSlice>>,
}
impl DeferredFileSlice {
pub fn new(opener: impl Fn() -> io::Result<FileSlice> + Send + Sync + 'static) -> Self {
DeferredFileSlice {
opener: Arc::new(opener),
file_slice: OnceLock::default(),
}
}
pub fn open(&self) -> io::Result<&FileSlice> {
match self.file_slice.get_or_init(|| (self.opener)()) {
Ok(file_slice) => Ok(file_slice),
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
}
}
}
#[cfg(test)]
mod tests {
use std::io;

View File

@@ -6,6 +6,7 @@ pub use byteorder::LittleEndian as Endianness;
mod bitset;
pub mod bounds;
pub mod buffered_file_slice;
mod byte_count;
mod datetime;
pub mod file_slice;

View File

@@ -58,6 +58,33 @@ impl BinarySerializable for VIntU128 {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct VInt(pub u64);
impl VInt {
pub fn deserialize_with_size<R: Read>(reader: &mut R) -> io::Result<(Self, usize)> {
let mut nbytes = 0;
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;
loop {
match bytes.next() {
Some(Ok(b)) => {
nbytes += 1;
result |= u64::from(b % 128u8) << shift;
if b >= STOP_BIT {
return Ok((VInt(result), nbytes));
}
shift += 7;
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Reach end of buffer while reading VInt",
));
}
}
}
}
}
const STOP_BIT: u8 = 128;
#[inline]
@@ -225,7 +252,6 @@ impl BinarySerializable for VInt {
#[cfg(test)]
mod tests {
use super::{BinarySerializable, VInt, serialize_vint_u32};
fn aux_test_vint(val: u64) {

View File

@@ -91,10 +91,46 @@ fn main() -> tantivy::Result<()> {
}
}
// Some other powerful operations (especially `.seek`) may be useful to consume these
// A `Term` is a text token associated with a field.
// Let's go through all docs containing the term `title:the` and access their position
let term_the = Term::from_field_text(title, "the");
// Some other powerful operations (especially `.skip_to`) may be useful to consume these
// posting lists rapidly.
// You can check for them in the [`DocSet`](https://docs.rs/tantivy/~0/tantivy/trait.DocSet.html) trait
// and the [`Postings`](https://docs.rs/tantivy/~0/tantivy/trait.Postings.html) trait
// Also, for some VERY specific high performance use case like an OLAP analysis of logs,
// you can get better performance by accessing directly the blocks of doc ids.
for segment_reader in searcher.segment_readers() {
// A segment contains different data structure.
// Inverted index stands for the combination of
// - the term dictionary
// - the inverted lists associated with each terms and their positions
let inverted_index = segment_reader.inverted_index(title)?;
// This segment posting object is like a cursor over the documents matching the term.
// The `IndexRecordOption` arguments tells tantivy we will be interested in both term
// frequencies and positions.
//
// If you don't need all this information, you may get better performance by decompressing
// less information.
if let Some(mut block_segment_postings) =
inverted_index.read_block_postings(&term_the, IndexRecordOption::Basic)?
{
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
// Once again these docs MAY contains deleted documents as well.
let docs = block_segment_postings.docs();
// Prints `Docs [0, 2].`
println!("Docs {docs:?}");
block_segment_postings.advance();
}
}
}
Ok(())
}

View File

@@ -0,0 +1,86 @@
// # Multiple Snippets Example
//
// This example demonstrates how to return multiple text fragments
// from a document, useful for long documents with matches in different locations.
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::snippet::SnippetGenerator;
use tantivy::{doc, Index, IndexWriter};
use tempfile::TempDir;
fn main() -> tantivy::Result<()> {
let index_path = TempDir::new()?;
// Define the schema
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", TEXT | STORED);
let body = schema_builder.add_text_field("body", TEXT | STORED);
let schema = schema_builder.build();
// Create the index
let index = Index::create_in_dir(&index_path, schema)?;
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
// Index a long document with multiple occurrences of "rust"
index_writer.add_document(doc!(
title => "The Rust Programming Language",
body => "Rust is a systems programming language that runs blazingly fast, prevents \
segfaults, and guarantees thread safety. Lorem ipsum dolor sit amet, \
consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore. \
Rust empowers everyone to build reliable and efficient software. More filler \
text to create distance between matches. Ut enim ad minim veniam, quis nostrud \
exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. \
The Rust compiler is known for its helpful error messages. Duis aute irure \
dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla \
pariatur. Rust has a strong type system and ownership model."
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("rust")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
// Create snippet generator
let mut snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
println!("=== Single Snippet (Default Behavior) ===\n");
for (score, doc_address) in &top_docs {
let doc = searcher.doc::<TantivyDocument>(*doc_address)?;
let snippet = snippet_generator.snippet_from_doc(&doc);
println!("Document score: {}", score);
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
println!("Single snippet: {}\n", snippet.to_html());
}
println!("\n=== Multiple Snippets (New Feature) ===\n");
// Configure to return multiple snippets
// Get up to 3 snippets
snippet_generator.set_snippets_limit(3);
// Smaller fragments
snippet_generator.set_max_num_chars(80);
// By default, multiple snippets are sorted by score. You can change this to sort by position.
// snippet_generator.set_sort_order(SnippetSortOrder::Position);
for (score, doc_address) in top_docs {
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
let snippets = snippet_generator.snippets_from_doc(&doc);
println!("Document score: {}", score);
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
println!("Found {} snippets:", snippets.len());
for (i, snippet) in snippets.iter().enumerate() {
println!(" Snippet {}: {}", i + 1, snippet.to_html());
}
println!();
}
Ok(())
}

3
runtests.sh Executable file
View File

@@ -0,0 +1,3 @@
#! /bin/bash
cargo +stable nextest run --features quickwit,mmap,stopwords,lz4-compression,zstd-compression,failpoints --verbose --workspace

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use columnar::{Column, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize;
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
@@ -35,7 +35,6 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl AggregationsSegmentCtx {
@@ -108,14 +107,21 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
@@ -123,7 +129,10 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
@@ -133,6 +142,21 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -296,7 +320,6 @@ impl PerRequestAggSegCtx {
/// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> =
@@ -322,19 +345,12 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
}
pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new();
for node in nodes.iter() {
@@ -372,8 +388,6 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
@@ -384,21 +398,20 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count
| StatsType::Max
| StatsType::Min
| StatsType::Stats => build_segment_stats_collector(req_data),
StatsType::ExtendedStats(sigma) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
node.idx_in_req_data,
))),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
)),
}
}
AggKind::TopHits => {
@@ -415,8 +428,12 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
}
}
@@ -476,7 +493,6 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx {
per_request: Default::default(),
context,
column_block_accessor: ColumnBlockAccessor::default(),
};
for (name, agg) in aggs.iter() {
@@ -505,9 +521,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: range_req.clone(),
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -525,7 +541,9 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(),
is_date_histogram: false,
bounds: HistogramBounds {
@@ -550,7 +568,9 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req,
is_date_histogram: true,
bounds: HistogramBounds {
@@ -630,6 +650,7 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for,
missing: *missing,
@@ -657,6 +678,7 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing,
@@ -731,7 +753,6 @@ fn build_nodes(
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -874,7 +895,7 @@ fn build_terms_or_cardinality_nodes(
});
}
// Add one node per accessor
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg {
None
@@ -905,8 +926,11 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
@@ -919,6 +943,7 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: req.clone(),
});

View File

@@ -2,441 +2,15 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -451,10 +25,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -467,9 +37,8 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -6,14 +6,10 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema;
@@ -408,18 +404,15 @@ pub struct FilterAggReqData {
pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
}
}
@@ -496,24 +489,17 @@ impl Debug for DocumentQueryEvaluator {
}
}
#[derive(Debug, Clone, PartialEq, Copy)]
struct DocCount {
doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
pub struct SegmentFilterCollector {
/// Document count in this bucket
doc_count: u64,
/// Sub-aggregation collectors
sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl<C: SubAggCache> SegmentFilterCollector<C> {
impl SegmentFilterCollector {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
@@ -525,75 +511,47 @@ impl<C: SubAggCache> SegmentFilterCollector<C> {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
doc_count: 0,
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}
pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
impl Debug for SegmentFilterCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
.field("doc_count", &self.doc_count)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
impl CollectorClone for SegmentFilterCollector {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
if let Some(sub_aggs) = self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
}
// Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
doc_count: self.doc_count,
sub_aggregations: sub_results,
};
@@ -612,17 +570,32 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
Ok(())
}
fn collect(
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
docs: &[DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.is_empty() {
return Ok(());
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -631,24 +604,18 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
bucket.doc_count += req.matching_docs_buffer.len() as u64;
self.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
// Use collect_block for better sub-aggregation performance
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(())
}
@@ -659,21 +626,6 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
}
/// Intermediate result for filter aggregation
@@ -1567,9 +1519,9 @@ mod tests {
let searcher = reader.searcher();
let agg = json!({
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
@@ -26,8 +26,13 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
@@ -252,24 +257,18 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -279,38 +278,27 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
let bucket = self.into_intermediate_bucket_result(agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -319,40 +307,44 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
let val = f64_from_fastfield_u64(val, &req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
SegmentHistogramBucketEntry { key, doc_count: 0 }
});
bucket.doc_count += 1;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
}
}
}
@@ -366,30 +358,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_aggregation) = &mut self.sub_agg {
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -397,19 +373,22 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
pub fn into_intermediate_bucket_result(
self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(histogram.buckets.len());
let mut buckets = Vec::with_capacity(self.buckets.len());
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
buckets.push(bucket_res?);
}
@@ -429,7 +408,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let sub_agg = if !node.children.is_empty() {
let blueprint = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
@@ -444,13 +423,13 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
req_data.sub_aggregation_blueprint = blueprint;
Ok(Self {
parent_buckets: Default::default(),
sub_agg,
buckets: Default::default(),
sub_aggregations: Default::default(),
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}

View File

@@ -1,22 +1,18 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
@@ -27,12 +23,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
@@ -155,47 +151,19 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<C: SubAggCache> {
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
buckets: Vec<SegmentRangeAndBucketEntry>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -216,50 +184,48 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let sub_aggregation = IntermediateAggregationResults::default();
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation_res: sub_aggregation,
sub_aggregation: sub_aggregation_res,
from: self.from,
to: self.to,
})
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
impl SegmentAggregationCollector for SegmentRangeCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
.into_iter()
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(agg_data)?,
))
})
.collect::<crate::Result<_>>()?;
@@ -276,114 +242,73 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_range_req_data(self.accessor_idx);
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
req.column_block_accessor.fetch_block(docs, &req.accessor);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
for (doc, val) in req
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(agg_data)?;
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -392,20 +317,20 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
// let sub_aggregation = sub_agg_prototype.clone();
let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
bucket_id,
sub_aggregation,
key,
from,
to,
@@ -414,19 +339,26 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
})
.collect::<crate::Result<_>>()?;
self.limits.add_memory_consumed(
req_data.context.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(buckets)
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
@@ -524,7 +456,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
Ok(f64_from_fastfield_u64(val, field_type).to_string())
}
};
@@ -554,7 +486,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggCache> {
) -> SegmentRangeCollector {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
@@ -574,33 +506,30 @@ mod tests {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
parent_buckets: vec![buckets],
buckets,
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
@@ -847,7 +776,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -870,7 +799,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -885,7 +814,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -894,7 +823,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -903,7 +832,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -949,7 +878,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -961,3 +890,63 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,13 +5,11 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
@@ -37,55 +35,41 @@ impl MissingTermAggReqData {
}
}
#[derive(Default, Debug, Clone)]
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
missing_count: u32,
accessor_idx: usize,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
}
impl TermMissingAgg {
pub(crate) fn new(
agg_data: &mut AggregationsSegmentCtx,
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
..Default::default()
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
@@ -96,16 +80,13 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: missing_count.missing_count,
doc_count: self.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = &mut self.sub_agg {
if let Some(sub_agg) = self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -128,52 +109,30 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?;
}
}
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
self.collect(*doc, agg_data)?;
}
Ok(())
}

View File

@@ -0,0 +1,87 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -1,245 +0,0 @@
use std::fmt::Debug;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
Ok(())
}
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and there may be nothing to group.
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
partition.clear();
}
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
_force: bool,
) -> crate::Result<()> {
let mut max_bucket = 0u32;
for partition in self.partitions.iter() {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
for slot in self.partitions.iter() {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
}
}
self.clear();
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
}
}
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
// Note: There may be other flexible values, for other aggregations, but we can use the
// const value here as a upper bound. (better than nothing)
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold_limit > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
}
self.clear();
Ok(())
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardCachedSubAggs,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -151,11 +151,8 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
@@ -173,31 +170,26 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
self.agg_collector.push(0, doc);
match self
if let Err(err) = self
.agg_collector
.check_flush_local(&mut self.aggs_with_accessor)
.collect(doc, &mut self.aggs_with_accessor)
{
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
}
}
@@ -208,13 +200,10 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation_res: IntermediateAggregationResults,
pub sub_aggregation: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation_res
.sub_aggregation
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -857,8 +857,7 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
Ok(())
}
}
@@ -888,7 +887,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation_res: Default::default(),
sub_aggregation: Default::default(),
from: None,
to: None,
},
@@ -921,7 +920,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,8 +52,10 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
@@ -106,6 +106,8 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
@@ -133,34 +135,45 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
accessor_idx: usize,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: FxHashSet::default(),
entries: Default::default(),
accessor_idx,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
req_data: &CardinalityAggReqData,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
@@ -181,7 +194,6 @@ impl SegmentCardinalityCollectorBucket {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -215,49 +227,16 @@ impl SegmentCardinalityCollectorBucket {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -268,20 +247,27 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
bucket.entries.insert(term_ord);
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
.accessor
.values
.clone()
@@ -296,29 +282,16 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.sketch.insert_any(&val);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.sketch.insert_any(&val);
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,8 +52,10 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,9 +8,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -317,28 +318,51 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
missing,
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
}
@@ -346,18 +370,15 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
extended_stats,
self.extended_stats,
)),
)?;
@@ -367,36 +388,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -55,6 +55,8 @@ pub struct MetricAggReqData {
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result

View File

@@ -7,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// # Percentiles
///
@@ -130,16 +131,10 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) percentiles: PercentilesCollector,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -234,18 +229,33 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
percentiles: PercentilesCollector::new(),
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -253,18 +263,12 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
results.push(
name,
@@ -277,33 +281,40 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -1,6 +1,5 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
@@ -8,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -187,75 +187,75 @@ pub enum StatsType {
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
impl SegmentStatsCollector {
pub fn from_req(accessor_idx: usize) -> Self {
Self {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
let intermediate_metric_result = match req.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
self.collecting_for
req.collecting_for
)))
}
};
@@ -271,67 +271,41 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,8 +52,10 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -1,7 +1,8 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use std::net::Ipv6Addr;
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn};
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange};
use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR;
use common::DateTime;
use regex::Regex;
@@ -15,11 +16,12 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::aggregation::AggregationError;
use crate::collector::sort_key::{Comparator, ReverseComparator};
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
@@ -382,7 +384,7 @@ impl From<FastFieldValue> for OwnedValue {
/// Holds a fast field value in its u64 representation, and the order in which it should be sorted.
#[derive(Clone, Serialize, Deserialize, Debug)]
struct DocValueAndOrder {
pub(crate) struct DocValueAndOrder {
/// A fast field value in its u64 representation.
value: Option<u64>,
/// Sort order for the value
@@ -454,6 +456,37 @@ impl PartialEq for DocSortValuesAndFields {
impl Eq for DocSortValuesAndFields {}
impl Comparator<DocSortValuesAndFields> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: DocSortValuesAndFields,
) -> ValueRange<DocSortValuesAndFields> {
ValueRange::LessThan(threshold, true)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct TopHitsSegmentSortKey(pub Vec<DocValueAndOrder>);
impl Comparator<TopHitsSegmentSortKey> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: TopHitsSegmentSortKey,
) -> ValueRange<TopHitsSegmentSortKey> {
ValueRange::LessThan(threshold, true)
}
}
/// The TopHitsCollector used for collecting over segments and merging results.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
@@ -471,10 +504,7 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
req: req.clone(),
}
}
@@ -520,8 +550,7 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
top_n: TopNComputer<TopHitsSegmentSortKey, DocAddress, ReverseComparator>,
}
impl TopHitsSegmentCollector {
@@ -530,35 +559,27 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
num_hits,
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
fn into_top_hits_collector(
self,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = top_n.into_vec();
// Map TopHitsSegmentSortKey back to Vec<DocValueAndOrder> if needed or use directly
// The TopNComputer here stores TopHitsSegmentSortKey.
let top_results = self.top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.sort_key,
sorts: res.sort_key.0,
doc_value_fields,
},
res.doc,
@@ -567,24 +588,54 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
TopHitsSegmentSortKey(sorts),
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, &req_data.req),
);
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
@@ -594,54 +645,24 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let req = &req_data.req;
let accessors = &req_data.accessors;
for &doc_id in docs {
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn prepare_max_bucket(
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
}
Ok(())
}
}
@@ -759,7 +780,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -888,7 +909,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
pub(crate) mod cached_sub_aggs;
mod buf_collector;
mod collector;
mod date;
mod error;
@@ -162,19 +162,6 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
@@ -348,37 +335,19 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
match field_type {
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
}
}

View File

@@ -8,67 +8,25 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: Debug {
pub trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
fn collect_block(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -78,7 +36,26 @@ pub trait SegmentAggregationCollector: Debug {
}
}
#[derive(Default)]
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -96,13 +73,12 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?;
}
Ok(())
@@ -110,13 +86,23 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect(parent_bucket_id, docs, agg_data)?;
collector.collect_block(docs, agg_data)?;
}
Ok(())
}
@@ -126,15 +112,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}

View File

@@ -1,229 +0,0 @@
/// Codec specific to postings data.
pub mod postings;
/// Standard tantivy codec. This is the codec you use by default.
pub mod standard;
use std::io;
pub use standard::StandardCodec;
use crate::codec::postings::PostingsCodec;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{Postings, TermInfo};
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::term_query::TermScorer;
use crate::query::{box_scorer, Bm25Weight, BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocId, InvertedIndexReader, Score};
/// Codecs describes how data is layed out on disk.
///
/// For the moment, only postings codec can be custom.
pub trait Codec: Clone + std::fmt::Debug + Send + Sync + 'static {
/// The specific postings type used by this codec.
type PostingsCodec: PostingsCodec;
/// ID of the codec. It should be unique to your codec.
/// Make it human-readable, descriptive, short and unique.
const ID: &'static str;
/// Load codec based on the codec configuration.
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self>;
/// Get codec configuration.
fn to_json_props(&self) -> serde_json::Value;
/// Returns the postings codec.
fn postings_codec(&self) -> &Self::PostingsCodec;
}
/// Object-safe codec is a Codec that can be used in a trait object.
///
/// The point of it is to offer a way to use a codec without a proliferation of generics.
pub trait ObjectSafeCodec: 'static + Send + Sync {
/// Loads a type-erased Postings object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a Postings is still
/// returned in a best-effort manner.
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>>;
/// Loads a type-erased TermScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return TermScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>>;
/// Loads a type-erased PhraseScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return PhraseScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>>;
/// Performs a for_each_pruning operation on the given scorer.
///
/// The function will go through matching documents and call the callback
/// function for all docs with a score exceeding the threshold.
///
/// The function itself will return a larger threshold value,
/// meant to update the threshold value.
///
/// If the codec and the scorer allow it, this function can rely on
/// optimizations like the block-max wand.
fn for_each_pruning(
&self,
threshold: Score,
scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(DocId, Score) -> Score,
);
/// Builds a union scorer possibly specialized if
/// all scorers are `Term<Self::Postings>`.
fn build_union_scorer_with_sum_combiner(
&self,
scorers: Vec<Box<dyn Scorer>>,
num_docs: DocId,
score_combiner_type: SumOrDoNothingCombiner,
) -> Box<dyn Scorer>;
}
impl<TCodec: Codec> ObjectSafeCodec for TCodec {
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>> {
let postings = inverted_index_reader
.read_postings_from_terminfo_specialized(term_info, option, self)?;
Ok(Box::new(postings))
}
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_term_scorer_specialized(
term_info,
option,
fieldnorm_reader,
similarity_weight,
self,
)?;
Ok(box_scorer(scorer))
}
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_phrase_scorer_type_specialized(
term_infos,
similarity_weight,
fieldnorm_reader,
slop,
self,
)?;
Ok(box_scorer(scorer))
}
fn build_union_scorer_with_sum_combiner(
&self,
scorers: Vec<Box<dyn Scorer>>,
num_docs: DocId,
sum_or_do_nothing_combiner: SumOrDoNothingCombiner,
) -> Box<dyn Scorer> {
if !scorers.iter().all(|scorer| {
scorer.is::<TermScorer<<<Self as Codec>::PostingsCodec as PostingsCodec>::Postings>>()
}) {
return box_scorer(BufferedUnionScorer::build(
scorers,
SumCombiner::default,
num_docs,
));
}
let specialized_scorers: Vec<
TermScorer<<<Self as Codec>::PostingsCodec as PostingsCodec>::Postings>,
> = scorers
.into_iter()
.map(|scorer| {
*scorer.downcast::<TermScorer<_>>().ok().expect(
"Downcast failed despite the fact we already checked the type was correct",
)
})
.collect();
match sum_or_do_nothing_combiner {
SumOrDoNothingCombiner::Sum => box_scorer(BufferedUnionScorer::build(
specialized_scorers,
SumCombiner::default,
num_docs,
)),
SumOrDoNothingCombiner::DoNothing => box_scorer(BufferedUnionScorer::build(
specialized_scorers,
DoNothingCombiner::default,
num_docs,
)),
}
}
fn for_each_pruning(
&self,
threshold: Score,
scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
let accerelerated_foreach_pruning_res =
<TCodec as Codec>::PostingsCodec::try_accelerated_for_each_pruning(
threshold, scorer, callback,
);
if let Err(mut scorer) = accerelerated_foreach_pruning_res {
// No acceleration available. We need to do things manually.
scorer.for_each_pruning(threshold, callback);
}
}
}
/// SumCombiner or DoNothingCombiner
#[derive(Copy, Clone)]
pub enum SumOrDoNothingCombiner {
/// Sum scores together
Sum,
/// Do not track any score.
DoNothing,
}

View File

@@ -1,123 +0,0 @@
use std::io;
/// Block-max WAND algorithm.
pub mod block_wand;
use common::OwnedBytes;
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::{Bm25Weight, Scorer};
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Postings codec.
pub trait PostingsCodec: Send + Sync + 'static {
/// Serializer type for the postings codec.
type PostingsSerializer: PostingsSerializer;
/// Postings type for the postings codec.
type Postings: Postings + Clone;
/// Creates a new postings serializer.
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer;
/// Loads postings
///
/// Record option is the option that was passed at indexing time.
/// Requested option is the option that is requested.
///
/// For instance, we may have term_freq in the posting list
/// but we can skip decompressing as we read the posting list.
///
/// If record option does not support the requested option,
/// this method does NOT return an error and will in fact restrict
/// requested_option to what is available.
fn load_postings(
&self,
doc_freq: u32,
postings_data: OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
positions_data: Option<OwnedBytes>,
) -> io::Result<Self::Postings>;
/// If your codec supports different ways to accelerate `for_each_pruning` that's
/// where you should implement it.
///
/// Returning `Err(scorer)` without mutating the scorer nor calling the callback function,
/// is never "wrong". It just leaves the responsability to the caller to call a fallback
/// implementation on the scorer.
///
/// If your codec supports BlockMax-Wand, you just need to have your
/// postings implement `PostingsWithBlockMax` and copy what is done in the StandardPostings
/// codec to enable it.
fn try_accelerated_for_each_pruning(
_threshold: Score,
scorer: Box<dyn Scorer>,
_callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> Result<(), Box<dyn Scorer>> {
Err(scorer)
}
}
/// A postings serializer is a listener that is in charge of serializing postings
///
/// IO is done only once per postings, once all of the data has been received.
/// A serializer will therefore contain internal buffers.
///
/// A serializer is created once and recycled for all postings.
///
/// Clients should use PostingsSerializer as follows.
/// ```rust,no_run
/// // First postings list
/// serializer.new_term(2, true);
/// serializer.write_doc(2, 1);
/// serializer.write_doc(6, 2);
/// serializer.close_term(3);
/// serializer.clear();
/// // Second postings list
/// serializer.new_term(1, true);
/// serializer.write_doc(3, 1);
/// serializer.close_term(3);
/// ```
pub trait PostingsSerializer {
/// The term_doc_freq here is the number of documents
/// in the postings lists.
///
/// It can be used to compute the idf that will be used for the
/// blockmax parameters.
///
/// If not available (e.g. if we do not collect `term_frequencies`
/// blockwand is disabled), the term_doc_freq passed will be set 0.
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool);
/// Records a new document id for the current term.
/// The serializer may ignore it.
fn write_doc(&mut self, doc_id: DocId, term_freq: u32);
/// Closes the current term and writes the postings list associated.
fn close_term(&mut self, doc_freq: u32, wrt: &mut impl io::Write) -> io::Result<()>;
}
/// A light complement interface to Postings to allow block-max wand acceleration.
pub trait PostingsWithBlockMax: Postings {
/// Moves the postings to the block containign `target_doc` and returns
/// an upperbound of the score for documents in the block.
///
/// `Warning`: Calling this method may leave the postings in an invalid state.
/// callers are required to call seek before calling any other of the
/// `Postings` method (like doc / advance etc.).
fn seek_block_max(
&mut self,
target_doc: crate::DocId,
fieldnorm_reader: &FieldNormReader,
similarity_weight: &Bm25Weight,
) -> Score;
/// Returns the last document in the current block (or Terminated if this
/// is the last block).
fn last_doc_in_block(&self) -> crate::DocId;
}

View File

@@ -1,35 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::codec::standard::postings::StandardPostingsCodec;
use crate::codec::Codec;
/// Tantivy's default postings codec.
pub mod postings;
/// Tantivy's default codec.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StandardCodec;
impl Codec for StandardCodec {
type PostingsCodec = StandardPostingsCodec;
const ID: &'static str = "tantivy-default";
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self> {
if !json_value.is_null() {
return Err(crate::TantivyError::InvalidArgument(format!(
"Codec property for the StandardCodec are unexpected. expected null, got {}",
json_value.as_str().unwrap_or("null")
)));
}
Ok(StandardCodec)
}
fn to_json_props(&self) -> serde_json::Value {
serde_json::Value::Null
}
fn postings_codec(&self) -> &Self::PostingsCodec {
&StandardPostingsCodec
}
}

View File

@@ -1,50 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::DocId;
pub struct Block {
doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
len: usize,
}
impl Block {
pub fn new() -> Self {
Block {
doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
len: 0,
}
}
pub fn doc_ids(&self) -> &[DocId] {
&self.doc_ids[..self.len]
}
pub fn term_freqs(&self) -> &[u32] {
&self.term_freqs[..self.len]
}
pub fn clear(&mut self) {
self.len = 0;
}
pub fn append_doc(&mut self, doc: DocId, term_freq: u32) {
let len = self.len;
self.doc_ids[len] = doc;
self.term_freqs[len] = term_freq;
self.len = len + 1;
}
pub fn is_full(&self) -> bool {
self.len == COMPRESSION_BLOCK_SIZE
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn last_doc(&self) -> DocId {
assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
}
}

View File

@@ -1,164 +0,0 @@
use std::io;
use crate::codec::postings::block_wand::{block_wand, block_wand_single_scorer};
use crate::codec::postings::PostingsCodec;
use crate::codec::standard::postings::block_segment_postings::BlockSegmentPostings;
pub use crate::codec::standard::postings::segment_postings::SegmentPostings;
use crate::fieldnorm::FieldNormReader;
use crate::positions::PositionReader;
use crate::query::term_query::TermScorer;
use crate::query::{BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocSet as _, Score, TERMINATED};
mod block;
mod block_segment_postings;
mod segment_postings;
mod skip;
mod standard_postings_serializer;
pub use segment_postings::SegmentPostings as StandardPostings;
pub use standard_postings_serializer::StandardPostingsSerializer;
/// The default postings codec for tantivy.
pub struct StandardPostingsCodec;
#[expect(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone, Copy, Eq)]
pub(crate) enum FreqReadingOption {
NoFreq,
SkipFreq,
ReadFreq,
}
impl PostingsCodec for StandardPostingsCodec {
type PostingsSerializer = StandardPostingsSerializer;
type Postings = SegmentPostings;
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer {
StandardPostingsSerializer::new(avg_fieldnorm, mode, fieldnorm_reader)
}
fn load_postings(
&self,
doc_freq: u32,
postings_data: common::OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
positions_data_opt: Option<common::OwnedBytes>,
) -> io::Result<Self::Postings> {
// Rationalize record_option/requested_option.
let requested_option = requested_option.downgrade(record_option);
let block_segment_postings =
BlockSegmentPostings::open(doc_freq, postings_data, record_option, requested_option)?;
let position_reader = positions_data_opt.map(PositionReader::open).transpose()?;
Ok(SegmentPostings::from_block_postings(
block_segment_postings,
position_reader,
))
}
fn try_accelerated_for_each_pruning(
mut threshold: Score,
mut scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(crate::DocId, Score) -> Score,
) -> Result<(), Box<dyn Scorer>> {
scorer = match scorer.downcast::<TermScorer<Self::Postings>>() {
Ok(term_scorer) => {
block_wand_single_scorer(*term_scorer, threshold, callback);
return Ok(());
}
Err(scorer) => scorer,
};
let mut union_scorer =
scorer.downcast::<BufferedUnionScorer<Box<dyn Scorer>, SumCombiner>>()?;
if !union_scorer
.scorers()
.iter()
.all(|scorer| scorer.is::<TermScorer<Self::Postings>>())
{
return Err(union_scorer);
}
let doc = union_scorer.doc();
if doc == TERMINATED {
return Ok(());
}
let score = union_scorer.score();
if score > threshold {
threshold = callback(doc, score);
}
let boxed_scorers: Vec<Box<dyn Scorer>> = union_scorer.into_scorers();
let scorers: Vec<TermScorer<Self::Postings>> = boxed_scorers
.into_iter()
.map(|scorer| {
*scorer.downcast::<TermScorer<Self::Postings>>().ok().expect(
"Downcast failed despite the fact we already checked the type was correct",
)
})
.collect();
block_wand(scorers, threshold, callback);
Ok(())
}
}
#[cfg(test)]
mod tests {
use common::OwnedBytes;
use super::*;
use crate::codec::postings::PostingsSerializer as _;
use crate::postings::Postings as _;
fn test_segment_postings_tf_aux(num_docs: u32, include_term_freq: bool) -> SegmentPostings {
let mut postings_serializer =
StandardPostingsCodec.new_serializer(1.0f32, IndexRecordOption::WithFreqs, None);
let mut buffer = Vec::new();
postings_serializer.new_term(num_docs, include_term_freq);
for i in 0..num_docs {
postings_serializer.write_doc(i, 2);
}
postings_serializer
.close_term(num_docs, &mut buffer)
.unwrap();
StandardPostingsCodec
.load_postings(
num_docs,
OwnedBytes::new(buffer),
IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs,
None,
)
.unwrap()
}
#[test]
fn test_segment_postings_small_block_with_and_without_freq() {
let small_block_without_term_freq = test_segment_postings_tf_aux(1, false);
assert!(!small_block_without_term_freq.has_freq());
assert_eq!(small_block_without_term_freq.doc(), 0);
assert_eq!(small_block_without_term_freq.term_freq(), 1);
let small_block_with_term_freq = test_segment_postings_tf_aux(1, true);
assert!(small_block_with_term_freq.has_freq());
assert_eq!(small_block_with_term_freq.doc(), 0);
assert_eq!(small_block_with_term_freq.term_freq(), 2);
}
#[test]
fn test_segment_postings_large_block_with_and_without_freq() {
let large_block_without_term_freq = test_segment_postings_tf_aux(128, false);
assert!(!large_block_without_term_freq.has_freq());
assert_eq!(large_block_without_term_freq.doc(), 0);
assert_eq!(large_block_without_term_freq.term_freq(), 1);
let large_block_with_term_freq = test_segment_postings_tf_aux(128, true);
assert!(large_block_with_term_freq.has_freq());
assert_eq!(large_block_with_term_freq.doc(), 0);
assert_eq!(large_block_with_term_freq.term_freq(), 2);
}
}

View File

@@ -1,184 +0,0 @@
use std::cmp::Ordering;
use std::io::{self, Write as _};
use common::{BinarySerializable as _, VInt};
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::block::Block;
use crate::codec::standard::postings::skip::SkipSerializer;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockEncoder, VIntEncoder as _, COMPRESSION_BLOCK_SIZE};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Serializer object for tantivy's default postings format.
pub struct StandardPostingsSerializer {
last_doc_id_encoded: u32,
block_encoder: BlockEncoder,
block: Box<Block>,
postings_write: Vec<u8>,
skip_write: SkipSerializer,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<Bm25Weight>,
avg_fieldnorm: Score, /* Average number of term in the field for that segment.
* this value is used to compute the block wand information. */
term_has_freq: bool,
}
impl StandardPostingsSerializer {
pub(crate) fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> StandardPostingsSerializer {
Self {
last_doc_id_encoded: 0,
block_encoder: BlockEncoder::new(),
block: Box::new(Block::new()),
postings_write: Vec::new(),
skip_write: SkipSerializer::new(),
mode,
fieldnorm_reader,
bm25_weight: None,
avg_fieldnorm,
term_has_freq: false,
}
}
}
impl PostingsSerializer for StandardPostingsSerializer {
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.clear();
self.term_has_freq = self.mode.has_freq() && record_term_freq;
if !self.term_has_freq {
return;
}
let num_docs_in_segment: u64 =
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
fieldnorm_reader.num_docs() as u64
} else {
return;
};
if num_docs_in_segment == 0 {
return;
}
self.bm25_weight = Some(Bm25Weight::for_one_term_without_explain(
term_doc_freq as u64,
num_docs_in_segment,
self.avg_fieldnorm,
));
}
fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
self.write_block();
}
}
fn close_term(&mut self, doc_freq: u32, output_write: &mut impl io::Write) -> io::Result<()> {
if !self.block.is_empty() {
// we have doc ids waiting to be written
// this happens when the number of doc ids is
// not a perfect multiple of our block size.
//
// In that case, the remaining part is encoded
// using variable int encoding.
{
let block_encoded = self
.block_encoder
.compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.postings_write.write_all(block_encoded)?;
}
// ... Idem for term frequencies
if self.term_has_freq {
let block_encoded = self
.block_encoder
.compress_vint_unsorted(self.block.term_freqs());
self.postings_write.write_all(block_encoded)?;
}
self.block.clear();
}
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(output_write)?;
output_write.write_all(skip_data)?;
}
output_write.write_all(&self.postings_write[..])?;
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}
}
impl StandardPostingsSerializer {
fn clear(&mut self) {
self.bm25_weight = None;
self.block.clear();
self.last_doc_id_encoded = 0;
}
fn write_block(&mut self) {
{
// encode the doc ids
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.last_doc_id_encoded = self.block.last_doc();
self.skip_write
.write_doc(self.last_doc_id_encoded, num_bits);
// last el block 0, offset block 1,
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
self.postings_write.extend(block_encoded);
self.skip_write.write_term_freq(num_bits);
if self.mode.has_positions() {
// We serialize the sum of term freqs within the skip information
// in order to navigate through positions.
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params = (0u8, 0u32);
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids().iter().cloned();
let term_freqs = self.block.term_freqs().iter().cloned();
let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
blockwand_params = fieldnorms
.zip(term_freqs)
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
)
.unwrap();
}
}
let (fieldnorm_id, term_freq) = blockwand_params;
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
}

View File

@@ -486,9 +486,9 @@ mod tests {
use std::collections::BTreeSet;
use columnar::Dictionary;
use rand::distr::Uniform;
use rand::distributions::Uniform;
use rand::prelude::SliceRandom;
use rand::{rng, Rng};
use rand::{thread_rng, Rng};
use super::{FacetCollector, FacetCounts};
use crate::collector::facet_collector::compress_mapping;
@@ -731,7 +731,7 @@ mod tests {
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let uniform = Uniform::new_inclusive(1, 100_000).unwrap();
let uniform = Uniform::new_inclusive(1, 100_000);
let mut docs: Vec<TantivyDocument> =
vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)]
.into_iter()
@@ -741,11 +741,14 @@ mod tests {
std::iter::repeat_n(doc, count)
})
.map(|mut doc| {
doc.add_facet(facet_field, &format!("/facet/{}", rng().sample(uniform)));
doc.add_facet(
facet_field,
&format!("/facet/{}", thread_rng().sample(uniform)),
);
doc
})
.collect();
docs[..].shuffle(&mut rng());
docs[..].shuffle(&mut thread_rng());
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for doc in docs {
@@ -818,9 +821,8 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rng;
use rand::seq::SliceRandom;
use rand::thread_rng;
use test::Bencher;
use crate::collector::FacetCollector;
@@ -843,7 +845,7 @@ mod bench {
}
}
// 40425 docs
docs[..].shuffle(&mut rng());
docs[..].shuffle(&mut thread_rng());
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for doc in docs {

View File

@@ -96,10 +96,9 @@ mod histogram_collector;
pub use histogram_collector::HistogramCollector;
mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
pub use columnar::ComparableDoc;
mod top_collector;
pub use self::top_collector::ComparableDoc;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_score_collector;
pub use self::top_score_collector::{TopDocs, TopNComputer};

View File

@@ -281,7 +281,6 @@ impl SegmentCollector for MultiCollectorChild {
#[cfg(test)]
mod tests {
use super::*;
use crate::collector::{Count, TopDocs};
use crate::query::TermQuery;

View File

@@ -13,31 +13,13 @@ pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
mod tests {
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
Comparator, NaturalComparator, ReverseComparator, SortByErasedType, SortBySimilarityScore,
SortByStaticFastValue, SortByString,
};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
@@ -389,6 +371,52 @@ pub(crate) mod tests {
Ok(())
}
#[test]
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
let index = make_index()?;
type CompoundSortKey = (Option<String>, Option<f64>);
fn assert_query(
index: &Index,
city_order: Order,
altitude_order: Order,
expected: Vec<(CompoundSortKey, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<f64>::for_field("altitude"),
altitude_order,
),
));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(key, doc)| (key, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
Order::Desc,
vec![
((Some("austin".to_owned()), Some(149.0)), 0),
((Some("greenville".to_owned()), Some(27.0)), 1),
((Some("tokyo".to_owned()), Some(40.0)), 2),
((None, Some(0.0)), 3),
],
)?;
Ok(())
}
use proptest::prelude::*;
proptest! {
@@ -441,7 +469,11 @@ pub(crate) mod tests {
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
sort_hits(&mut comparable_docs, order);
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -451,4 +483,197 @@ pub(crate) mod tests {
);
}
}
proptest! {
#[test]
fn test_order_by_compound_prop(
city_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
altitude_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
(proptest::option::of("[a-c]"), proptest::option::of(0..50u64)),
1..10_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let altitude = schema_builder.add_u64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for (city_val, altitude_val) in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(c) = city_val {
doc.add_text(city, c);
}
if let Some(a) = altitude_val {
doc.add_u64(altitude, a);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<u64>::for_field("altitude"),
altitude_order,
),
));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<((Option<String>, Option<u64>), DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let city_val = if let Some(col) = reader.fast_fields().str("city").unwrap() {
let ord = col.ords().first(doc_addr.doc_id);
if let Some(ord) = ord {
let mut out = Vec::new();
col.dictionary().ord_to_term(ord, &mut out).unwrap();
String::from_utf8(out).ok()
} else {
None
}
} else {
None
};
let alt_val = if let Some((col, _)) = reader.fast_fields().u64_lenient("altitude").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
((city_val, alt_val), doc_addr)
})
.collect();
let city_comparator = ComparatorEnum::from(city_order);
let alt_comparator = ComparatorEnum::from(altitude_order);
let comparator = (city_comparator, alt_comparator);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
proptest! {
#[test]
fn test_order_by_u64_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
proptest::option::of(0..100u64),
1..1000_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let field = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for val in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(v) = val {
doc.add_u64(field, v);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((SortByStaticFastValue::<u64>::for_field("field"), order));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<(Option<u64>, DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let val = if let Some((col, _)) = reader.fast_fields().u64_lenient("field").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
(val, doc_addr)
})
.collect();
let comparator = ComparatorEnum::from(order);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
}

View File

@@ -1,9 +1,9 @@
use std::cmp::Ordering;
use columnar::MonotonicallyMappableToU64;
use columnar::{MonotonicallyMappableToU64, ValueRange};
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score};
@@ -69,6 +69,26 @@ fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedVal
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
/// Return the order between two ComparableDoc values, using the semantics which are
/// implemented by TopNComputer.
#[inline(always)]
fn compare_doc<D: Ord>(
&self,
lhs: &ComparableDoc<T, D>,
rhs: &ComparableDoc<T, D>,
) -> Ordering {
// TopNComputer sorts in descending order of the SortKey by default: we apply that ordering
// here to ease comparison in testing.
self.compare(&rhs.sort_key, &lhs.sort_key).then_with(|| {
// In case of a tie on the sort key, we always sort by ascending `DocAddress` in order
// to ensure a stable sorting of the documents, regardless of the sort key's order.
// See the TopNComputer docs for more information.
lhs.doc.cmp(&rhs.doc)
})
}
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
}
/// Compare values naturally (e.g. 1 < 2).
@@ -84,7 +104,11 @@ pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
lhs.partial_cmp(rhs).unwrap()
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
ValueRange::GreaterThan(threshold, false)
}
}
@@ -97,6 +121,10 @@ impl Comparator<OwnedValue> for NaturalComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, false)
}
}
/// Compare values in reverse (e.g. 2 < 1).
@@ -114,13 +142,69 @@ impl Comparator<OwnedValue> for NaturalComparator {
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
macro_rules! impl_reverse_comparator_primitive {
($($t:ty),*) => {
$(
impl Comparator<$t> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> {
ValueRange::LessThan(threshold, true)
}
}
)*
}
}
impl_reverse_comparator_primitive!(
bool,
u8,
u16,
u32,
u64,
u128,
usize,
i8,
i16,
i32,
i64,
i128,
isize,
f32,
f64,
String,
crate::DateTime,
Vec<u8>,
crate::schema::Facet
);
impl<T: PartialOrd + Send + Sync + std::fmt::Debug + Clone + 'static> Comparator<Option<T>>
for ReverseComparator
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
fn compare(&self, lhs: &Option<T>, rhs: &Option<T>) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
let is_some = threshold.is_some();
ValueRange::LessThan(threshold, is_some)
}
}
impl Comparator<OwnedValue> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
let is_not_null = !matches!(threshold, OwnedValue::Null);
ValueRange::LessThan(threshold, is_not_null)
}
}
/// Compare values in reverse, but treating `None` as lower than `Some`.
@@ -147,6 +231,14 @@ where ReverseComparator: Comparator<T>
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
ValueRange::LessThan(threshold, false)
} else {
ValueRange::GreaterThan(threshold, false)
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
@@ -154,6 +246,10 @@ impl Comparator<u32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
@@ -161,6 +257,10 @@ impl Comparator<u64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
@@ -168,6 +268,10 @@ impl Comparator<f64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
@@ -175,6 +279,10 @@ impl Comparator<f32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
@@ -182,6 +290,10 @@ impl Comparator<i64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
@@ -189,6 +301,10 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
@@ -196,6 +312,10 @@ impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::LessThan(threshold, false)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
@@ -218,6 +338,15 @@ where NaturalComparator: Comparator<T>
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
let is_some = threshold.is_some();
ValueRange::GreaterThan(threshold, is_some)
} else {
ValueRange::LessThan(threshold, false)
}
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
@@ -225,6 +354,10 @@ impl Comparator<u32> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
@@ -232,6 +365,10 @@ impl Comparator<u64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
@@ -239,6 +376,10 @@ impl Comparator<f64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
@@ -246,6 +387,10 @@ impl Comparator<f32> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
@@ -253,6 +398,10 @@ impl Comparator<i64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
@@ -260,6 +409,10 @@ impl Comparator<String> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
@@ -267,6 +420,10 @@ impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, true)
}
}
/// An enum representing the different sort orders.
@@ -308,6 +465,19 @@ where
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
match self {
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
ComparatorEnum::ReverseNoneLower => {
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
}
ComparatorEnum::NaturalNoneHigher => {
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
}
}
}
}
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
@@ -322,6 +492,10 @@ where
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
@@ -338,6 +512,13 @@ where
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, Type3)),
) -> ValueRange<(Type1, (Type2, Type3))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
@@ -354,6 +535,13 @@ where
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3),
) -> ValueRange<(Type1, Type2, Type3)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -377,6 +565,13 @@ where
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, (Type3, Type4))),
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -400,6 +595,13 @@ where
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3, Type4),
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
@@ -489,16 +691,32 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
type Buffer = TSegmentSortKeyComputer::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.comparator.clone()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.segment_sort_key_computer
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
@@ -519,36 +737,13 @@ mod tests {
use super::*;
use crate::schema::OwnedValue;
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = None;
let v1 = Some(1_u64);
let v2 = Some(2_u64);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(None, Some(2)) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(Some(1), None) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(None, None) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
#[test]
fn test_mixed_ownedvalue_compare() {
let u = OwnedValue::U64(10);
let i = OwnedValue::I64(10);
let f = OwnedValue::F64(10.0);
let nc = NaturalComparator;
let nc = NaturalComparator::default();
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
@@ -564,4 +759,27 @@ mod tests {
// Str < F64
assert_eq!(nc.compare(&s, &f), Ordering::Less);
}
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = OwnedValue::Null;
let v1 = OwnedValue::U64(1);
let v2 = OwnedValue::U64(2);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(Null, 2) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(1, Null) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(Null, Null) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
}

View File

@@ -1,9 +1,10 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
@@ -36,12 +37,23 @@ impl SortByErasedType {
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
);
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
struct ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
inner: C,
converter: F,
buffer: C::Buffer,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
@@ -53,6 +65,16 @@ where
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
) {
self.inner
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
@@ -60,7 +82,7 @@ where
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
segment_computer: SortBySimilarityScoreSegmentComputer,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
@@ -69,6 +91,15 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
Some(score_value.to_u64())
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
_filter: ValueRange<Option<u64>>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
@@ -112,6 +143,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::U64 => {
@@ -122,6 +154,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::I64 => {
@@ -132,6 +165,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::F64 => {
@@ -142,6 +176,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::Bool => {
@@ -152,6 +187,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::DateTime => {
@@ -162,6 +198,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
column_type => {
@@ -174,7 +211,8 @@ impl SortKeyComputer for SortByErasedType {
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
segment_computer: SortBySimilarityScore
.segment_sort_key_computer(segment_reader)?,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
@@ -189,12 +227,23 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.inner.segment_sort_keys(input_docs, output, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}
@@ -333,7 +382,7 @@ mod tests {
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
_ => panic!("Wrong type {:?}", key),
})
.collect();
@@ -351,7 +400,7 @@ mod tests {
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
_ => panic!("Wrong type {:?}", key),
})
.collect();

View File

@@ -1,5 +1,7 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -9,7 +11,7 @@ pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Child = SortBySimilarityScoreSegmentComputer;
type Comparator = NaturalComparator;
@@ -21,7 +23,7 @@ impl SortKeyComputer for SortBySimilarityScore {
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
Ok(SortBySimilarityScoreSegmentComputer)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
@@ -61,16 +63,29 @@ impl SortKeyComputer for SortBySimilarityScore {
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
pub struct SortBySimilarityScoreSegmentComputer;
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}

View File

@@ -1,9 +1,10 @@
use std::marker::PhantomData;
use columnar::Column;
use columnar::{Column, ValueRange};
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
@@ -84,13 +85,110 @@ impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column
.first_vals_in_value_range(input_docs, output, u64_filter);
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST};
use crate::Index;
#[test]
fn test_sort_by_fast_value_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(10), Some(20), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_fast_value_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(20)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,7 +1,10 @@
use columnar::StrColumn;
use columnar::{StrColumn, ValueRange};
use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
@@ -50,6 +53,7 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -57,6 +61,28 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
str_column.ords().first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
if let Some(str_column) = &self.str_column_opt {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column
.ords()
.first_vals_in_value_range(input_docs, output, u64_filter);
} else if range_contains_none(&filter) {
for &doc in input_docs {
output.push(ComparableDoc {
doc,
sort_key: None,
});
}
}
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
@@ -70,3 +96,90 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
String::try_from(bytes).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_string_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(0), Some(1), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_string_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(1)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,8 +1,12 @@
use std::cmp::Ordering;
use columnar::ValueRange;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::collector::{
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
@@ -21,7 +25,10 @@ pub trait SegmentSortKeyComputer: 'static {
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
/// Buffer type used for scratch space.
type Buffer: Default + Send + Sync + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
@@ -31,6 +38,18 @@ pub trait SegmentSortKeyComputer: 'static {
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort keys for a batch of documents.
///
/// The computed sort keys and document IDs are pushed into the `output` vector.
/// The `buffer` is used for scratch space.
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
);
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
@@ -39,12 +58,32 @@ pub trait SegmentSortKeyComputer: 'static {
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
let comparator = self.segment_comparator();
let value_range = if let Some(threshold) = &top_n_computer.threshold {
comparator.threshold_to_valuerange(threshold.clone())
} else {
ValueRange::All
};
let (buffer, scratch) = top_n_computer.buffer_and_scratch();
self.segment_sort_keys(docs, buffer, scratch, value_range);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -58,26 +97,6 @@ pub trait SegmentSortKeyComputer: 'static {
self.segment_comparator().compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
}
/// Convert a segment level sort key into the global sort key.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
}
@@ -145,7 +164,8 @@ where
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Child =
ChainSegmentSortKeyComputer<HeadSortKeyComputer::Child, TailSortKeyComputer::Child>;
type Comparator = (
HeadSortKeyComputer::Comparator,
@@ -157,10 +177,10 @@ where
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
Ok(ChainSegmentSortKeyComputer {
head: self.0.segment_sort_key_computer(segment_reader)?,
tail: self.1.segment_sort_key_computer(segment_reader)?,
})
}
/// Checks whether the schema is compatible with the sort key computer.
@@ -178,25 +198,91 @@ where
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
pub struct ChainSegmentSortKeyComputer<Head, Tail>
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
head: Head,
tail: Tail,
}
type SegmentComparator = (
HeadSegmentSortKeyComputer::SegmentComparator,
TailSegmentSortKeyComputer::SegmentComparator,
);
pub struct ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey> {
pub head: HeadBuffer,
pub tail: TailBuffer,
pub head_output: Vec<ComparableDoc<HeadKey, DocId>>,
pub tail_output: Vec<ComparableDoc<TailKey, DocId>>,
pub tail_input_docs: Vec<DocId>,
}
impl<HeadBuffer: Default, TailBuffer: Default, HeadKey, TailKey> Default
for ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey>
{
fn default() -> Self {
ChainBuffer {
head: HeadBuffer::default(),
tail: TailBuffer::default(),
head_output: Vec::new(),
tail_output: Vec::new(),
tail_input_docs: Vec::new(),
}
}
}
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &<Self as SegmentSortKeyComputer>::SegmentSortKey,
) -> Option<(Ordering, <Self as SegmentSortKeyComputer>::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let head_sort_key = self.head.segment_sort_key(doc_id, score);
let head_cmp = self
.head
.compare_segment_sort_key(&head_sort_key, head_threshold);
if head_cmp == Ordering::Less {
None
} else if head_cmp == Ordering::Equal {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
let tail_cmp = self
.tail
.compare_segment_sort_key(&tail_sort_key, tail_threshold);
if tail_cmp == Ordering::Less {
None
} else {
Some((tail_cmp, (head_sort_key, tail_sort_key)))
}
} else {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
}
impl<Head, Tail> SegmentSortKeyComputer for ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (Head::SortKey, Tail::SortKey);
type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey);
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
type Buffer =
ChainBuffer<Head::Buffer, Tail::Buffer, Head::SegmentSortKey, Tail::SegmentSortKey>;
fn segment_comparator(&self) -> Self::SegmentComparator {
(
self.head.segment_comparator(),
self.tail.segment_comparator(),
)
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
@@ -208,9 +294,90 @@ where
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
self.head
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let (head_filter, threshold) = match filter {
ValueRange::GreaterThan((head_threshold, tail_threshold), _)
| ValueRange::LessThan((head_threshold, tail_threshold), _) => {
let head_cmp = self.head.segment_comparator();
let strict_head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_filter = match strict_head_filter {
ValueRange::GreaterThan(t, m) => ValueRange::GreaterThanOrEqual(t, m),
ValueRange::LessThan(t, m) => ValueRange::LessThanOrEqual(t, m),
other => other,
};
(head_filter, Some((head_threshold, tail_threshold)))
}
_ => (ValueRange::All, None),
};
buffer.head_output.clear();
self.head.segment_sort_keys(
input_docs,
&mut buffer.head_output,
&mut buffer.head,
head_filter,
);
if buffer.head_output.is_empty() {
return;
}
buffer.tail_output.clear();
buffer.tail_input_docs.clear();
for cd in &buffer.head_output {
buffer.tail_input_docs.push(cd.doc);
}
self.tail.segment_sort_keys(
&buffer.tail_input_docs,
&mut buffer.tail_output,
&mut buffer.tail,
ValueRange::All,
);
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
for (head_doc, tail_doc) in buffer
.head_output
.drain(..)
.zip(buffer.tail_output.drain(..))
{
debug_assert_eq!(head_doc.doc, tail_doc.doc);
let doc = head_doc.doc;
let head_key = head_doc.sort_key;
let tail_key = tail_doc.sort_key;
let accept = if let Some((head_threshold, tail_threshold)) = &threshold {
let head_ord = head_cmp.compare(&head_key, head_threshold);
let ord = if head_ord == Ordering::Equal {
tail_cmp.compare(&tail_key, tail_threshold)
} else {
head_ord
};
ord == Ordering::Greater
} else {
true
};
if accept {
output.push(ComparableDoc {
sort_key: (head_key, tail_key),
doc,
});
}
}
}
#[inline(always)]
@@ -218,7 +385,7 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
@@ -235,48 +402,29 @@ where
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
let head_sort_key = self.head.segment_sort_key(doc, score);
let tail_sort_key = self.tail.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
self.head.convert_segment_sort_key(head_sort_key),
self.tail.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
pub struct MappedSegmentSortKeyComputer<T: SegmentSortKeyComputer, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
map: fn(T::SortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
for MappedSegmentSortKeyComputer<T, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync,
@@ -285,19 +433,25 @@ where
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
type Buffer = T::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.sort_key_computer.segment_comparator()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
fn segment_sort_keys(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
@@ -305,12 +459,21 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_keys_and_collect(docs, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
@@ -336,10 +499,6 @@ where
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
@@ -363,7 +522,13 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: sort_key_computer3,
},
},
map,
})
}
@@ -398,13 +563,6 @@ where
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
@@ -426,10 +584,16 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer3,
tail: sort_key_computer4,
},
},
},
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
@@ -452,6 +616,13 @@ where
}
}
use std::marker::PhantomData;
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
_phantom: PhantomData<TSortKey>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
@@ -459,15 +630,18 @@ where
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Child = FuncSegmentSortKeyComputer<SegmentF, TSortKey>;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
Ok(FuncSegmentSortKeyComputer {
func: (self)(segment_reader),
_phantom: PhantomData,
})
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
impl<F, TSortKey> SegmentSortKeyComputer for FuncSegmentSortKeyComputer<F, TSortKey>
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
@@ -475,9 +649,25 @@ where
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
(self.func)(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
for &doc in input_docs {
output.push(ComparableDoc {
sort_key: (self.func)(doc),
doc,
});
}
}
/// Convert a segment level score into the global level score.
@@ -486,13 +676,75 @@ where
}
}
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&None),
ValueRange::GreaterThan(_threshold, match_nulls) => *match_nulls,
ValueRange::GreaterThanOrEqual(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThan(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThanOrEqual(_threshold, match_nulls) => *match_nulls,
}
}
pub(crate) fn convert_optional_u64_range_to_u64_range(
range: ValueRange<Option<u64>>,
) -> ValueRange<u64> {
match range {
ValueRange::Inclusive(r) => {
let start = r.start().unwrap_or(0);
let end = r.end().unwrap_or(u64::MAX);
ValueRange::Inclusive(start..=end)
}
ValueRange::GreaterThan(Some(val), match_nulls) => {
ValueRange::GreaterThan(val, match_nulls)
}
ValueRange::GreaterThan(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::GreaterThanOrEqual(Some(val), match_nulls) => {
ValueRange::GreaterThanOrEqual(val, match_nulls)
}
ValueRange::GreaterThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::LessThan(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThan(Some(val), match_nulls) => ValueRange::LessThan(val, match_nulls),
ValueRange::LessThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThanOrEqual(Some(val), match_nulls) => {
ValueRange::LessThanOrEqual(val, match_nulls)
}
ValueRange::All => ValueRange::All,
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
@@ -640,4 +892,178 @@ mod tests {
(200u32, 2u32)
);
}
#[test]
fn test_batch_score_computer_edge_case() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_secondary = |_segment_reader: &SegmentReader| |_doc: DocId| "b";
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let mut top_n_computer =
TopNComputer::new_with_comparator(10, lazy_score_computer.comparator());
// Threshold (200, "a"). Doc is (200, "b"). 200 == 200, "b" > "a". Should be accepted.
top_n_computer.threshold = Some((200, "a"));
let docs = vec![0];
segment_sort_key_computer.compute_sort_keys_and_collect(&docs, &mut top_n_computer);
let results = top_n_computer.into_sorted_vec();
assert_eq!(results.len(), 1);
let result = &results[0];
assert_eq!(result.doc, 0);
assert_eq!(result.sort_key, (200, "b"));
}
}
#[cfg(test)]
mod proptest_tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::order::*;
// Re-implement logic to interpret ValueRange<Option<u64>> manually to verify expectations
fn range_contains_opt(range: &ValueRange<Option<u64>>, val: &Option<u64>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val > t
}
}
ValueRange::GreaterThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val >= t
}
}
ValueRange::LessThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val < t
}
}
ValueRange::LessThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val <= t
}
}
}
}
fn range_contains_u64(range: &ValueRange<u64>, val: &u64) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, _) => val > t,
ValueRange::GreaterThanOrEqual(t, _) => val >= t,
ValueRange::LessThan(t, _) => val < t,
ValueRange::LessThanOrEqual(t, _) => val <= t,
}
}
proptest! {
#[test]
fn test_comparator_consistency_natural_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseNoneIsLowerComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_natural_none_is_higher(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalNoneIsHigherComparator>(threshold, val)?;
}
}
fn check_comparator<C: Comparator<Option<u64>>>(
threshold: Option<u64>,
val: Option<u64>,
) -> std::result::Result<(), proptest::test_runner::TestCaseError> {
let comparator = C::default();
let range = comparator.threshold_to_valuerange(threshold);
let ordering = comparator.compare(&val, &threshold);
let should_be_in_range = ordering == Ordering::Greater;
let in_range_opt = range_contains_opt(&range, &val);
prop_assert_eq!(
in_range_opt,
should_be_in_range,
"Comparator consistency failed for {:?}. Threshold: {:?}, Val: {:?}, Range: {:?}, \
Ordering: {:?}. range_contains_opt says {}, but compare says {}",
std::any::type_name::<C>(),
threshold,
val,
range,
ordering,
in_range_opt,
should_be_in_range
);
// Check range_contains_none
let expected_none_in_range = range_contains_opt(&range, &None);
let actual_none_in_range = range_contains_none(&range);
prop_assert_eq!(
actual_none_in_range,
expected_none_in_range,
"range_contains_none failed for {:?}. Range: {:?}. Expected (from \
range_contains_opt): {}, Actual: {}",
std::any::type_name::<C>(),
range,
expected_none_in_range,
actual_none_in_range
);
// Check convert_optional_u64_range_to_u64_range
let u64_range = convert_optional_u64_range_to_u64_range(range.clone());
if let Some(v) = val {
let in_u64_range = range_contains_u64(&u64_range, &v);
let in_opt_range = range_contains_opt(&range, &Some(v));
prop_assert_eq!(
in_u64_range,
in_opt_range,
"convert_optional_u64_range_to_u64_range failed for {:?}. Val: {:?}, OptRange: \
{:?}, U64Range: {:?}. Opt says {}, U64 says {}",
std::any::type_name::<C>(),
v,
range,
u64_range,
in_opt_range,
in_u64_range
);
}
Ok(())
}
}

View File

@@ -99,7 +99,12 @@ where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) topn_computer: TopNComputer<
TSegmentSortKeyComputer::SegmentSortKey,
DocId,
C,
TSegmentSortKeyComputer::Buffer,
>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
@@ -120,6 +125,11 @@ where
);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.segment_sort_key_computer
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
@@ -160,7 +170,7 @@ mod tests {
expected: &[(crate::Score, usize)],
) {
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
vals.shuffle(&mut rand::rng());
vals.shuffle(&mut rand::thread_rng());
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
assert_eq!(&vals_merged, expected);
}

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use std::fmt;
use std::ops::Range;
use columnar::ValueRange;
use serde::{Deserialize, Serialize};
use super::Collector;
@@ -10,8 +11,7 @@ use crate::collector::sort_key::{
SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key_top_collector::TopBySortKeyCollector;
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastValue;
use crate::{DocAddress, DocId, Order, Score, SegmentReader};
@@ -481,11 +481,22 @@ where
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score)
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation is not supported for tweak score.")
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
@@ -509,12 +520,14 @@ where
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
pub struct TopNComputer<Score, D, C, Buffer = ()> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
comparator: C,
#[serde(skip)]
scratch: Buffer,
}
// Intermediate struct for TopNComputer for deserialization, to keep vec capacity
@@ -526,7 +539,9 @@ struct TopNComputerDeser<Score, D, C> {
comparator: C,
}
impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C> {
impl<Score, D, C, Buffer> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C, Buffer>
where Buffer: Default
{
fn from(mut value: TopNComputerDeser<Score, D, C>) -> Self {
let expected_cap = value.top_n.max(1) * 2;
let current_cap = value.buffer.capacity();
@@ -541,12 +556,15 @@ impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D
top_n: value.top_n,
threshold: value.threshold,
comparator: value.comparator,
scratch: Buffer::default(),
}
}
}
impl<Score: std::fmt::Debug, D, C> std::fmt::Debug for TopNComputer<Score, D, C>
where C: Comparator<Score>
impl<Score: std::fmt::Debug, D, C, Buffer> std::fmt::Debug for TopNComputer<Score, D, C, Buffer>
where
C: Comparator<Score>,
Buffer: std::fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TopNComputer")
@@ -554,12 +572,13 @@ where C: Comparator<Score>
.field("top_n", &self.top_n)
.field("current_threshold", &self.threshold)
.field("comparator", &self.comparator)
.field("scratch", &self.scratch)
.finish()
}
}
// Custom clone to keep capacity
impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Score, D, C, Buffer> {
fn clone(&self) -> Self {
let mut buffer_clone = Vec::with_capacity(self.buffer.capacity());
buffer_clone.extend(self.buffer.iter().cloned());
@@ -568,15 +587,17 @@ impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
top_n: self.top_n,
threshold: self.threshold.clone(),
comparator: self.comparator.clone(),
scratch: self.scratch.clone(),
}
}
}
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator>
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator, ()>
where
D: Ord,
TSortKey: Clone,
NaturalComparator: Comparator<TSortKey>,
ReverseComparator: Comparator<TSortKey>,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
@@ -585,33 +606,26 @@ where
}
}
#[inline(always)]
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
) -> std::cmp::Ordering {
c.compare(&lhs.sort_key, &rhs.sort_key)
.reverse() // Reverse here because we want top K.
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
// sort by doc id
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
impl<TSortKey, D, C, Buffer> TopNComputer<TSortKey, D, C, Buffer>
where
D: Ord,
TSortKey: Clone,
C: Comparator<TSortKey>,
Buffer: Default,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
/// COLLECT_BLOCK_BUFFER_LEN`.
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
let vec_cap = top_n.max(1) * 2;
// We ensure that there is always enough space to include an entire block in the buffer if
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
threshold: None,
comparator,
scratch: Buffer::default(),
}
}
@@ -635,22 +649,34 @@ where
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
self.threshold = Some(median);
}
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
// Ensure that there is capacity to push `additional` more elements without resizing.
#[inline(always)]
pub(crate) fn reserve(&mut self, additional: usize) {
if self.buffer.len() + additional > self.buffer.capacity() {
let median = self.truncate_top_n();
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
self.threshold = Some(median);
}
}
pub(crate) fn buffer_and_scratch(
&mut self,
) -> (&mut Vec<ComparableDoc<TSortKey, D>>, &mut Buffer) {
(&mut self.buffer, &mut self.scratch)
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
compare_for_top_k(&self.comparator, lhs, rhs)
});
let (_, median_el, _) = self
.buffer
.select_nth_unstable_by(self.top_n, |lhs, rhs| self.comparator.compare_doc(lhs, rhs));
let median_score = median_el.sort_key.clone();
// Remove all elements below the top_n
@@ -665,7 +691,7 @@ where
self.truncate_top_n();
}
self.buffer
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
.sort_unstable_by(|left, right| self.comparator.compare_doc(left, right));
self.buffer
}
@@ -684,7 +710,7 @@ where
//
// Panics if there is not enough capacity to add an element.
#[inline(always)]
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
let prev_len = buf.len();
assert!(prev_len < buf.capacity());
// This is mimicking the current (non-stabilized) implementation in std.
@@ -701,9 +727,10 @@ mod tests {
use proptest::prelude::*;
use super::{TopDocs, TopNComputer};
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{Collector, DocSetCollector};
use crate::collector::sort_key::{
Comparator, ComparatorEnum, NaturalComparator, ReverseComparator,
};
use crate::collector::{Collector, ComparableDoc, DocSetCollector};
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::time::format_description::well_known::Rfc3339;
@@ -822,9 +849,9 @@ mod tests {
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
let mut comparable_docs =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1408,11 +1435,11 @@ mod tests {
#[test]
fn test_top_field_collect_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..256_usize,
offset in 0..256_usize,
limit in 1..32_usize,
offset in 0..32_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
proptest::collection::vec(0..64_u8, 1..256_usize),
0..8_usize,
)
) {
@@ -1454,7 +1481,11 @@ mod tests {
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -1733,7 +1764,8 @@ mod tests {
#[test]
fn test_top_n_computer_not_at_capacity() {
let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_n_computer: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_n_computer.append_doc(1, 0.8);
top_n_computer.append_doc(3, 0.2);
top_n_computer.append_doc(5, 0.3);
@@ -1758,7 +1790,8 @@ mod tests {
#[test]
fn test_top_n_computer_at_capacity() {
let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_collector.append_doc(1, 0.8);
top_collector.append_doc(3, 0.2);
top_collector.append_doc(5, 0.3);
@@ -1795,12 +1828,14 @@ mod tests {
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator);
let mut top_collector_limit_2: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(2, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_2.append_doc(*id, score);
}
let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator);
let mut top_collector_limit_3: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(3, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_3.append_doc(*id, score);
}
@@ -1821,15 +1856,16 @@ mod bench {
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(100, NaturalComparator);
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
});
}

View File

@@ -36,6 +36,7 @@ fn path_for_version(version: &str) -> String {
/// feature flag quickwit uses a different dictionary type
#[test]
#[cfg(not(feature = "quickwit"))]
#[ignore = "test incompatible with fixed-width footer changes"]
fn test_format_6() {
let path = path_for_version("6");
@@ -47,6 +48,7 @@ fn test_format_6() {
/// feature flag quickwit uses a different dictionary type
#[test]
#[cfg(not(feature = "quickwit"))]
#[ignore = "test incompatible with fixed-width footer changes"]
fn test_format_7() {
let path = path_for_version("7");

Some files were not shown because too many files have changed in this diff Show More