Compare commits

..

22 Commits

Author SHA1 Message Date
Paul Masurel
643639f14b Introduced geopoint. 2025-12-03 17:05:27 +01:00
Paul Masurel
f85a27068d Introduced geopoint. 2025-12-03 17:05:16 +01:00
Paul Masurel
1619e05bc5 plastic surgery 2025-12-03 16:20:18 +01:00
Paul Masurel
5d03c600ba Added bugfix and unit tests
Removed use of robust.
2025-12-03 15:21:37 +01:00
Paul Masurel
32beb06382 plastic surgery 2025-12-03 13:02:10 +01:00
Paul Masurel
d8bc0e7c99 added doc 2025-12-03 12:41:17 +01:00
Paul Masurel
79622f1f0b bugfix 2025-12-01 17:13:57 +01:00
Alan Gutierrez
d26d6c34fc Fix select_nth_unstable_by_key midpoint duplicates.
Existing code behaved as if the result of `select_nth_unstable_by_key`
was either a sorted array or the product of an algorithm that gathered
partition values as in the Dutch national flag problem. The existing
code was written knowing that the former isn't true and the latter isn't
advertised. Knowing, but not remembering. Quite the oversight.
2025-12-01 16:49:22 +01:00
Alan Gutierrez
6da54fa5da Revert "Remove radix_select.rs."
This reverts commit 19eab167b6.

Restore radix select in order to implement a merge solution that will
not require a temporary file.
2025-12-01 16:49:21 +01:00
Alan Gutierrez
9f10279681 Complete Spatial/Geometry type integration.
Addressed all `todo!()` markers created when adding `Spatial` field type
and `Geometry` value type to existing code paths:

- Dynamic field handling: `Geometry` not supported in dynamic JSON
  fields, return `unimplemented!()` consistently with other complex
  types.
- Fast field writer: Panic if geometry routed incorrectly (internal
  error.)
- `OwnedValue` serialization: Implement `Geometry` to GeoJSON
  serialization and reference-to-owned conversion.
- Field type: Return `None` for `get_index_record_option()` since
  spatial fields use BKD trees, not inverted index.
- Space usage tracking: Add spatial field to `SegmentSpaceUsage` with
  proper integration through `SegmentReader`.
- Spatial query explain: Implement `explain()` method following pattern of
  other binary/constant-score queries.

Fixed `MultiPolygon` deserialization bug: count total points across all
rings, not number of rings.

Added clippy expects for legitimate too_many_arguments cases in geometric
predicates.
2025-12-01 16:49:21 +01:00
Alan Gutierrez
68009bb25b Read block kd-tree nodes using from_le_bytes.
Read node structures using `from_le_bytes` instead of casting memory.
After an inspection of columnar storage, it appears that this is the
standard practice in Rust and in the Tantivy code base. Left the
structure alignment for now in case it tends to align with cache
boundaries.
2025-12-01 16:49:20 +01:00
Alan Gutierrez
459456ca28 Remove radix_select.rs.
Ended up using `select_nth_unstable_by_key` from the Rust standard
library instead.
2025-12-01 16:49:20 +01:00
Alan Gutierrez
dbbc8c3f65 Slot block kd-tree into Tantivy.
Implemented a geometry document field with a minimal `Geometry` enum.
Now able to add that Geometry from GeoJSON parsed from a JSON document.
Geometry is triangulated if it is a polygon, otherwise it is correctly
encoded as a degenerate triangle if it is a point or a line string.
Write accumulated triangles to a block kd-tree on commit.

Serialize the original `f64` polygon for retrieval from search.

Created a query method for intersection. Query against the memory mapped
block kd-tree. Return hits and original `f64` polygon.

Implemented a merge of one or more block kd-trees from one or more
segments during merge.

Updated the block kd-tree to write to a Tantivy `WritePtr` instead of
more generic Rust I/O.
2025-12-01 16:49:16 +01:00
Alan Gutierrez
d3049cb323 Triangulation is not just a conversion.
The triangulation function in `triangle.rs` is now called
`delaunay_to_triangles` and it accepts the output of a Delaunay
triangulation from `i_triangle` and not a GeoRust multi-polygon. The
translation of user polygons to `i_triangle` polygons and subsequent
triangulation will take place outside of `triangle.rs`.
2025-12-01 16:48:34 +01:00
Alan Gutierrez
ccdf399cd7 XOR delta compression for f64 polygon coordinates.
Lossless compression for floating-point lat/lon coordinates using XOR
delta encoding on IEEE 754 bit patterns with variable-length integer
encoding. Designed for per-polygon random access in the document store,
where each polygon compresses independently without requiring sequential
decompression.
2025-12-01 16:48:33 +01:00
Alan Gutierrez
2dc46b235e Implement block kd-tree.
Implement an immutable bulk-loaded spatial index using recursive median
partitioning on bounding box dimensions. Each leaf stores up to 512
triangles with delta-compressed coordinates and doc IDs. The tree
provides three query types (intersects, within, contains) that use exact
integer arithmetic for geometric predicates and accumulate results in
bit sets for efficient deduplication across leaves.

The serialized format stores compressed leaf pages followed by the tree
structure (leaf and branch nodes), enabling zero-copy access through
memory-mapped segments without upfront decompression.
2025-12-01 16:48:32 +01:00
Alan Gutierrez
f38140f72f Add delta compression for block kd-tree leaf nodes.
Implements dimension-major bit-packing with zigzag encoding for signed i32
deltas, enabling compression of spatially-clustered triangles from 32-bit
coordinates down to 4-19 bits per delta depending on spatial extent.
2025-12-01 16:48:32 +01:00
Alan Gutierrez
0996bea7ac Add a surveyor to determine spread and prefix.
Implemented a `Surveyor` that will evaluate the bounding boxes of a set
of triangles and determine the dimension with the maximum spread and the
shared prefix for the values of dimension with the maximum spread.
2025-12-01 16:48:31 +01:00
Alan Gutierrez
1c66567efc Radix selection for block kd-tree partitioning.
Implemented byte-wise histogram selection to find median values without
comparisons, enabling efficient partitioning of spatial data during
block kd-tree construction. Processes values through multiple passes,
building histograms for each byte position after a common prefix,
avoiding the need to sort or compare elements directly.
2025-12-01 16:48:31 +01:00
Alan Gutierrez
b2a9bb279d Implement polygon tessellation.
The `triangulate` function takes a polygon with floating-point lat/lon
coordinates, converts to integer coordinates with millimeter precision
(using 2^32 scaling), performs constrained Delaunay triangulation, and
encodes the resulting triangles with boundary edge information for block
kd-tree spatial indexing.

It handles polygons with holes correctly, preserving which triangle
edges lie on the original polygon boundaries versus internal
tessellation edges.
2025-12-01 16:48:26 +01:00
Alan Gutierrez
558c99fa2d Triangle encoding for spatial indexing.
Encodes triangles with the bounding box in the first four words,
enabling efficient spatial pruning during tree traversal without
reconstructing the full triangle. The remaining words contain an
additional vertex and packed reconstruction metadata, allowing exact
triangle recovery when needed.
2025-12-01 16:47:56 +01:00
Alan Gutierrez
43b5f34721 Implement SPATIAL flag.
Implement a SPATIAL flag for use in creating a spatial field.
2025-12-01 16:47:55 +01:00
192 changed files with 5962 additions and 12373 deletions

View File

@@ -1,125 +0,0 @@
---
name: rationalize-deps
description: Analyze Cargo.toml dependencies and attempt to remove unused features to reduce compile times and binary size
---
# Rationalize Dependencies
This skill analyzes Cargo.toml dependencies to identify and remove unused features.
## Overview
Many crates enable features by default that may not be needed. This skill:
1. Identifies dependencies with default features enabled
2. Tests if `default-features = false` works
3. Identifies which specific features are actually needed
4. Verifies compilation after changes
## Step 1: Identify the target
Ask the user which crate(s) to analyze:
- A specific crate name (e.g., "tokio", "serde")
- A specific workspace member (e.g., "quickwit-search")
- "all" to scan the entire workspace
## Step 2: Analyze current dependencies
For the workspace Cargo.toml (`quickwit/Cargo.toml`), list dependencies that:
- Do NOT have `default-features = false`
- Have default features that might be unnecessary
Run: `cargo tree -p <crate> -f "{p} {f}" --edges features` to see what features are actually used.
## Step 3: For each candidate dependency
### 3a: Check the crate's default features
Look up the crate on crates.io or check its Cargo.toml to understand:
- What features are enabled by default
- What each feature provides
Use: `cargo metadata --format-version=1 | jq '.packages[] | select(.name == "<crate>") | .features'`
### 3b: Try disabling default features
Modify the dependency in `quickwit/Cargo.toml`:
From:
```toml
some-crate = { version = "1.0" }
```
To:
```toml
some-crate = { version = "1.0", default-features = false }
```
### 3c: Run cargo check
Run: `cargo check --workspace` (or target specific packages for faster feedback)
If compilation fails:
1. Read the error messages to identify which features are needed
2. Add only the required features explicitly:
```toml
some-crate = { version = "1.0", default-features = false, features = ["needed-feature"] }
```
3. Re-run cargo check
### 3d: Binary search for minimal features
If there are many default features, use binary search:
1. Start with no features
2. If it fails, add half the default features
3. Continue until you find the minimal set
## Step 4: Document findings
For each dependency analyzed, report:
- Original configuration
- New configuration (if changed)
- Features that were removed
- Any features that are required
## Step 5: Verify full build
After all changes, run:
```bash
cargo check --workspace --all-targets
cargo test --workspace --no-run
```
## Common Patterns
### Serde
Often only needs `derive`:
```toml
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
```
### Tokio
Identify which runtime features are actually used:
```toml
tokio = { version = "1.0", default-features = false, features = ["rt-multi-thread", "macros", "sync"] }
```
### Reqwest
Often doesn't need all TLS backends:
```toml
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
```
## Rollback
If changes cause issues:
```bash
git checkout quickwit/Cargo.toml
cargo check --workspace
```
## Tips
- Start with large crates that have many default features (tokio, reqwest, hyper)
- Use `cargo bloat --crates` to identify large dependencies
- Check `cargo tree -d` for duplicate dependencies that might indicate feature conflicts
- Some features are needed only for tests - consider using `[dev-dependencies]` features

View File

@@ -1,60 +0,0 @@
---
name: simple-pr
description: Create a simple PR from staged changes with an auto-generated commit message
disable-model-invocation: true
---
# Simple PR
Follow these steps to create a simple PR from staged changes:
## Step 1: Check workspace state
Run: `git status`
Verify that all changes have been staged (no unstaged changes). If there are unstaged changes, abort and ask the user to stage their changes first with `git add`.
Also verify that we are on the `main` branch. If not, abort and ask the user to switch to main first.
## Step 2: Ensure main is up to date
Run: `git pull origin main`
This ensures we're working from the latest code.
## Step 3: Review staged changes
Run: `git diff --cached`
Review the staged changes to understand what the PR will contain.
## Step 4: Generate commit message
Based on the staged changes, generate a concise commit message (1-2 sentences) that describes the "why" rather than the "what".
Display the proposed commit message to the user and ask for confirmation before proceeding.
## Step 5: Create a new branch
Get the git username: `git config user.name | tr ' ' '-' | tr '[:upper:]' '[:lower:]'`
Create a short, descriptive branch name based on the changes (e.g., `fix-typo-in-readme`, `add-retry-logic`, `update-deps`).
Create and checkout the branch: `git checkout -b {username}/{short-descriptive-name}`
## Step 6: Commit changes
Commit with the message from step 3:
```
git commit -m "{commit-message}"
```
## Step 7: Push and open a PR
Push the branch and open a PR:
```
git push -u origin {branch-name}
gh pr create --title "{commit-message-title}" --body "{longer-description-if-needed}"
```
Report the PR URL to the user when complete.

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -39,11 +39,11 @@ jobs:
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
- name: Check Stable Compilation
run: cargo build --all-features
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
@@ -59,10 +59,10 @@ jobs:
strategy:
matrix:
features:
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
- { label: "none", flags: "" }
features: [
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
]
name: test-${{ matrix.features.label}}
@@ -80,21 +80,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
cargo +stable nextest run --lib --no-default-features --verbose --workspace
else
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
fi
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
- name: Run doctests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
echo "no doctest for no feature flag"
else
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
fi
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace

View File

@@ -15,7 +15,7 @@ rust-version = "1.85"
exclude = ["benches/*.json", "benches/*.txt"]
[dependencies]
oneshot = "0.1.13"
oneshot = "0.1.7"
base64 = "0.22.0"
byteorder = "1.4.3"
crc32fast = "1.3.2"
@@ -27,7 +27,7 @@ regex = { version = "1.5.5", default-features = false, features = [
aho-corasick = "1.0"
tantivy-fst = "0.5"
memmap2 = { version = "0.9.0", optional = true }
lz4_flex = { version = "0.12", default-features = false, optional = true }
lz4_flex = { version = "0.11", default-features = false, optional = true }
zstd = { version = "0.13", optional = true, default-features = false }
tempfile = { version = "3.12.0", optional = true }
log = "0.4.16"
@@ -37,9 +37,9 @@ fs4 = { version = "0.13.1", optional = true }
levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = { version = "1.2.0", optional = true }
rust-stemmers = "1.2.0"
downcast-rs = "2.0.1"
bitpacking = { version = "0.9.3", default-features = false, features = [
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
] }
census = "0.4.2"
@@ -47,15 +47,16 @@ rustc-hash = "2.0.0"
thiserror = "2.0.1"
htmlescape = "0.3.1"
fail = { version = "0.5.0", optional = true }
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.35", features = ["serde-well-known"] }
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.16.3"
lru = "0.12.0"
fastdivide = "0.4.0"
itertools = "0.14.0"
measure_time = "0.9.0"
arc-swap = "1.5.0"
bon = "3.3.1"
i_triangle = "0.38.0"
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
@@ -64,29 +65,30 @@ query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tanti
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { git = "https://github.com/quickwit-oss/rust-sketches-ddsketch.git", rev = "555caf1", features = ["use_serde"] }
datasketches = "0.2.0"
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
geo-types = "0.7.17"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.2"
rand = "0.9"
binggan = "0.14.0"
rand = "0.8.5"
maplit = "1.0.2"
matches = "0.1.9"
pretty_assertions = "1.2.1"
proptest = "1.7.0"
proptest = "1.0.0"
test-log = "0.2.10"
futures = "0.3.21"
paste = "1.0.11"
more-asserts = "0.3.1"
rand_distr = "0.5"
time = { version = "0.3.47", features = ["serde-well-known", "macros"] }
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
], default-features = false }
@@ -113,8 +115,7 @@ debug-assertions = true
overflow-checks = true
[features]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
stemmer = ["rust-stemmers"]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
mmap = ["fs4", "tempfile", "memmap2"]
stopwords = []
@@ -174,30 +175,6 @@ harness = false
name = "exists_json"
harness = false
[[bench]]
name = "range_query"
harness = false
[[bench]]
name = "and_or_queries"
harness = false
[[bench]]
name = "range_queries"
harness = false
[[bench]]
name = "bool_queries_with_range"
harness = false
[[bench]]
name = "str_search_and_get"
harness = false
[[bench]]
name = "merge_segments"
harness = false
[[bench]]
name = "regex_all_terms"
harness = false

View File

@@ -123,7 +123,6 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -1,8 +1,7 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distr::weighted::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::seq::IndexedRandom;
use rand::{Rng, SeedableRng};
use rand_distr::Distribution;
use serde_json::json;
@@ -10,7 +9,7 @@ use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::{AllQuery, TermQuery};
use tantivy::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
use tantivy::{doc, DateTime, Index, Term};
use tantivy::{doc, Index, Term};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
@@ -54,39 +53,27 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, composite_term_many_page_1000);
register!(group, composite_term_many_page_1000_with_avg_sub_agg);
register!(group, composite_term_few);
register!(group, composite_histogram);
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_status);
register!(group, histogram_with_term_agg_few);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
@@ -145,12 +132,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
});
execute_agg(index, agg_req);
}
@@ -165,10 +152,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
fn terms_few_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"terms": { "field": "text_few_terms" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -181,20 +168,13 @@ fn terms_status_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
fn terms_few(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
"my_texts": { "terms": { "field": "text_few_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
fn terms_many(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -242,10 +222,11 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
fn terms_few_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"terms": { "field": "text_few_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
@@ -253,60 +234,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
@@ -320,75 +247,6 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn composite_term_few(index: &Index) {
let agg_req = json!({
"my_ctf": {
"composite": {
"sources": [
{ "text_few_terms": { "terms": { "field": "text_few_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000(index: &Index) {
let agg_req = json!({
"my_ctmp1000": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_ctmp1000wasa": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000,
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram(index: &Index) {
let agg_req = json!({
"my_ch": {
"composite": {
"sources": [
{ "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram_calendar(index: &Index) {
let agg_req = json!({
"my_chc": {
"composite": {
"sources": [
{ "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
let collector = get_collector(agg_req);
@@ -432,7 +290,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_status(index: &Index) {
fn range_agg_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -447,7 +305,7 @@ fn range_agg_with_term_agg_status(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } },
"my_texts": { "terms": { "field": "text_few_terms" } },
}
},
});
@@ -503,12 +361,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_status(index: &Index) {
fn histogram_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } }
"my_texts": { "terms": { "field": "text_few_terms" } }
}
}
});
@@ -553,13 +411,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -568,50 +419,20 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let date_field = schema_builder.add_date_field("timestamp", FAST);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000.0, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -621,27 +442,15 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -656,8 +465,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
}
let _val_max = 1_000_000.0;
for _ in 0..doc_with_value {
let val: f64 = rng.random_range(0.0..1_000_000.0);
let json = if rng.random_bool(0.1) {
let val: f64 = rng.gen_range(0.0..1_000_000.0);
let json = if rng.gen_bool(0.1) {
// 10% are numeric values
json!({ "mixed_type": val })
} else {
@@ -666,15 +475,11 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {
@@ -723,7 +528,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
"terms": { "field": "text_few_terms" }
}
}
}
@@ -739,7 +544,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
"terms": { "field": "text_few_terms" }
}
}
}

View File

@@ -55,29 +55,29 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.random_bool(p_a as f64);
let has_b = rng.random_bool(p_b as f64);
let has_c = rng.random_bool(p_c as f64);
let score = rng.random_range(0u64..100u64);
let score2 = rng.random_range(0u64..100_000u64);
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.random_bool(0.1) {
if rng.gen_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");

View File

@@ -1,288 +0,0 @@
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
// Unified schema
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.random_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.random_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.random_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.random_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build query parser for title field
let qp_title = QueryParser::for_index(&index, vec![f_title]);
BenchIndex {
index,
searcher,
query_parser: qp_title,
}
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
0,
9,
),
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
990,
999,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
0,
9,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
9_999_990,
9_999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define all four field types
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
// Define the three terms we want to test with
let terms = ["a", "b", "z"];
// Generate all combinations of terms and field names
let mut queries = Vec::new();
for &term in &terms {
for &field_name in &field_names {
let query_str = format!(
"{} AND {}:[{} TO {}]",
term, field_name, range_low, range_high
);
queries.push((query_str, field_name.to_string()));
}
}
let query_str = format!(
"{}:[{} TO {}] AND {}:[{} TO {}]",
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
);
queries.push((query_str, "num_asc_fast".to_string()));
// Run all benchmark tasks for each query and its corresponding field name
for (query_str, field_name) in queries {
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
}
group.run();
}
}
/// Run all benchmark tasks for a given query string and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
field_name: &str,
) {
// Test count
add_bench_task(bench_group, bench_index, query_str, Count, "count");
// Test all results
add_bench_task(
bench_group,
bench_index,
query_str,
DocSetCollector,
"all results",
);
// Test top 100 by the field (if it's a FAST field)
if field_name.ends_with("_fast") {
// Ascending order
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
&collector_name,
);
}
// Descending order
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
&collector_name,
);
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &self.collector).unwrap();
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
*count
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(top_docs) =
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
{
doc_set.len()
} else {
eprintln!(
"Unknown collector result type: {:?}",
std::any::type_name::<C::Fruit>()
);
0
}
}
}

View File

@@ -1,224 +0,0 @@
// Benchmarks segment merging
//
// Notes:
// - Input segments are kept intact (no deletes / no IndexWriter merge).
// - Output is written to a `NullDirectory` that discards all files except
// fieldnorms (needed for merging).
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::directory::error::{DeleteError, OpenReadError, OpenWriteError};
use tantivy::directory::{
AntiCallToken, Directory, FileHandle, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle,
WritePtr,
};
use tantivy::indexer::{merge_filtered_segments, NoMergePolicy};
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, HasLen, Index, IndexSettings, Segment};
#[derive(Clone, Default, Debug)]
struct NullDirectory {
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
struct NullWriter;
impl Write for NullWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for NullWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
Ok(())
}
}
struct InMemoryWriter {
path: PathBuf,
buffer: Vec<u8>,
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
impl Write for InMemoryWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for InMemoryWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
let bytes = OwnedBytes::new(std::mem::take(&mut self.buffer));
self.blobs.write().unwrap().insert(self.path.clone(), bytes);
Ok(())
}
}
#[derive(Debug, Default)]
struct NullFileHandle;
impl HasLen for NullFileHandle {
fn len(&self) -> usize {
0
}
}
impl FileHandle for NullFileHandle {
fn read_bytes(&self, _range: std::ops::Range<usize>) -> io::Result<OwnedBytes> {
unimplemented!()
}
}
impl Directory for NullDirectory {
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(Arc::new(bytes.clone()));
}
Ok(Arc::new(NullFileHandle))
}
fn delete(&self, _path: &Path) -> Result<(), DeleteError> {
Ok(())
}
fn exists(&self, _path: &Path) -> Result<bool, OpenReadError> {
Ok(true)
}
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
let path_buf = path.to_path_buf();
if path.to_string_lossy().ends_with(".fieldnorm") {
let writer = InMemoryWriter {
path: path_buf,
buffer: Vec::new(),
blobs: Arc::clone(&self.blobs),
};
Ok(io::BufWriter::new(Box::new(writer)))
} else {
Ok(io::BufWriter::new(Box::new(NullWriter)))
}
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(bytes.as_slice().to_vec());
}
Err(OpenReadError::FileDoesNotExist(path.to_path_buf()))
}
fn atomic_write(&self, _path: &Path, _data: &[u8]) -> io::Result<()> {
Ok(())
}
fn sync_directory(&self) -> io::Result<()> {
Ok(())
}
fn watch(&self, _watch_callback: WatchCallback) -> tantivy::Result<WatchHandle> {
Ok(WatchHandle::empty())
}
}
struct MergeScenario {
#[allow(dead_code)]
index: Index,
segments: Vec<Segment>,
settings: IndexSettings,
label: String,
}
fn build_index(
num_segments: usize,
docs_per_segment: usize,
tokens_per_doc: usize,
vocab_size: usize,
) -> MergeScenario {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
assert!(vocab_size > 0);
let total_tokens = num_segments * docs_per_segment * tokens_per_doc;
let use_unique_terms = vocab_size >= total_tokens;
let mut rng = StdRng::from_seed([7u8; 32]);
let mut next_token_id: u64 = 0;
{
let mut writer = index.writer_with_num_threads(1, 256_000_000).unwrap();
writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..num_segments {
for _ in 0..docs_per_segment {
let mut tokens = Vec::with_capacity(tokens_per_doc);
for _ in 0..tokens_per_doc {
let token_id = if use_unique_terms {
let id = next_token_id;
next_token_id += 1;
id
} else {
rng.random_range(0..vocab_size as u64)
};
tokens.push(format!("term_{token_id}"));
}
writer.add_document(doc!(body => tokens.join(" "))).unwrap();
}
writer.commit().unwrap();
}
}
let segments = index.searchable_segments().unwrap();
let settings = index.settings().clone();
let label = format!(
"segments={}, docs/seg={}, tokens/doc={}, vocab={}",
num_segments, docs_per_segment, tokens_per_doc, vocab_size
);
MergeScenario {
index,
segments,
settings,
label,
}
}
fn main() {
let scenarios = vec![
build_index(8, 50_000, 12, 8),
build_index(16, 50_000, 12, 8),
build_index(16, 100_000, 12, 8),
build_index(8, 50_000, 8, 8 * 50_000 * 8),
];
let mut runner = BenchRunner::new();
for scenario in scenarios {
let mut group = runner.new_group();
group.set_name(format!("merge_segments inv_index — {}", scenario.label));
let segments = scenario.segments.clone();
let settings = scenario.settings.clone();
group.register("merge", move |_| {
let output_dir = NullDirectory::default();
let filter_doc_ids = vec![None; segments.len()];
let merged_index =
merge_filtered_segments(&segments, settings.clone(), filter_doc_ids, output_dir)
.unwrap();
black_box(merged_index);
});
group.run();
}
}

View File

@@ -1,365 +0,0 @@
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector, TopDocs};
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, FAST, INDEXED};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with fast fields only
let mut schema_builder = Schema::builder();
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
let num_rand = rng.random_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
let num_rand = rng.random_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
// Dense distribution - random values in small range (0-999)
(
"dense_values_search_low_value_range".to_string(),
10_000_000,
"dense",
0,
9,
),
(
"dense_values_search_high_value_range".to_string(),
10_000_000,
"dense",
990,
999,
),
(
"dense_values_search_out_of_range".to_string(),
10_000_000,
"dense",
1000,
1002,
),
(
"sparse_values_search_low_value_range".to_string(),
10_000_000,
"sparse",
0,
9,
),
(
"sparse_values_search_high_value_range".to_string(),
10_000_000,
"sparse",
9_999_990,
9_999_999,
),
(
"sparse_values_search_out_of_range".to_string(),
10_000_000,
"sparse",
10_000_000,
10_000_002,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define fast field types
let field_names = ["num_rand_fast", "num_asc_fast"];
// Generate range queries for fast fields
for &field_name in &field_names {
// Create the range query
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
let lower_term = Term::from_field_u64(field, range_low);
let upper_term = Term::from_field_u64(field, range_high);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(
&mut group,
&bench_index,
query,
field_name,
range_low,
range_high,
);
}
group.run();
}
}
/// Run all benchmark tasks for a given range query and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
field_name: &str,
range_low: u64,
range_high: u64,
) {
// Test count
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
"count",
field_name,
range_low,
range_high,
);
// Test top 100 by the field (ascending order)
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_asc(
bench_group,
bench_index,
query.clone(),
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
// Test top 100 by the field (descending order)
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_desc(
bench_group,
bench_index,
query,
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_asc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100AscSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_desc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100DescSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct Top100AscSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100AscSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}
struct Top100DescSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100DescSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}

View File

@@ -1,260 +0,0 @@
use std::fmt::Display;
use std::net::Ipv6Addr;
use std::ops::RangeInclusive;
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use columnar::MonotonicallyMappableToU128;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
bench_range_query();
}
fn bench_range_query() {
let index = get_index_0_to_100();
let mut runner = BenchRunner::new();
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
runner.set_name("range_query on u64");
let field_name_and_descr: Vec<_> = vec![
("id", "Single Valued Range Field"),
("ids", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent()),
("10_percent", get_10_percent()),
("1_percent", get_1_percent()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
runner.set_name("range_query on ip");
let field_name_and_descr: Vec<_> = vec![
("ip", "Single Valued Range Field"),
("ips", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent_ip()),
("10_percent", get_10_percent_ip()),
("1_percent", get_1_percent_ip()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
}
fn test_range<T: Display>(
runner: &mut BenchRunner,
index: &Index,
field_name_and_descr: &[(&str, &str)],
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
) {
for (field, suffix) in field_name_and_descr {
let term_num_hits = vec![
("", ""),
("1_percent", "veryfew"),
("10_percent", "few"),
("90_percent", "most"),
];
let mut group = runner.new_group();
group.set_name(suffix);
// all intersect combinations
for (range_name, range) in &range_num_hits {
for (term_name, term) in &term_num_hits {
let index = &index;
let test_name = if term_name.is_empty() {
format!("id_range_hit_{}", range_name)
} else {
format!(
"id_range_hit_{}_intersect_with_term_{}",
range_name, term_name
)
};
group.register(test_name, move |_| {
let query = if term_name.is_empty() {
"".to_string()
} else {
format!("AND id_name:{}", term)
};
black_box(execute_query(field, range, &query, index));
});
}
}
group.run();
}
}
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.random_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.random_bool(0.1) {
"few".to_string() // 9%
} else {
"most".to_string() // 90%
};
Doc {
id_name,
id: rng.random_range(0..100),
// Multiply by 1000, so that we create most buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.random_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
#[derive(Clone, Debug)]
pub struct Doc {
pub id_name: String,
pub id: u64,
pub ip: Ipv6Addr,
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
let ids_u64_field =
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
let ids_f64_field = schema_builder.add_f64_field(
"ids_f64",
NumericOptions::default().set_fast().set_indexed(),
);
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
let ids_i64_field = schema_builder.add_i64_field(
"ids_i64",
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ids_i64_field => doc.id as i64,
ids_i64_field => doc.id as i64,
ids_f64_field => doc.id as f64,
ids_f64_field => doc.id as f64,
ids_u64_field => doc.id,
ids_u64_field => doc.id,
id_u64_field => doc.id,
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
struct NumHits {
count: usize,
}
impl OutputValue for NumHits {
fn column_title() -> &'static str {
"NumHits"
}
fn format(&self) -> Option<String> {
Some(self.count.to_string())
}
}
fn execute_query<T: Display>(
field: &str,
id_range: &RangeInclusive<T>,
suffix: &str,
index: &Index,
) -> NumHits {
let gen_query_inclusive = |from: &T, to: &T| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
execute_query_(&query, index)
}
fn execute_query_(query: &str, index: &Index) -> NumHits {
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let num_hits = searcher
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap()
.1;
NumHits { count: num_hits }
}

View File

@@ -1,113 +0,0 @@
// Benchmarks regex query that matches all terms in a synthetic index.
//
// Corpus model:
// - N unique terms: t000000, t000001, ...
// - M docs
// - K tokens per doc: doc i gets terms derived from (i, token_index)
//
// Query:
// - Regex "t.*" to match all terms
//
// Run with:
// - cargo bench --bench regex_all_terms
//
use std::fmt::Write;
use binggan::{black_box, BenchRunner};
use tantivy::collector::Count;
use tantivy::query::RegexQuery;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy};
const HEAP_SIZE_BYTES: usize = 200_000_000;
#[derive(Clone, Copy)]
struct BenchConfig {
num_terms: usize,
num_docs: usize,
tokens_per_doc: usize,
}
fn main() {
let configs = default_configs();
let mut runner = BenchRunner::new();
for config in configs {
let (index, text_field) = build_index(config, HEAP_SIZE_BYTES);
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.expect("reader");
let searcher = reader.searcher();
let query = RegexQuery::from_pattern("t.*", text_field).expect("regex query");
let mut group = runner.new_group();
group.set_name(format!(
"regex_all_terms_t{}_d{}_k{}",
config.num_terms, config.num_docs, config.tokens_per_doc
));
group.register("regex_count", move |_| {
let count = searcher.search(&query, &Count).expect("search");
black_box(count);
});
group.run();
}
}
fn default_configs() -> Vec<BenchConfig> {
vec![
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
]
}
fn build_index(config: BenchConfig, heap_size_bytes: usize) -> (Index, tantivy::schema::Field) {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let term_width = config.num_terms.to_string().len();
{
let mut writer = index
.writer_with_num_threads(1, heap_size_bytes)
.expect("writer");
let mut buffer = String::new();
for doc_id in 0..config.num_docs {
buffer.clear();
for token_idx in 0..config.tokens_per_doc {
if token_idx > 0 {
buffer.push(' ');
}
let term_id = (doc_id * config.tokens_per_doc + token_idx) % config.num_terms;
write!(&mut buffer, "t{term_id:0term_width$}").expect("write token");
}
writer
.add_document(doc!(text_field => buffer.as_str()))
.expect("add_document");
}
writer.commit().expect("commit");
}
(index, text_field)
}

View File

@@ -1,421 +0,0 @@
// This benchmark compares different approaches for retrieving string values:
//
// 1. Fast Field Approach: retrieves string values via term_ords() and ord_to_str()
//
// 2. Doc Store Approach: retrieves string values via searcher.doc() and field extraction
//
// The benchmark includes various data distributions:
// - Dense Sequential: Sequential document IDs with dense data
// - Dense Random: Random document IDs with dense data
// - Sparse Sequential: Sequential document IDs with sparse data
// - Sparse Random: Random document IDs with sparse data
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector};
use tantivy::query::RangeQuery;
use tantivy::schema::document::TantivyDocument;
use tantivy::schema::{Schema, Value, FAST, STORED, STRING};
use tantivy::{doc, Index, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with string fast field and stored field for doc access
let mut schema_builder = Schema::builder();
let f_str_fast = schema_builder.add_text_field("str_fast", STRING | STORED | FAST);
let f_str_stored = schema_builder.add_text_field("str_stored", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.gen_range(0u64..1000u64);
let str_val = format!("str_{:03}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"dense_sequential" => {
for doc_id in 0..num_docs {
let suffix = doc_id as u64 % 1000;
let str_val = format!("str_{:03}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"sparse_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.gen_range(0u64..1000000u64);
let str_val = format!("str_{:07}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
"sparse_sequential" => {
for doc_id in 0..num_docs {
let suffix = doc_id as u64;
let str_val = format!("str_{:07}", suffix);
writer
.add_document(doc!(
f_str_fast=>str_val.clone(),
f_str_stored=>str_val,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense_random_search_low_range".to_string(),
1_000_000,
"dense_random",
0,
9,
),
(
"dense_random_search_high_range".to_string(),
1_000_000,
"dense_random",
990,
999,
),
(
"dense_sequential_search_low_range".to_string(),
1_000_000,
"dense_sequential",
0,
9,
),
(
"dense_sequential_search_high_range".to_string(),
1_000_000,
"dense_sequential",
990,
999,
),
(
"sparse_random_search_low_range".to_string(),
1_000_000,
"sparse_random",
0,
9999,
),
(
"sparse_random_search_high_range".to_string(),
1_000_000,
"sparse_random",
990_000,
999_999,
),
(
"sparse_sequential_search_low_range".to_string(),
1_000_000,
"sparse_sequential",
0,
9999,
),
(
"sparse_sequential_search_high_range".to_string(),
1_000_000,
"sparse_sequential",
990_000,
999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, distribution, range_low, range_high) in scenarios {
let bench_index = build_shared_indices(n, distribution);
let mut group = runner.new_group();
group.set_name(scenario_id);
let field = bench_index.searcher.schema().get_field("str_fast").unwrap();
let (lower_str, upper_str) =
if distribution == "dense_sequential" || distribution == "dense_random" {
(
format!("str_{:03}", range_low),
format!("str_{:03}", range_high),
)
} else {
(
format!("str_{:07}", range_low),
format!("str_{:07}", range_high),
)
};
let lower_term = Term::from_field_text(field, &lower_str);
let upper_term = Term::from_field_text(field, &upper_str);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(&mut group, &bench_index, query, range_low, range_high);
group.run();
}
}
/// Run all benchmark tasks for a given range query
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
// Test count of matching documents
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all DocIds of matching documents
add_bench_task_docset(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all string fast field values of matching documents
add_bench_task_fetch_all_strings(
bench_group,
bench_index,
query.clone(),
range_low,
range_high,
);
// Test fetching all string values of matching documents through doc() method
add_bench_task_fetch_all_strings_from_doc(
bench_group,
bench_index,
query,
range_low,
range_high,
);
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!("string_search_count_[{}-{}]", range_low, range_high);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!("string_fetch_all_docset_[{}-{}]", range_low, range_high);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_fetch_all_strings(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"string_fastfield_fetch_all_strings_[{}-{}]",
range_low, range_high
);
let search_task = FetchAllStringsSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| {
let result = black_box(search_task.run());
result.len()
});
}
fn add_bench_task_fetch_all_strings_from_doc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"string_doc_fetch_all_strings_[{}-{}]",
range_low, range_high
);
let search_task = FetchAllStringsFromDocTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| {
let result = black_box(search_task.run());
result.len()
});
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct FetchAllStringsSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl FetchAllStringsSearchTask {
#[inline(never)]
pub fn run(&self) -> Vec<String> {
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
docs.sort();
let mut strings = Vec::with_capacity(docs.len());
for doc_address in docs {
let segment_reader = &self.searcher.segment_readers()[doc_address.segment_ord as usize];
let str_column_opt = segment_reader.fast_fields().str("str_fast");
if let Ok(Some(str_column)) = str_column_opt {
let doc_id = doc_address.doc_id;
let term_ord = str_column.term_ords(doc_id).next().unwrap();
let mut str_buffer = String::new();
if str_column.ord_to_str(term_ord, &mut str_buffer).is_ok() {
strings.push(str_buffer);
}
}
}
strings
}
}
struct FetchAllStringsFromDocTask {
searcher: Searcher,
query: RangeQuery,
}
impl FetchAllStringsFromDocTask {
#[inline(never)]
pub fn run(&self) -> Vec<String> {
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
docs.sort();
let mut strings = Vec::with_capacity(docs.len());
let str_stored_field = self
.searcher
.schema()
.get_field("str_stored")
.expect("str_stored field should exist");
for doc_address in docs {
// Get the document from the doc store (row store access)
if let Ok(doc) = self.searcher.doc::<TantivyDocument>(doc_address) {
// Extract string values from the stored field
if let Some(field_value) = doc.get_first(str_stored_field) {
if let Some(text) = field_value.as_value().as_str() {
strings.push(text.to_string());
}
}
}
}
strings
}
}

View File

@@ -18,5 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy"
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
[dev-dependencies]
rand = "0.9"
rand = "0.8"
proptest = "1"

View File

@@ -4,8 +4,8 @@ extern crate test;
#[cfg(test)]
mod tests {
use rand::rng;
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
@@ -27,7 +27,7 @@ mod tests {
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut thread_rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for &idx in &idxs {

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,19 +66,17 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
unsafe { output_tail.offset_from(output) as usize }
output_tail.offset_from(output) as usize
}
#[inline]
@@ -94,7 +92,8 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
}
union U8x32 {

View File

@@ -22,7 +22,7 @@ downcast-rs = "2.0.1"
[dev-dependencies]
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
rand = "0.8"
binggan = "0.14.0"
[[bench]]

View File

@@ -9,7 +9,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_co
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55_000_u64)
.map(|num| num + rng.random::<u8>() as u64)
.map(|num| num + rng.r#gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);

View File

@@ -6,7 +6,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_valu
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55_000_u64)
.map(|num| num + rng.random::<u8>() as u64)
.map(|num| num + rng.r#gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);

View File

@@ -8,7 +8,7 @@ const TOTAL_NUM_VALUES: u32 = 1_000_000;
fn gen_optional_index(fill_ratio: f64) -> OptionalIndex {
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let vals: Vec<u32> = (0..TOTAL_NUM_VALUES)
.map(|_| rng.random_bool(fill_ratio))
.map(|_| rng.gen_bool(fill_ratio))
.enumerate()
.filter(|(_pos, val)| *val)
.map(|(pos, _)| pos as u32)
@@ -25,7 +25,7 @@ fn random_range_iterator(
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let mut current = start;
std::iter::from_fn(move || {
current += rng.random_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
if current >= end { None } else { Some(current) }
})
}

View File

@@ -39,7 +39,7 @@ fn get_data_50percent_item() -> Vec<u128> {
let mut data = vec![];
for _ in 0..300_000 {
let val = rng.random_range(1..=100);
let val = rng.gen_range(1..=100);
data.push(val);
}
data.push(SINGLE_ITEM);

View File

@@ -34,7 +34,7 @@ fn get_data_50percent_item() -> Vec<u128> {
let mut data = vec![];
for _ in 0..300_000 {
let val = rng.random_range(1..=100);
let val = rng.gen_range(1..=100);
data.push(val);
}
data.push(SINGLE_ITEM);

View File

@@ -29,20 +29,12 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -85,8 +85,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
#[inline]
pub fn first(&self, doc_id: DocId) -> Option<T> {
self.values_for_doc(doc_id).next()
pub fn first(&self, row_id: RowId) -> Option<T> {
self.values_for_doc(row_id).next()
}
/// Load the first value for each docid in the provided slice.

View File

@@ -41,6 +41,12 @@ fn transform_range_before_linear_transformation(
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);

View File

@@ -268,7 +268,7 @@ mod tests {
#[test]
fn linear_interpol_fast_field_rand() {
let mut rng = rand::rng();
let mut rng = rand::thread_rng();
for _ in 0..50 {
let mut data = (0..10_000).map(|_| rng.next_u64()).collect::<Vec<_>>();
create_and_validate::<LinearCodec>(&data, "random");

View File

@@ -122,7 +122,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
assert_eq!(vals, buffer);
if !vals.is_empty() {
let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1);
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals
.iter()
.enumerate()

View File

@@ -3,8 +3,7 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, OwnedBytes};
use serde::{Deserialize, Serialize};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -318,89 +317,10 @@ impl DynamicColumnHandle {
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -48,7 +48,7 @@ pub use columnar::{
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;
@@ -59,7 +59,7 @@ pub struct RowAddr {
pub row_id: RowId,
}
pub use sstable::{Dictionary, TermOrdHit};
pub use sstable::Dictionary;
pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>;
pub use common::DateTime;

View File

@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect();
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
}
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect();
assert_eq!(
&vals,
&[
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
panic!();
};
let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(str_col.num_rows(), 5);
let mut term_buffer = String::new();
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
panic!();
};
let index: Vec<Option<u64>> = (0..5)
.map(|doc_id| bytes_col.ords().first(doc_id))
.map(|row_id| bytes_col.ords().first(row_id))
.collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(bytes_col.num_rows(), 5);

View File

@@ -15,10 +15,11 @@ repository = "https://github.com/quickwit-oss/tantivy"
byteorder = "1.4.3"
ownedbytes = { version= "0.9", path="../ownedbytes" }
async-trait = "0.1"
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.10", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.14.0"
proptest = "1.0.0"
rand = "0.9"
rand = "0.8.4"

View File

@@ -1,6 +1,6 @@
use binggan::{BenchRunner, black_box};
use rand::rng;
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_common::{BitSet, TinySet, serialize_vint_u32};
fn bench_vint() {
@@ -17,7 +17,7 @@ fn bench_vint() {
black_box(out);
});
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut rng(), 100_000);
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
runner.bench_function("bench_vint_rand", move |_| {
let mut out = 0u64;
for val in vals.iter().cloned() {

View File

@@ -181,14 +181,6 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)
@@ -416,7 +408,7 @@ mod tests {
use std::collections::HashSet;
use ownedbytes::OwnedBytes;
use rand::distr::Bernoulli;
use rand::distributions::Bernoulli;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

View File

@@ -62,9 +62,7 @@ impl<W: TerminatingWrite> TerminatingWrite for CountingWriter<W> {
pub struct AntiCallToken(());
/// Trait used to indicate when no more write need to be done on a writer
///
/// Thread-safety is enforced at the call sites that require it.
pub trait TerminatingWrite: Write {
pub trait TerminatingWrite: Write + Send + Sync {
/// Indicate that the writer will no longer be used. Internally call terminate_ref.
fn terminate(mut self) -> io::Result<()>
where Self: Sized {

View File

@@ -60,7 +60,7 @@ At indexing, tantivy will try to interpret number and strings as different type
priority order.
Numbers will be interpreted as u64, i64 and f64 in that order.
Strings will be interpreted as rfc3339 dates or simple strings.
Strings will be interpreted as rfc3999 dates or simple strings.
The first working type is picked and is the only term that is emitted for indexing.
Note this interpretation happens on a per-document basis, and there is no effort to try to sniff
@@ -81,7 +81,7 @@ Will be interpreted as
(my_path.my_segment, String, 233) or (my_path.my_segment, u64, 233)
```
Likewise, we need to emit two tokens if the query contains an rfc3339 date.
Likewise, we need to emit two tokens if the query contains an rfc3999 date.
Indeed the date could have been actually a single token inside the text of a document at ingestion time. Generally speaking, we will always at least emit a string token in query parsing, and sometimes more.
If one more json field is defined, things get even more complicated.

66
examples/geo_json.rs Normal file
View File

@@ -0,0 +1,66 @@
use geo_types::Point;
use tantivy::collector::TopDocs;
use tantivy::query::SpatialQuery;
use tantivy::schema::{Schema, Value, SPATIAL, STORED, TEXT};
use tantivy::spatial::point::GeoPoint;
use tantivy::{Index, IndexWriter, TantivyDocument};
fn main() -> tantivy::Result<()> {
let mut schema_builder = Schema::builder();
schema_builder.add_json_field("properties", STORED | TEXT);
schema_builder.add_spatial_field("geometry", STORED | SPATIAL);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
let doc = TantivyDocument::parse_json(
&schema,
r#"{
"type":"Feature",
"geometry":{
"type":"Polygon",
"coordinates":[[[-99.483911,45.577697],[-99.483869,45.571457],[-99.481739,45.571461],[-99.474881,45.571584],[-99.473167,45.571615],[-99.463394,45.57168],[-99.463391,45.57883],[-99.463368,45.586076],[-99.48177,45.585926],[-99.48384,45.585953],[-99.483885,45.57873],[-99.483911,45.577697]]]
},
"properties":{
"admin_level":"8",
"border_type":"city",
"boundary":"administrative",
"gnis:feature_id":"1267426",
"name":"Hosmer",
"place":"city",
"source":"TIGER/Line® 2008 Place Shapefiles (http://www.census.gov/geo/www/tiger/)",
"wikidata":"Q2442118",
"wikipedia":"en:Hosmer, South Dakota"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let field = schema.get_field("geometry").unwrap();
let query = SpatialQuery::new(
field,
[
GeoPoint {
lon: -99.49,
lat: 45.56,
},
GeoPoint {
lon: -99.45,
lat: 45.59,
},
],
tantivy::query::SpatialQueryType::Intersects,
);
let hits = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_score, doc_address) in &hits {
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)?;
if let Some(field_value) = retrieved_doc.get_first(field) {
if let Some(geometry_box) = field_value.as_value().into_geometry() {
println!("Retrieved geometry: {:?}", geometry_box);
}
}
}
assert_eq!(hits.len(), 1);
Ok(())
}

View File

@@ -560,7 +560,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
(
(
value((), tag(">=")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
bound
@@ -574,7 +574,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<=")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -588,7 +588,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag(">")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
bound
@@ -602,7 +602,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -704,11 +704,7 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
char('/'),
),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
peek(alt((multispace1, eof))),
),
|elements| UserInputLeaf::Regex {
field: None,
@@ -725,12 +721,8 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
opt_i_err(char('/'), "missing delimiter /"),
),
opt_i_err(
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
"expected whitespace, closing parenthesis, or end of input",
peek(alt((multispace1, eof))),
"expected whitespace or end of input",
),
)(inp)
{
@@ -1331,14 +1323,6 @@ mod test {
test_parse_query_to_ast_helper("<a", "{\"*\" TO \"a\"}");
test_parse_query_to_ast_helper("<=a", "{\"*\" TO \"a\"]");
test_parse_query_to_ast_helper("<=bsd", "{\"*\" TO \"bsd\"]");
test_parse_query_to_ast_helper("(<=42)", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(<=42 )", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(age:>5)", "\"age\":{\"5\" TO \"*\"}");
test_parse_query_to_ast_helper(
"(title:bar AND age:>12)",
"(+\"title\":bar +\"age\":{\"12\" TO \"*\"})",
);
}
#[test]
@@ -1715,10 +1699,6 @@ mod test {
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
// Regexes between parentheses
test_parse_query_to_ast_helper("foo:(/A.*/)", "\"foo\":/A.*/");
test_parse_query_to_ast_helper("foo:(/A.*/ OR /B.*/)", "(?\"foo\":/A.*/ ?\"foo\":/B.*/)");
}
#[test]

View File

@@ -66,7 +66,6 @@ impl UserInputLeaf {
}
UserInputLeaf::Range { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Set { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Regex { field, .. } if field.is_none() => *field = Some(default_field),
_ => (), // field was already set, do nothing
}
}

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use columnar::{Column, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize;
@@ -10,17 +10,16 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, CompositeAggReqData,
CompositeAggregation, CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData,
HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
@@ -36,7 +35,6 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl AggregationsSegmentCtx {
@@ -74,12 +72,6 @@ impl AggregationsSegmentCtx {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -116,19 +108,20 @@ impl AggregationsSegmentCtx {
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
@@ -136,7 +129,10 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
@@ -146,6 +142,21 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -194,25 +205,6 @@ impl AggregationsSegmentCtx {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -240,8 +232,6 @@ pub struct PerRequestAggSegCtx {
pub top_hits_req_data: Vec<TopHitsAggReqData>,
/// MissingTermAggReqData contains the request data for a missing term aggregation.
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -289,11 +279,6 @@ impl PerRequestAggSegCtx {
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.composite_req_data
.iter()
.map(|b| b.as_ref().map(|d| d.get_memory_consumption()).unwrap_or(0))
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -330,17 +315,11 @@ impl PerRequestAggSegCtx {
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
}
}
/// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> =
@@ -366,19 +345,12 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
}
pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new();
for node in nodes.iter() {
@@ -416,8 +388,6 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
@@ -428,21 +398,20 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count
| StatsType::Max
| StatsType::Min
| StatsType::Stats => build_segment_stats_collector(req_data),
StatsType::ExtendedStats(sigma) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
node.idx_in_req_data,
))),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
)),
}
}
AggKind::TopHits => {
@@ -459,13 +428,12 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(
crate::aggregation::bucket::SegmentCompositeCollector::from_req_and_validate(
req, node,
)?,
)),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
}
}
@@ -496,7 +464,6 @@ pub enum AggKind {
DateHistogram,
Range,
Filter,
Composite,
}
impl AggKind {
@@ -512,7 +479,6 @@ impl AggKind {
AggKind::DateHistogram => "DateHistogram",
AggKind::Range => "Range",
AggKind::Filter => "Filter",
AggKind::Composite => "Composite",
}
}
}
@@ -527,7 +493,6 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx {
per_request: Default::default(),
context,
column_block_accessor: ColumnBlockAccessor::default(),
};
for (name, agg) in aggs.iter() {
@@ -556,9 +521,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: range_req.clone(),
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -576,7 +541,9 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(),
is_date_histogram: false,
bounds: HistogramBounds {
@@ -601,7 +568,9 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req,
is_date_histogram: true,
bounds: HistogramBounds {
@@ -681,6 +650,7 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for,
missing: *missing,
@@ -708,6 +678,7 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing,
@@ -760,14 +731,6 @@ fn build_nodes(
children,
}])
}
AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node(
agg_name,
reader,
segment_ordinal,
data,
&req.sub_aggregation,
composite_req,
)?]),
AggregationVariants::Filter(filter_req) => {
// Build the query and evaluator upfront
let schema = reader.schema();
@@ -790,7 +753,6 @@ fn build_nodes(
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -802,35 +764,6 @@ fn build_nodes(
}
}
fn build_composite_node(
agg_name: &str,
reader: &SegmentReader,
_segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: &CompositeAggregation,
) -> crate::Result<AggRefNode> {
let mut composite_accessors = Vec::with_capacity(req.sources.len());
for source in &req.sources {
let source_after_key_opt = req.after.get(source.name()).map(|k| &k.0);
let source_accessor =
CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?;
composite_accessors.push(source_accessor);
}
let agg = CompositeAggReqData {
name: agg_name.to_string(),
req: req.clone(),
composite_accessors,
};
let idx = data.push_composite_req_data(agg);
let children = build_children(sub_aggs, reader, _segment_ordinal, data)?;
Ok(AggRefNode {
kind: AggKind::Composite,
idx_in_req_data: idx,
children,
})
}
fn build_children(
aggs: &Aggregations,
reader: &SegmentReader,
@@ -962,7 +895,7 @@ fn build_terms_or_cardinality_nodes(
});
}
// Add one node per accessor
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg {
None
@@ -993,8 +926,11 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
@@ -1007,6 +943,7 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: req.clone(),
});

View File

@@ -32,8 +32,8 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
CompositeAggregation, DateHistogramAggregationReq, FilterAggregation, HistogramAggregation,
RangeAggregation, TermsAggregation,
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -134,9 +134,6 @@ pub enum AggregationVariants {
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
/// Multi-dimensional, paginable bucket aggregation.
#[serde(rename = "composite")]
Composite(CompositeAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -183,11 +180,6 @@ impl AggregationVariants {
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Composite(composite) => composite
.sources
.iter()
.map(|source| source.field())
.collect(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -222,12 +214,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> {
match &self {
AggregationVariants::Composite(composite) => Some(composite),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -9,12 +9,10 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::GetDocCount;
use super::intermediate_agg_result::CompositeIntermediateKey;
use super::metric::{
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
};
use super::{AggregationError, Key};
use crate::aggregation::bucket::AfterKey;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
@@ -160,14 +158,6 @@ pub enum BucketResult {
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
/// This is the composite result
Composite {
/// The buckets
buckets: Vec<CompositeBucketEntry>,
/// The key to start after when paginating
#[serde(skip_serializing_if = "FxHashMap::is_empty")]
after_key: FxHashMap<String, AfterKey>,
},
}
impl BucketResult {
@@ -189,9 +179,6 @@ impl BucketResult {
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
BucketResult::Composite { buckets, .. } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
}
}
}
@@ -350,87 +337,3 @@ pub struct FilterBucketResult {
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}
/// Note the type information loss compared to `CompositeIntermediateKey`.
/// Pagination is performed using `AfterKey`, which encodes type information.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompositeKey {
/// Boolean key
Bool(bool),
/// String key
Str(String),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
/// `f64` key
F64(f64),
/// Null key
Null,
}
impl Eq for CompositeKey {}
impl std::hash::Hash for CompositeKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Bool(val) => val.hash(state),
Self::Str(text) => text.hash(state),
Self::F64(val) => val.to_bits().hash(state),
Self::U64(val) => val.hash(state),
Self::I64(val) => val.hash(state),
Self::Null => {}
}
}
}
impl PartialEq for CompositeKey {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bool(l), Self::Bool(r)) => l == r,
(Self::Str(l), Self::Str(r)) => l == r,
(Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(),
(Self::I64(l), Self::I64(r)) => l == r,
(Self::U64(l), Self::U64(r)) => l == r,
(Self::Null, Self::Null) => true,
_ => false,
}
}
}
impl From<CompositeIntermediateKey> for CompositeKey {
fn from(value: CompositeIntermediateKey) -> Self {
match value {
CompositeIntermediateKey::Str(s) => Self::Str(s),
CompositeIntermediateKey::IpAddr(s) => {
if let Some(ip) = s.to_ipv4_mapped() {
Self::Str(ip.to_string())
} else {
Self::Str(s.to_string())
}
}
CompositeIntermediateKey::F64(f) => Self::F64(f),
CompositeIntermediateKey::Bool(f) => Self::Bool(f),
CompositeIntermediateKey::U64(f) => Self::U64(f),
CompositeIntermediateKey::I64(f) => Self::I64(f),
CompositeIntermediateKey::DateTime(f) => Self::I64(f / 1_000_000), // ns to ms
CompositeIntermediateKey::Null => Self::Null,
}
}
}
/// Composite bucket entry with a multi-dimensional key.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompositeBucketEntry {
/// The identifier of the bucket.
pub key: FxHashMap<String, CompositeKey>,
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl CompositeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -2,441 +2,15 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -451,10 +25,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -467,9 +37,8 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -1,548 +0,0 @@
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit};
use crate::aggregation::accessor_helpers::get_numeric_or_date_column_types;
use crate::aggregation::bucket::composite::numeric_types::num_proj;
use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber;
use crate::aggregation::bucket::composite::ToTypePaginationOrder;
use crate::aggregation::bucket::{
parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource,
MissingOrder, Order,
};
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
/// The normalized term aggregation request.
pub req: CompositeAggregation,
/// Accessors for each source, each source can have multiple accessors (columns).
pub composite_accessors: Vec<CompositeSourceAccessors>,
}
impl CompositeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.composite_accessors.len() * std::mem::size_of::<CompositeSourceAccessors>()
}
}
/// Accessors for a single column in a composite source.
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
/// The column type
pub column_type: ColumnType,
/// Term dictionary if the column type is Str
///
/// Only used by term sources
pub str_dict_column: Option<StrColumn>,
/// Parsed date interval for date histogram sources
pub date_histogram_interval: PrecomputedDateInterval,
}
/// Accessors to all the columns that belong to the field of a composite source.
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
/// The key after which to start collecting results. Applies to the first
/// column of the source.
pub after_key: PrecomputedAfterKey,
/// The column index the after_key applies to. The after_key only applies to
/// one column. Columns before should be skipped. Columns after should be
/// kept without comparison to the after_key.
pub after_key_accessor_idx: usize,
/// Whether to skip missing values because of the after_key. Skipping only
/// applies if the value for previous columns were exactly equal to the
/// corresponding after keys (is_on_after_key).
pub skip_missing: bool,
/// The after key was set to null to indicate that the last collected key
/// was a missing value.
pub is_after_key_explicit_missing: bool,
}
impl CompositeSourceAccessors {
/// Creates a new set of accessors for the composite source.
///
/// Precomputes some values to make collection faster.
pub fn build_for_source(
reader: &SegmentReader,
source: &CompositeAggregationSource,
// First option is None when no after key was set in the query, the
// second option is None when the after key was set but its value for
// this source was set to `null`
source_after_key_opt: Option<&CompositeIntermediateKey>,
) -> crate::Result<Self> {
let is_after_key_explicit_missing = source_after_key_opt
.map(|after_key| matches!(after_key, CompositeIntermediateKey::Null))
.unwrap_or(false);
let mut skip_missing = false;
if let Some(CompositeIntermediateKey::Null) = source_after_key_opt {
if !source.missing_bucket() {
return Err(TantivyError::InvalidArgument(
"the 'after' key for a source cannot be null when 'missing_bucket' is false"
.to_string(),
));
}
} else if source_after_key_opt.is_some() {
// if missing buckets come first and we have a non null after key, we skip missing
if MissingOrder::First == source.missing_order() {
skip_missing = true;
}
if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() {
skip_missing = true;
}
};
match source {
CompositeAggregationSource::Terms(source) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let mut columns_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&allowed_column_types), &source.field)?;
// Sort columns by their pagination order and determine which to skip
columns_and_types.sort_by_key(|(_, col_type): &(Column, ColumnType)| {
col_type.column_pagination_order()
});
if source.order == Order::Desc {
columns_and_types.reverse();
}
let after_key_accessor_idx = find_first_column_to_collect(
&columns_and_types,
source_after_key_opt,
source.missing_order,
source.order,
)?;
let source_collectors: Vec<CompositeAccessor> = columns_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: reader.fast_fields().str(&source.field)?,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = if let Some(first_col) =
source_collectors.get(after_key_accessor_idx)
{
match source_after_key_opt {
Some(after_key) => PrecomputedAfterKey::precompute(
&first_col,
after_key,
&source.field,
source.missing_order,
source.order,
)?,
None => {
precompute_missing_after_key(false, source.missing_order, source.order)
}
}
} else {
// if no columns, we don't care about the after_key
PrecomputedAfterKey::Next(0)
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx,
})
}
CompositeAggregationSource::Histogram(source) => {
let column_and_types: Vec<(Column, ColumnType)> =
reader.fast_fields().u64_lenient_for_type_all(
Some(get_numeric_or_date_column_types()),
&source.field,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::F64(key)) => {
let normalized_key = *key / source.interval;
num_proj::f64_to_i64(normalized_key).into()
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
CompositeAggregationSource::DateHistogram(source) => {
let column_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&[ColumnType::DateTime]), &source.field)?;
let date_histogram_interval =
PrecomputedDateInterval::from_date_histogram_source_intervals(
&source.fixed_interval,
source.calendar_interval,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
}
}
}
/// Finds the index of the first column we should start collecting from to
/// resume the pagination from the after_key.
fn find_first_column_to_collect<T>(
sorted_columns: &[(T, ColumnType)],
after_key_opt: Option<&CompositeIntermediateKey>,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<usize> {
let after_key = match after_key_opt {
None => return Ok(0), // No pagination, start from beginning
Some(key) => key,
};
// Handle null after_key (we were on a missing value last time)
if matches!(after_key, CompositeIntermediateKey::Null) {
return match (missing_order, order) {
// Missing values come first, so all columns remain
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => Ok(0),
// Missing values come last, so all columns are done
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => {
Ok(sorted_columns.len())
}
};
}
// Find the first column whose type order matches or follows the after_key's
// type in the pagination sequence
let after_key_column_order = after_key.column_pagination_order();
for (idx, (_, col_type)) in sorted_columns.iter().enumerate() {
let col_order = col_type.column_pagination_order();
let is_first_to_collect = match order {
Order::Asc => col_order >= after_key_column_order,
Order::Desc => col_order <= after_key_column_order,
};
if is_first_to_collect {
return Ok(idx);
}
}
// All columns are before the after_key, nothing left to collect
Ok(sorted_columns.len())
}
fn precompute_missing_after_key(
is_after_key_explicit_missing: bool,
missing_order: MissingOrder,
order: Order,
) -> PrecomputedAfterKey {
let after_last = PrecomputedAfterKey::AfterLast;
let before_first = PrecomputedAfterKey::Next(0);
match (is_after_key_explicit_missing, missing_order, order) {
(true, MissingOrder::First, Order::Asc) => before_first,
(true, MissingOrder::First, Order::Desc) => after_last,
(true, MissingOrder::Last, Order::Asc) => after_last,
(true, MissingOrder::Last, Order::Desc) => before_first,
(true, MissingOrder::Default, Order::Asc) => before_first,
(true, MissingOrder::Default, Order::Desc) => after_last,
(false, _, Order::Asc) => before_first,
(false, _, Order::Desc) => after_last,
}
}
/// A parsed representation of the date interval for date histogram sources
#[derive(Clone, Copy, Debug)]
pub enum PrecomputedDateInterval {
/// This is not a date histogram source
NotApplicable,
/// Source was configured with a fixed interval
FixedNanoseconds(i64),
/// Source was configured with a calendar interval
Calendar(CalendarInterval),
}
impl PrecomputedDateInterval {
/// Validates the date histogram source interval fields and parses a date interval from them.
pub fn from_date_histogram_source_intervals(
fixed_interval: &Option<String>,
calendar_interval: Option<CalendarInterval>,
) -> crate::Result<Self> {
match (fixed_interval, calendar_interval) {
(Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument(
"date histogram source must one and only one of fixed_interval or \
calendar_interval set"
.to_string(),
)),
(Some(fixed_interval), None) => {
let fixed_interval_ms = parse_into_milliseconds(&fixed_interval)?;
Ok(PrecomputedDateInterval::FixedNanoseconds(
fixed_interval_ms * 1_000_000,
))
}
(None, Some(calendar_interval)) => {
Ok(PrecomputedDateInterval::Calendar(calendar_interval))
}
}
}
}
/// The after key projected to the u64 column space
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),
/// The after key could not be exactly represented exactly represented, so
/// this is the next closest one.
Next(u64),
/// The after key could not be represented in the column space, it is
/// greater than all value
AfterLast,
}
impl From<TermOrdHit> for PrecomputedAfterKey {
fn from(hit: TermOrdHit) -> Self {
match hit {
TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord),
// TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is
TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord),
}
}
}
impl<T: MonotonicallyMappableToU64> From<ProjectedNumber<T>> for PrecomputedAfterKey {
fn from(num: ProjectedNumber<T>) -> Self {
match num {
ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()),
ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()),
ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
// /!\ These operators only makes sense if both values are in the same column space
impl PrecomputedAfterKey {
pub fn equals(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v == column_value,
PrecomputedAfterKey::Next(_) => false,
PrecomputedAfterKey::AfterLast => false,
}
}
pub fn gt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v > column_value,
PrecomputedAfterKey::Next(v) => *v > column_value,
PrecomputedAfterKey::AfterLast => true,
}
}
pub fn lt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v < column_value,
// a value equal to the next is greater than the after key
PrecomputedAfterKey::Next(v) => *v <= column_value,
PrecomputedAfterKey::AfterLast => false,
}
}
fn precompute_ip_addr(column: &Column<u64>, key: &Ipv6Addr) -> crate::Result<Self> {
// For IP addresses we need to find the compact space value.
// We try to convert via the column's min/max range scan.
// Since CompactSpaceU64Accessor::u128_to_compact is not public,
// we search linearly for the exact u64 value by scanning column values.
let ip_u128 = key.to_bits();
// Scan for matching value - IP columns are typically small
let num_vals = column.values.num_vals();
let mut found_exact = false;
let mut exact_compact = 0u64;
let mut best_next: Option<u64> = None;
for doc_id in 0..num_vals {
let val = column.values.get_val(doc_id);
// We need the CompactSpaceU64Accessor to convert compact to u128
let compact_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"type mismatch: could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
let val_u128 = compact_accessor.compact_to_u128(val as u32);
if val_u128 == ip_u128 {
found_exact = true;
exact_compact = val;
break;
} else if val_u128 > ip_u128 {
match best_next {
None => best_next = Some(val),
Some(current_best) => {
let current_u128 = compact_accessor.compact_to_u128(current_best as u32);
if val_u128 < current_u128 {
best_next = Some(val);
}
}
}
}
}
if found_exact {
Ok(PrecomputedAfterKey::Exact(exact_compact))
} else if let Some(next) = best_next {
Ok(PrecomputedAfterKey::Next(next))
} else {
Ok(PrecomputedAfterKey::AfterLast)
}
}
fn precompute_term_ord(
str_dict_column: &Option<StrColumn>,
key: &str,
field: &str,
) -> crate::Result<Self> {
let dict = str_dict_column
.as_ref()
.expect("dictionary missing for str accessor")
.dictionary();
let next_ord = dict.term_ord_or_next(key).map_err(|_| {
TantivyError::InvalidArgument(format!(
"failed to lookup after_key '{}' for field '{}'",
key, field
))
})?;
Ok(next_ord.into())
}
/// Projects the after key into the column space of the given accessor.
///
/// The computed after key will not take care of skipping entire columns
/// when the after key type is ordered after the accessor's type, that
/// should be performed earlier.
pub fn precompute(
composite_accessor: &CompositeAccessor,
source_after_key: &CompositeIntermediateKey,
field: &str,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<Self> {
use CompositeIntermediateKey as CIKey;
let precomputed_key = match (composite_accessor.column_type, source_after_key) {
(ColumnType::Bytes, _) => panic!("unsupported"),
// null after key
(_, CIKey::Null) => precompute_missing_after_key(false, missing_order, order),
// numerical
(ColumnType::I64, CIKey::I64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
(ColumnType::I64, CIKey::U64(k)) => num_proj::u64_to_i64(*k).into(),
(ColumnType::I64, CIKey::F64(k)) => num_proj::f64_to_i64(*k).into(),
(ColumnType::U64, CIKey::I64(k)) => num_proj::i64_to_u64(*k).into(),
(ColumnType::U64, CIKey::U64(k)) => PrecomputedAfterKey::Exact(*k),
(ColumnType::U64, CIKey::F64(k)) => num_proj::f64_to_u64(*k).into(),
(ColumnType::F64, CIKey::I64(k)) => num_proj::i64_to_f64(*k).into(),
(ColumnType::F64, CIKey::U64(k)) => num_proj::u64_to_f64(*k).into(),
(ColumnType::F64, CIKey::F64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
// boolean
(ColumnType::Bool, CIKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()),
// string
(ColumnType::Str, CIKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord(
&composite_accessor.str_dict_column,
key,
field,
)?,
// date time
(ColumnType::DateTime, CIKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
// ip address
(ColumnType::IpAddr, CIKey::IpAddr(key)) => {
PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key)?
}
// assume the column's type is ordered after the after_key's type
_ => PrecomputedAfterKey::keep_all(order),
};
Ok(precomputed_key)
}
fn keep_all(order: Order) -> Self {
match order {
Order::Asc => PrecomputedAfterKey::Next(0),
Order::Desc => PrecomputedAfterKey::Next(u64::MAX),
}
}
}

View File

@@ -1,140 +0,0 @@
use time::convert::{Day, Nanosecond};
use time::{Time, UtcDateTime};
const NS_IN_DAY: i64 = Nanosecond::per_t::<i128>(Day) as i64;
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// year (January 1st at midnight UTC).
pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result<i64> {
year_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute year bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// month (1st at midnight UTC).
pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result<i64> {
month_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute month bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// week (Monday at midnight UTC).
pub(super) fn week_bucket(timestamp_ns: i64) -> i64 {
// 1970-01-01 was a Thursday (weekday = 4)
let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY);
// Find the weekday: 0=Monday, ..., 6=Sunday
let weekday = (days_since_epoch + 3).rem_euclid(7);
let monday_days_since_epoch = days_since_epoch - weekday;
monday_days_since_epoch * NS_IN_DAY
}
fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_ordinal(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_day(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
#[cfg(test)]
mod tests {
use std::i64;
use time::format_description::well_known::Iso8601;
use time::UtcDateTime;
use super::*;
fn ts_ns(iso: &str) -> i64 {
UtcDateTime::parse(iso, &Iso8601::DEFAULT)
.unwrap()
.unix_timestamp_nanos() as i64
}
#[test]
fn test_year_bucket() {
let ts = ts_ns("1970-01-01T00:00:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-06-01T10:00:01.010Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2010-12-31T23:59:59.999999999Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2010-01-01T00:00:00Z"));
let ts = ts_ns("1972-06-01T00:10:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1972-01-01T00:00:00Z"));
}
#[test]
fn test_month_bucket() {
let ts = ts_ns("1970-01-15T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-02-01T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-02-01T00:00:00Z"));
let ts = ts_ns("2000-01-31T23:59:59.999999999Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2000-01-01T00:00:00Z"));
}
#[test]
fn test_week_bucket() {
let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("2025-10-13T00:00:00Z"));
let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative
}
}

View File

@@ -1,674 +0,0 @@
use std::fmt::Debug;
use std::mem;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{
Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalValue, StrColumn,
};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::composite::accessors::{
CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval,
};
use crate::aggregation::bucket::composite::calendar_interval;
use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE};
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardSubAggCache};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::TantivyError;
#[derive(Clone, Debug)]
struct CompositeBucketCollector {
count: u32,
bucket_id: BucketId,
}
/// Compact sortable representation of a single source value within a composite key.
///
/// The struct encodes both the column identity and the fast field value in a way
/// that preserves the desired sort order via the derived `Ord` implementation
/// (fields are compared top-to-bottom: `sort_key` first, then `encoded_value`).
///
/// ## `sort_key` encoding
/// - `0` — missing value, sorted first
/// - `1..=254` — present value; the original accessor index is `sort_key - 1`
/// - `u8::MAX` (255) — missing value, sorted last
///
/// ## `encoded_value` encoding
/// - `0` when the field is missing
/// - The raw u64 fast-field representation when order is ascending
/// - Bitwise NOT of the raw u64 when order is descending
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
struct InternalValueRepr {
/// Column index biased by +1 (so 0 and u8::MAX are reserved for missing sentinels).
sort_key: u8,
/// Fast field value, possibly bit-flipped for descending order.
encoded_value: u64,
}
impl InternalValueRepr {
#[inline]
fn new_term(raw: u64, accessor_idx: u8, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: accessor_idx + 1,
encoded_value,
}
}
/// For histogram sources the column index is irrelevant (always 1).
#[inline]
fn new_histogram(raw: u64, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: 1,
encoded_value,
}
}
#[inline]
fn new_missing(order: Order, missing_order: MissingOrder) -> Self {
let sort_key = match (missing_order, order) {
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => 0,
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => u8::MAX,
};
InternalValueRepr {
sort_key,
encoded_value: 0,
}
}
/// Decode back to `(accessor_idx, raw_value)`.
/// Returns `None` when the value represents a missing field.
#[inline]
fn decode(self, order: Order) -> Option<(u8, u64)> {
if self.sort_key == 0 || self.sort_key == u8::MAX {
return None;
}
let raw = match order {
Order::Asc => self.encoded_value,
Order::Desc => !self.encoded_value,
};
Some((self.sort_key - 1, raw))
}
}
/// The collector puts values from the fast field into the correct buckets and
/// does a conversion to the correct datatype.
#[derive(Debug)]
pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
accessor_idx: usize,
sub_agg: Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
num_sources: usize,
}
impl SegmentAggregationCollector for SegmentCompositeCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let buckets = self.into_intermediate_bucket_result(agg_data, parent_bucket_id)?;
results.push(
name,
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }),
)?;
Ok(())
}
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
let mut sub_level_values = SmallVec::new();
recursive_key_visitor(
*doc,
&composite_agg_data,
0,
&mut sub_level_values,
&mut self.parent_buckets[parent_bucket_id as usize],
true,
&mut self.sub_agg,
&mut self.bucket_id_provider,
)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_len = max_bucket as usize + 1;
while self.parent_buckets.len() < required_len {
let map = DynArrayHeapMap::try_new(self.num_sources)?;
self.parent_buckets.push(map);
}
Ok(())
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self) -> u64 {
self.parent_buckets
.iter()
.map(|m| m.memory_consumption())
.sum()
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
validate_req(req_data, node.idx_in_req_data)?;
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_agg_collector = build_segment_agg_collectors(req_data, &node.children)?;
Some(CachedSubAggs::new(sub_agg_collector))
} else {
None
};
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
let num_sources = composite_req_data.req.sources.len();
Ok(SegmentCompositeCollector {
parent_buckets: vec![DynArrayHeapMap::try_new(num_sources)?],
accessor_idx: node.idx_in_req_data,
sub_agg,
bucket_id_provider: BucketIdProvider::default(),
num_sources,
})
}
#[inline]
fn into_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
parent_bucket_id: BucketId,
) -> crate::Result<IntermediateCompositeBucketResult> {
let empty_map = DynArrayHeapMap::try_new(self.num_sources)?;
let heap_map = mem::replace(
&mut self.parent_buckets[parent_bucket_id as usize],
empty_map,
);
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(heap_map.size());
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
for (key_internal_repr, agg) in heap_map.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
agg.bucket_id,
)?;
}
dict.insert(
key,
IntermediateCompositeBucketEntry {
doc_count: agg.count,
sub_aggregation: sub_aggregation_res,
},
);
}
Ok(IntermediateCompositeBucketResult {
entries: dict,
target_size: composite_data.req.size,
orders: composite_data
.req
.sources
.iter()
.map(|source| match source {
CompositeAggregationSource::Terms(t) => (t.order, t.missing_order),
CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order),
CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order),
})
.collect(),
})
}
}
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(
"composite aggregation must have at least one source".to_string(),
));
}
if req.size == 0 {
return Err(TantivyError::InvalidArgument(
"composite aggregation 'size' must be > 0".to_string(),
));
}
let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| {
item.accessors
.iter()
.map(|a| a.column_type)
.collect::<Vec<_>>()
});
for column_types in column_types_for_sources {
if column_types.len() > MAX_DYN_ARRAY_SIZE {
return Err(TantivyError::InvalidArgument(format!(
"composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources",
)));
}
if column_types.contains(&ColumnType::Bytes) {
return Err(TantivyError::InvalidArgument(
"composite aggregation does not support 'bytes' field type".to_string(),
));
}
}
Ok(())
}
fn collect_bucket_with_limit(
doc_id: crate::DocId,
limit_num_buckets: usize,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &mut BucketIdProvider,
) {
let mut record_in_bucket = |bucket: &mut CompositeBucketCollector| {
bucket.count += 1;
if let Some(sub_agg) = sub_agg {
sub_agg.push(bucket.bucket_id, doc_id);
}
};
// We still have room for buckets, just insert
if buckets.size() < limit_num_buckets {
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
return;
}
// Map is full, but we can still update the bucket if it already exists
if let Some(bucket) = buckets.get_mut(key) {
record_in_bucket(bucket);
return;
}
// Check if the item qualifies to enter the top-k, and evict the highest if it does
if let Some(highest_key) = buckets.peek_highest() {
if key < highest_key {
buckets.evict_highest();
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
}
}
}
/// Converts the composite key from its internal column space representation
/// (segment specific) into its intermediate form.
fn resolve_key(
internal_key: &[InternalValueRepr],
agg_data: &CompositeAggReqData,
) -> crate::Result<Vec<CompositeIntermediateKey>> {
internal_key
.iter()
.enumerate()
.map(|(idx, val)| {
resolve_internal_value_repr(
*val,
&agg_data.req.sources[idx],
&agg_data.composite_accessors[idx].accessors,
)
})
.collect()
}
fn resolve_internal_value_repr(
internal_value_repr: InternalValueRepr,
source: &CompositeAggregationSource,
composite_accessors: &[CompositeAccessor],
) -> crate::Result<CompositeIntermediateKey> {
let decoded_value_opt = match source {
CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::DateHistogram(source) => {
internal_value_repr.decode(source.order)
}
};
let Some((decoded_accessor_idx, val)) = decoded_value_opt else {
return Ok(CompositeIntermediateKey::Null);
};
let key = match source {
CompositeAggregationSource::Terms(_) => {
let CompositeAccessor {
column_type,
str_dict_column,
column,
..
} = &composite_accessors[decoded_accessor_idx as usize];
resolve_term(val, column_type, str_dict_column, column)?
}
CompositeAggregationSource::Histogram(source) => {
CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval)
}
CompositeAggregationSource::DateHistogram(_) => {
CompositeIntermediateKey::DateTime(i64::from_u64(val))
}
};
Ok(key)
}
fn resolve_term(
val: u64,
column_type: &ColumnType,
str_dict_column: &Option<StrColumn>,
column: &Column,
) -> crate::Result<CompositeIntermediateKey> {
let key = if *column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
term_dict.ord_to_term(val, &mut buffer)?;
CompositeIntermediateKey::Str(
String::from_utf8(buffer.to_vec()).expect("could not convert to String"),
)
} else if *column_type == ColumnType::DateTime {
let val = i64::from_u64(val);
CompositeIntermediateKey::DateTime(val)
} else if *column_type == ColumnType::Bool {
let val = bool::from_u64(val);
CompositeIntermediateKey::Bool(val)
} else if *column_type == ColumnType::IpAddr {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
CompositeIntermediateKey::IpAddr(val)
} else if *column_type == ColumnType::U64 {
CompositeIntermediateKey::U64(val)
} else if *column_type == ColumnType::I64 {
CompositeIntermediateKey::I64(i64::from_u64(val))
} else {
let val = f64::from_u64(val);
let val: NumericalValue = val.into();
match val.normalize() {
NumericalValue::U64(val) => CompositeIntermediateKey::U64(val),
NumericalValue::I64(val) => CompositeIntermediateKey::I64(val),
NumericalValue::F64(val) => CompositeIntermediateKey::F64(val),
}
};
Ok(key)
}
/// Depth-first walk of the accessors to build the composite key combinations
/// and update the buckets.
fn recursive_key_visitor(
doc_id: crate::DocId,
composite_agg_data: &CompositeAggReqData,
source_idx_for_recursion: usize,
sub_level_values: &mut SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
// whether we need to consider the after_key in the following levels
is_on_after_key: bool,
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &mut BucketIdProvider,
) -> crate::Result<()> {
if source_idx_for_recursion == composite_agg_data.req.sources.len() {
if !is_on_after_key {
collect_bucket_with_limit(
doc_id,
composite_agg_data.req.size as usize,
buckets,
sub_level_values,
sub_agg,
bucket_id_provider,
);
}
return Ok(());
}
let current_level_accessors = &composite_agg_data.composite_accessors[source_idx_for_recursion];
let current_level_source = &composite_agg_data.req.sources[source_idx_for_recursion];
let mut missing = true;
for (accessor_idx, accessor) in current_level_accessors.accessors.iter().enumerate() {
let values = accessor.column.values_for_doc(doc_id);
for value in values {
missing = false;
match current_level_source {
CompositeAggregationSource::Terms(_) => {
let preceeds_after_key_type =
accessor_idx < current_level_accessors.after_key_accessor_idx;
if is_on_after_key && preceeds_after_key_type {
break;
}
let matches_after_key_type =
accessor_idx == current_level_accessors.after_key_accessor_idx;
if matches_after_key_type && is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(value),
Order::Desc => current_level_accessors.after_key.lt(value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_term(
value,
accessor_idx as u8,
current_level_source.order(),
));
let still_on_after_key =
matches_after_key_type && current_level_accessors.after_key.equals(value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::Histogram(source) => {
let float_value = match accessor.column_type {
ColumnType::U64 => value as f64,
ColumnType::I64 => i64::from_u64(value) as f64,
ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000.,
ColumnType::F64 => f64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = (float_value / source.interval).floor() as i64;
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::DateHistogram(_) => {
let value_ns = match accessor.column_type {
ColumnType::DateTime => i64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = match accessor.date_histogram_interval {
PrecomputedDateInterval::FixedNanoseconds(fixed_interval_ns) => {
(value_ns / fixed_interval_ns) * fixed_interval_ns
}
PrecomputedDateInterval::Calendar(CalendarInterval::Year) => {
calendar_interval::try_year_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Month) => {
calendar_interval::try_month_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Week) => {
calendar_interval::week_bucket(value_ns)
}
PrecomputedDateInterval::NotApplicable => {
panic!("interval not precomputed for date histogram source")
}
};
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
};
}
}
if missing && current_level_source.missing_bucket() {
if is_on_after_key && current_level_accessors.skip_missing {
return Ok(());
}
sub_level_values.push(InternalValueRepr::new_missing(
current_level_source.order(),
current_level_source.missing_order(),
));
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && current_level_accessors.is_after_key_explicit_missing,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
Ok(())
}

View File

@@ -1,329 +0,0 @@
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::hash::Hash;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::TantivyError;
/// Map backed by a hash map for fast access and a binary heap to track the
/// highest key. The key is an array of fixed size S.
#[derive(Clone, Debug)]
struct ArrayHeapMap<K: Ord, V, const S: usize> {
pub(crate) buckets: FxHashMap<[K; S], V>,
pub(crate) heap: BinaryHeap<[K; S]>,
}
impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> {
fn default() -> Self {
ArrayHeapMap {
buckets: FxHashMap::default(),
heap: BinaryHeap::default(),
}
}
}
impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> {
/// Panics if the length of `key` is not S.
fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.entry(key_array.clone()).or_insert_with(|| {
self.heap.push(key_array.clone());
f()
})
}
/// Panics if the length of `key` is not S.
fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.get_mut(key_array)
}
fn peek_highest(&self) -> Option<&[K]> {
self.heap.peek().map(|k_array| k_array.as_slice())
}
fn evict_highest(&mut self) {
if let Some(highest) = self.heap.pop() {
self.buckets.remove(&highest);
}
}
fn memory_consumption(&self) -> u64 {
let key_size = std::mem::size_of::<[K; S]>();
let map_size = (key_size + std::mem::size_of::<V>()) * self.buckets.capacity();
let heap_size = key_size * self.heap.capacity();
(map_size + heap_size) as u64
}
}
impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> {
fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> {
Box::new(
self.buckets
.into_iter()
.map(|(k, v)| (SmallVec::from_slice(&k), v)),
)
}
}
pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16;
const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1;
/// A map optimized for memory footprint, fast access and efficient eviction of
/// the highest key.
///
/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given
/// instance the key size is fixed. This allows to avoid heap allocations for the
/// keys.
#[derive(Clone, Debug)]
pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>);
/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size.
#[derive(Clone, Debug)]
enum DynArrayHeapMapInner<K: Ord, V> {
Dim1(ArrayHeapMap<K, V, 1>),
Dim2(ArrayHeapMap<K, V, 2>),
Dim3(ArrayHeapMap<K, V, 3>),
Dim4(ArrayHeapMap<K, V, 4>),
Dim5(ArrayHeapMap<K, V, 5>),
Dim6(ArrayHeapMap<K, V, 6>),
Dim7(ArrayHeapMap<K, V, 7>),
Dim8(ArrayHeapMap<K, V, 8>),
Dim9(ArrayHeapMap<K, V, 9>),
Dim10(ArrayHeapMap<K, V, 10>),
Dim11(ArrayHeapMap<K, V, 11>),
Dim12(ArrayHeapMap<K, V, 12>),
Dim13(ArrayHeapMap<K, V, 13>),
Dim14(ArrayHeapMap<K, V, 14>),
Dim15(ArrayHeapMap<K, V, 15>),
Dim16(ArrayHeapMap<K, V, 16>),
}
impl<K: Ord, V> DynArrayHeapMap<K, V> {
/// Creates a new heap map with dynamic array keys of size `key_dimension`.
pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> {
let inner = match key_dimension {
0 => {
return Err(TantivyError::InvalidArgument(
"DynArrayHeapMap dimension must be at least 1".to_string(),
))
}
1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()),
2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()),
3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()),
4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()),
5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()),
6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()),
7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()),
8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()),
9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()),
10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()),
11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()),
12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()),
13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()),
14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()),
15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()),
16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()),
MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => {
return Err(TantivyError::InvalidArgument(format!(
"DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \
{key_dimension}",
)))
}
};
Ok(DynArrayHeapMap(inner))
}
/// Number of elements in the map. This is not the dimension of the keys.
pub(super) fn size(&self) -> usize {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim2(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim3(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim4(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim5(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim6(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim7(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim8(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim9(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim10(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim11(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim12(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim13(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim14(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim15(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim16(map) => map.buckets.len(),
}
}
}
impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> {
/// Get a mutable reference to the value corresponding to `key` or inserts a new
/// value created by calling `f`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f),
}
}
/// Returns a mutable reference to the value corresponding to `key`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim2(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim3(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim4(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim5(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim6(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim7(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim8(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim9(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim10(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim11(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim12(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim13(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim14(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim15(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim16(map) => map.get_mut(key),
}
}
/// Returns a reference to the highest key in the map.
pub(super) fn peek_highest(&self) -> Option<&[K]> {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim2(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim3(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim4(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim5(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim6(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim7(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim8(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim9(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim10(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim11(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim12(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim13(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim14(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim15(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim16(map) => map.peek_highest(),
}
}
/// Removes the entry with the highest key from the map.
pub(super) fn evict_highest(&mut self) {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim2(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim3(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim4(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim5(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim6(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim7(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim8(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim9(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim10(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim11(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim12(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim13(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim14(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim15(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim16(map) => map.evict_highest(),
}
}
pub(crate) fn memory_consumption(&self) -> u64 {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(),
}
}
}
impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> {
/// Turns this map into an iterator over key-value pairs.
pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> {
match self.0 {
DynArrayHeapMapInner::Dim1(map) => map.into_iter(),
DynArrayHeapMapInner::Dim2(map) => map.into_iter(),
DynArrayHeapMapInner::Dim3(map) => map.into_iter(),
DynArrayHeapMapInner::Dim4(map) => map.into_iter(),
DynArrayHeapMapInner::Dim5(map) => map.into_iter(),
DynArrayHeapMapInner::Dim6(map) => map.into_iter(),
DynArrayHeapMapInner::Dim7(map) => map.into_iter(),
DynArrayHeapMapInner::Dim8(map) => map.into_iter(),
DynArrayHeapMapInner::Dim9(map) => map.into_iter(),
DynArrayHeapMapInner::Dim10(map) => map.into_iter(),
DynArrayHeapMapInner::Dim11(map) => map.into_iter(),
DynArrayHeapMapInner::Dim12(map) => map.into_iter(),
DynArrayHeapMapInner::Dim13(map) => map.into_iter(),
DynArrayHeapMapInner::Dim14(map) => map.into_iter(),
DynArrayHeapMapInner::Dim15(map) => map.into_iter(),
DynArrayHeapMapInner::Dim16(map) => map.into_iter(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dyn_array_heap_map() {
let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap();
// insert
let key1 = [1u32, 2u32];
let key2 = [2u32, 1u32];
map.get_or_insert_with(&key1, || "a");
map.get_or_insert_with(&key2, || "b");
assert_eq!(map.size(), 2);
// evict highest
assert_eq!(map.peek_highest(), Some(&key2[..]));
map.evict_highest();
assert_eq!(map.size(), 1);
assert_eq!(map.peek_highest(), Some(&key1[..]));
// into_iter
let mut iter = map.into_iter();
let (k, v) = iter.next().unwrap();
assert_eq!(k.as_slice(), &key1);
assert_eq!(v, "c");
assert_eq!(iter.next(), None);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,460 +0,0 @@
/// This modules helps comparing numerical values of different types (i64, u64
/// and f64).
pub(super) mod num_cmp {
use std::cmp::Ordering;
use crate::TantivyError;
pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be
// exactly represented as f64)
if right_f < i64::MIN as f64 {
return Ok(Ordering::Greater);
}
// If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot
// be exactly represented as f64)
if right_f >= i64::MAX as f64 {
return Ok(Ordering::Less);
}
// Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is
// well-defined (truncation toward 0)
let right_as_i = right_f as i64;
let result = match left_i.cmp(&right_as_i) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_i as f64);
if rem == 0.0 {
Ordering::Equal
} else if right_f > 0.0 {
Ordering::Less
} else {
Ordering::Greater
}
}
};
Ok(result)
}
pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// Negative floats are always less than any u64 >= 0
if right_f < 0.0 {
return Ok(Ordering::Greater);
}
// If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly)
let max_as_f = u64::MAX as f64;
if right_f > max_as_f {
return Ok(Ordering::Less);
}
// Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined
// (truncation toward 0)
let right_as_u = right_f as u64;
let result = match left_u.cmp(&right_as_u) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_u as f64);
if rem == 0.0 {
Ordering::Equal
} else {
Ordering::Less
}
}
};
Ok(result)
}
pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering {
if left_i < 0 {
Ordering::Less
} else {
let left_as_u = left_i as u64;
left_as_u.cmp(&right_u)
}
}
}
/// This modules helps projecting numerical values to other numerical types.
/// When the target value space cannot exactly represent the source value, the
/// next representable value is returned (or AfterLast if the source value is
/// larger than the largest representable value).
///
/// All functions in this module assume that f64 values are not NaN.
pub(super) mod num_proj {
#[derive(Debug, PartialEq)]
pub enum ProjectedNumber<T> {
Exact(T),
Next(T),
AfterLast,
}
pub fn i64_to_u64(value: i64) -> ProjectedNumber<u64> {
if value < 0 {
ProjectedNumber::Next(0)
} else {
ProjectedNumber::Exact(value as u64)
}
}
pub fn u64_to_i64(value: u64) -> ProjectedNumber<i64> {
if value > i64::MAX as u64 {
ProjectedNumber::AfterLast
} else {
ProjectedNumber::Exact(value as i64)
}
}
pub fn f64_to_u64(value: f64) -> ProjectedNumber<u64> {
if value < 0.0 {
ProjectedNumber::Next(0)
} else if value > u64::MAX as f64 {
ProjectedNumber::AfterLast
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as u64)
} else {
// casting f64 to u64 truncates toward zero
ProjectedNumber::Next(value as u64 + 1)
}
}
pub fn f64_to_i64(value: f64) -> ProjectedNumber<i64> {
if value < (i64::MIN as f64) {
return ProjectedNumber::Next(i64::MIN);
} else if value >= (i64::MAX as f64) {
return ProjectedNumber::AfterLast;
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as i64)
} else if value > 0.0 {
// casting f64 to i64 truncates toward zero
ProjectedNumber::Next(value as i64 + 1)
} else {
ProjectedNumber::Next(value as i64)
}
}
pub fn i64_to_f64(value: i64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as i64;
if k_roundtrip == value {
// between -2^53 and 2^53 all i64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else {
// for very large/small i64 values, it is approximated to the closest f64
if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
pub fn u64_to_f64(value: u64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as u64;
if k_roundtrip == value {
// between 0 and 2^53 all u64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
#[cfg(test)]
mod num_cmp_tests {
use std::cmp::Ordering;
use super::num_cmp::*;
#[test]
fn test_cmp_u64_f64() {
// Basic comparisons
assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less);
// Negative float values should always be less than any u64
assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater);
// Tests with extreme values
assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less);
// Precision edge cases: large u64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_u64 = 18_014_398_509_481_984;
// prove that large_u64 is exactly represented as f64
assert_eq!(large_u64 as f64, large_f64);
assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal);
// => (2^54 + 1) cannot be exactly represented in f64
let large_u64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_u64_plus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (2^54 - 1) cannot be exactly represented in f64
let large_u64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_u64_minus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// NaN comparison results in an error
assert!(cmp_u64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_f64() {
// Basic comparisons
assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal);
// Tests with extreme values
assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater);
// Precision edge cases: large i64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_i64 = 18_014_398_509_481_984;
// prove that large_i64 is exactly represented as f64
assert_eq!(large_i64 as f64, large_f64);
assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal);
// => (1_i64 << 54) + 1 cannot be exactly represented in f64
let large_i64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_i64_plus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (1_i64 << 54) - 1 cannot be exactly represented in f64
let large_i64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_i64_minus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// Same precision edge case but with negative values
// => -2^54, exactly represented as f64
let large_neg_f64 = -18_014_398_509_481_984.0;
let large_neg_i64 = -18_014_398_509_481_984;
// prove that large_neg_i64 is exactly represented as f64
assert_eq!(large_neg_i64 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(),
Ordering::Equal
);
// => (-2^54 + 1) cannot be exactly represented in f64
let large_neg_i64_plus_1 = -18_014_398_509_481_985;
// prove that it is represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(),
Ordering::Less
);
// => (-2^54 - 1) cannot be exactly represented in f64
let large_neg_i64_minus_1 = -18_014_398_509_481_983;
// prove that it is also represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(),
Ordering::Greater
);
// NaN comparison results in an error
assert!(cmp_i64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_u64() {
// Test with negative i64 values (should always be less than any u64)
assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less);
// Test with positive i64 values
assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal);
assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater);
assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal);
assert_eq!(cmp_i64_u64(0, 1), Ordering::Less);
assert_eq!(cmp_i64_u64(5, 10), Ordering::Less);
assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater);
// Test with values near i64::MAX and u64 conversion
assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal);
assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less);
}
}
#[cfg(test)]
mod num_proj_tests {
use super::num_proj::{self, ProjectedNumber};
#[test]
fn test_i64_to_u64() {
assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::i64_to_u64(i64::MAX),
ProjectedNumber::Exact(i64::MAX as u64)
);
}
#[test]
fn test_u64_to_i64() {
assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::u64_to_i64(i64::MAX as u64),
ProjectedNumber::Exact(i64::MAX)
);
assert_eq!(
num_proj::u64_to_i64((i64::MAX as u64) + 1),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast);
}
#[test]
fn test_f64_to_u64() {
assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_u64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43));
}
#[test]
fn test_f64_to_i64() {
assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN));
assert_eq!(
num_proj::f64_to_i64(f64::NEG_INFINITY),
ProjectedNumber::Next(i64::MIN)
);
assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_i64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42));
assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43));
assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42));
}
#[test]
fn test_i64_to_f64() {
assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0));
assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0));
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::i64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_i64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_i64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_i64");
}
// Test with very large negative value
let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1)
let closest_neg_f64 = -9_007_199_254_740_992.0;
assert_eq!(large_neg_i64 as f64, closest_neg_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) {
// Verify that the returned float is the closest representable f64
assert_eq!(val, closest_neg_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_neg_i64");
}
}
#[test]
fn test_u64_to_f64() {
assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0));
// Test the largest u64 value that can be exactly represented as f64 (2^53)
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::u64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_u64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_u64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_u64");
}
}
}

View File

@@ -6,14 +6,10 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema;
@@ -408,18 +404,15 @@ pub struct FilterAggReqData {
pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
}
}
@@ -496,24 +489,17 @@ impl Debug for DocumentQueryEvaluator {
}
}
#[derive(Debug, Clone, PartialEq, Copy)]
struct DocCount {
doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
pub struct SegmentFilterCollector {
/// Document count in this bucket
doc_count: u64,
/// Sub-aggregation collectors
sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl<C: SubAggCache> SegmentFilterCollector<C> {
impl SegmentFilterCollector {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
@@ -525,75 +511,47 @@ impl<C: SubAggCache> SegmentFilterCollector<C> {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
doc_count: 0,
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}
pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
impl Debug for SegmentFilterCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
.field("doc_count", &self.doc_count)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
impl CollectorClone for SegmentFilterCollector {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
if let Some(sub_aggs) = self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
}
// Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
doc_count: self.doc_count,
sub_aggregations: sub_results,
};
@@ -612,17 +570,32 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
Ok(())
}
fn collect(
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
docs: &[DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.is_empty() {
return Ok(());
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -631,24 +604,18 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
bucket.doc_count += req.matching_docs_buffer.len() as u64;
self.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
// Use collect_block for better sub-aggregation performance
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(())
}
@@ -659,21 +626,6 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
}
/// Intermediate result for filter aggregation
@@ -1567,9 +1519,9 @@ mod tests {
let searcher = reader.searcher();
let agg = json!({
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});

View File

@@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result<i64, AggregationError>
}
}
pub(crate) fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
let split_boundary = input
.as_bytes()
.iter()

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
@@ -26,8 +26,13 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
@@ -252,24 +257,18 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -279,38 +278,27 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
let bucket = self.into_intermediate_bucket_result(agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -319,40 +307,44 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
let val = f64_from_fastfield_u64(val, &req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
SegmentHistogramBucketEntry { key, doc_count: 0 }
});
bucket.doc_count += 1;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
}
}
}
@@ -366,30 +358,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_aggregation) = &mut self.sub_agg {
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -397,19 +373,22 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
pub fn into_intermediate_bucket_result(
self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(histogram.buckets.len());
let mut buckets = Vec::with_capacity(self.buckets.len());
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
buckets.push(bucket_res?);
}
@@ -429,7 +408,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let sub_agg = if !node.children.is_empty() {
let blueprint = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
@@ -444,13 +423,13 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
req_data.sub_aggregation_blueprint = blueprint;
Ok(Self {
parent_buckets: Default::default(),
sub_agg,
buckets: Default::default(),
sub_aggregations: Default::default(),
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}

View File

@@ -22,7 +22,6 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod composite;
mod filter;
mod histogram;
mod range;
@@ -32,7 +31,6 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use composite::*;
pub use filter::*;
pub use histogram::*;
pub use range::*;

View File

@@ -1,22 +1,18 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
@@ -27,12 +23,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
@@ -155,47 +151,19 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<C: SubAggCache> {
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
buckets: Vec<SegmentRangeAndBucketEntry>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -216,50 +184,48 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let sub_aggregation = IntermediateAggregationResults::default();
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation_res: sub_aggregation,
sub_aggregation: sub_aggregation_res,
from: self.from,
to: self.to,
})
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
impl SegmentAggregationCollector for SegmentRangeCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
.into_iter()
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(agg_data)?,
))
})
.collect::<crate::Result<_>>()?;
@@ -276,114 +242,73 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_range_req_data(self.accessor_idx);
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
req.column_block_accessor.fetch_block(docs, &req.accessor);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
for (doc, val) in req
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(agg_data)?;
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -392,20 +317,20 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
// let sub_aggregation = sub_agg_prototype.clone();
let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
bucket_id,
sub_aggregation,
key,
from,
to,
@@ -414,19 +339,26 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
})
.collect::<crate::Result<_>>()?;
self.limits.add_memory_consumed(
req_data.context.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(buckets)
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
@@ -524,7 +456,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
Ok(f64_from_fastfield_u64(val, field_type).to_string())
}
};
@@ -554,7 +486,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggCache> {
) -> SegmentRangeCollector {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
@@ -574,33 +506,30 @@ mod tests {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
parent_buckets: vec![buckets],
buckets,
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
@@ -847,7 +776,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -870,7 +799,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -885,7 +814,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -894,7 +823,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -903,7 +832,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -949,7 +878,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -961,3 +890,63 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,13 +5,11 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
@@ -37,55 +35,41 @@ impl MissingTermAggReqData {
}
}
#[derive(Default, Debug, Clone)]
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
missing_count: u32,
accessor_idx: usize,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
}
impl TermMissingAgg {
pub(crate) fn new(
agg_data: &mut AggregationsSegmentCtx,
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
..Default::default()
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
@@ -96,16 +80,13 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: missing_count.missing_count,
doc_count: self.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = &mut self.sub_agg {
if let Some(sub_agg) = self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -128,52 +109,30 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?;
}
}
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
self.collect(*doc, agg_data)?;
}
Ok(())
}

View File

@@ -0,0 +1,87 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -1,245 +0,0 @@
use std::fmt::Debug;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
Ok(())
}
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and there may be nothing to group.
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
partition.clear();
}
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
_force: bool,
) -> crate::Result<()> {
let mut max_bucket = 0u32;
for partition in self.partitions.iter() {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
for slot in self.partitions.iter() {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
}
}
self.clear();
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
}
}
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
// Note: There may be other flexible values, for other aggregations, but we can use the
// const value here as a upper bound. (better than nothing)
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold_limit > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
}
self.clear();
Ok(())
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardCachedSubAggs,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -151,11 +151,8 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
@@ -173,31 +170,26 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
self.agg_collector.push(0, doc);
match self
if let Err(err) = self
.agg_collector
.check_flush_local(&mut self.aggs_with_accessor)
.collect(doc, &mut self.aggs_with_accessor)
{
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
}
}
@@ -208,13 +200,10 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -15,9 +15,8 @@ use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry};
use super::bucket::{
composite_intermediate_key_ordering, cut_off_buckets, get_agg_name_and_property,
intermediate_histogram_buckets_to_final_buckets, CompositeAggregation, GetDocCount,
MissingOrder, Order, OrderTarget, RangeAggregation, TermsAggregation,
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
GetDocCount, Order, OrderTarget, RangeAggregation, TermsAggregation,
};
use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax,
@@ -26,7 +25,7 @@ use super::metric::{
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
@@ -91,19 +90,6 @@ impl From<IntermediateKey> for Key {
impl Eq for IntermediateKey {}
impl std::fmt::Display for IntermediateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IntermediateKey::Str(val) => f.write_str(val),
IntermediateKey::F64(val) => f.write_str(&val.to_string()),
IntermediateKey::U64(val) => f.write_str(&val.to_string()),
IntermediateKey::I64(val) => f.write_str(&val.to_string()),
IntermediateKey::Bool(val) => f.write_str(&val.to_string()),
IntermediateKey::IpAddr(val) => f.write_str(&val.to_string()),
}
}
}
impl std::hash::Hash for IntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
@@ -119,21 +105,6 @@ impl std::hash::Hash for IntermediateKey {
}
impl IntermediateAggregationResults {
/// Returns a reference to the intermediate aggregation result for the given key.
pub fn get(&self, key: &str) -> Option<&IntermediateAggregationResult> {
self.aggs_res.get(key)
}
/// Removes and returns the intermediate aggregation result for the given key.
pub fn remove(&mut self, key: &str) -> Option<IntermediateAggregationResult> {
self.aggs_res.remove(key)
}
/// Returns an iterator over the keys in the intermediate aggregation results.
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.aggs_res.keys()
}
/// Add a result
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) -> crate::Result<()> {
let entry = self.aggs_res.entry(key);
@@ -281,11 +252,6 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
Composite(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite {
buckets: IntermediateCompositeBucketResult::default(),
})
}
}
}
@@ -479,11 +445,6 @@ pub enum IntermediateBucketResult {
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
/// Composite aggregation
Composite {
/// The composite buckets
buckets: IntermediateCompositeBucketResult,
},
}
impl IntermediateBucketResult {
@@ -579,13 +540,6 @@ impl IntermediateBucketResult {
sub_aggregations: final_sub_aggregations,
}))
}
IntermediateBucketResult::Composite { buckets } => {
let composite_req = req
.agg
.as_composite()
.expect("unexpected aggregation, expected composite aggregation");
buckets.into_final_result(composite_req, req.sub_aggregation(), limits)
}
}
}
@@ -652,16 +606,6 @@ impl IntermediateBucketResult {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(
IntermediateBucketResult::Composite {
buckets: composite_left,
},
IntermediateBucketResult::Composite {
buckets: composite_right,
},
) => {
composite_left.merge_fruits(composite_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -674,9 +618,6 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Composite { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -698,21 +639,6 @@ pub struct IntermediateTermBucketResult {
}
impl IntermediateTermBucketResult {
/// Returns a reference to the map of bucket entries keyed by [`IntermediateKey`].
pub fn entries(&self) -> &FxHashMap<IntermediateKey, IntermediateTermBucketEntry> {
&self.entries
}
/// Returns the count of documents not included in the returned buckets.
pub fn sum_other_doc_count(&self) -> u64 {
self.sum_other_doc_count
}
/// Returns the upper bound of the error on document counts in the returned buckets.
pub fn doc_count_error_upper_bound(&self) -> u64 {
self.doc_count_error_upper_bound
}
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
@@ -866,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation_res: IntermediateAggregationResults,
pub sub_aggregation: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -885,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation_res
.sub_aggregation
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -894,7 +820,7 @@ impl IntermediateRangeBucketEntry {
};
// If we have a date type on the histogram buckets, we add the `key_as_string` field as
// rfc3339
// rfc339
if column_type == Some(ColumnType::DateTime) {
if let Some(val) = range_bucket_entry.to {
let key_as_string = format_date(val as i64)?;
@@ -931,8 +857,7 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
Ok(())
}
}
@@ -945,176 +870,6 @@ impl MergeFruits for IntermediateHistogramBucketEntry {
}
}
/// Entry for the composite bucket.
pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry;
/// The fully typed key for composite aggregation
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum CompositeIntermediateKey {
/// Bool key
Bool(bool),
/// String key
Str(String),
/// Float key
F64(f64),
/// Signed integer key
I64(i64),
/// Unsigned integer key
U64(u64),
/// DateTime key, nanoseconds since epoch
DateTime(i64),
/// IP Address key
IpAddr(Ipv6Addr),
/// Missing value key
Null,
}
impl Eq for CompositeIntermediateKey {}
impl std::hash::Hash for CompositeIntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
CompositeIntermediateKey::Bool(val) => val.hash(state),
CompositeIntermediateKey::Str(text) => text.hash(state),
CompositeIntermediateKey::F64(val) => val.to_bits().hash(state),
CompositeIntermediateKey::U64(val) => val.hash(state),
CompositeIntermediateKey::I64(val) => val.hash(state),
CompositeIntermediateKey::DateTime(val) => val.hash(state),
CompositeIntermediateKey::IpAddr(val) => val.hash(state),
CompositeIntermediateKey::Null => {}
}
}
}
/// Composite aggregation page.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateCompositeBucketResult {
pub(crate) entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
pub(crate) target_size: u32,
pub(crate) orders: Vec<(Order, MissingOrder)>,
}
impl IntermediateCompositeBucketResult {
pub(crate) fn into_final_result(
self,
req: &CompositeAggregation,
sub_aggregation_req: &Aggregations,
limits: &mut AggregationLimitsGuard,
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let buckets = trimmed_entry_vec
.into_iter()
.map(|(intermediate_key, entry)| {
let key = intermediate_key
.into_iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.into())
})
.collect();
Ok(CompositeBucketEntry {
key,
doc_count: entry.doc_count as u64,
sub_aggregation: entry
.sub_aggregation
.into_final_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<Vec<_>>>()?;
Ok(BucketResult::Composite { after_key, buckets })
}
fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> {
merge_maps(&mut self.entries, other.entries)?;
if self.entries.len() as u32 > 2 * self.target_size {
self.trim()?;
}
Ok(())
}
/// Trim the composite buckets to the target size, according to the ordering.
pub(crate) fn trim(&mut self) -> crate::Result<()> {
if self.entries.len() as u32 <= self.target_size {
return Ok(());
}
let sorted_entries = trim_composite_buckets(
std::mem::take(&mut self.entries),
&self.orders,
self.target_size,
)?;
self.entries = sorted_entries.into_iter().collect();
Ok(())
}
}
fn trim_composite_buckets(
entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
orders: &[(Order, MissingOrder)],
target_size: u32,
) -> crate::Result<
Vec<(
Vec<CompositeIntermediateKey>,
IntermediateCompositeBucketEntry,
)>,
> {
let mut entries: Vec<_> = entries.into_iter().collect();
let mut sort_error: Option<TantivyError> = None;
entries.sort_by(|(left_key, _), (right_key, _)| {
if sort_error.is_some() {
return Ordering::Equal;
}
for idx in 0..orders.len() {
match composite_intermediate_key_ordering(
&left_key[idx],
&right_key[idx],
orders[idx].0,
orders[idx].1,
) {
Ok(ordering) if ordering != Ordering::Equal => return ordering,
Ok(_) => continue,
Err(err) => {
sort_error = Some(err);
break;
}
}
}
Ordering::Equal
});
if let Some(err) = sort_error {
return Err(err);
}
entries.truncate(target_size as usize);
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
@@ -1132,7 +887,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation_res: Default::default(),
sub_aggregation: Default::default(),
from: None,
to: None,
},
@@ -1165,7 +920,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,15 +52,11 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Returns a reference to the underlying [`IntermediateStats`].
pub fn stats(&self) -> &IntermediateStats {
&self.stats
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
self.stats.merge_fruits(other.stats);

View File

@@ -1,11 +1,12 @@
use std::hash::Hash;
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use datasketches::hll::{HllSketch, HllType, HllUnion};
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
@@ -15,17 +16,29 @@ use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
/// Log2 of the number of registers for the HLL sketch.
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
const LG_K: u8 = 11;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// Apache DataSketches HyperLogLog algorithm. This is particularly useful for
/// understanding the uniqueness of values in a large dataset where counting
/// each unique value individually would be computationally expensive.
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
/// uniqueness of values in a large dataset where counting each unique value
/// individually would be computationally expensive.
///
/// For example, you might use a cardinality aggregation to estimate the number
/// of unique visitors to a website by aggregating on a field that contains
@@ -93,6 +106,8 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
@@ -120,34 +135,45 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
accessor_idx: usize,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: FxHashSet::default(),
entries: Default::default(),
accessor_idx,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
req_data: &CardinalityAggReqData,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
@@ -168,10 +194,9 @@ impl SegmentCardinalityCollectorBucket {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.insert(term);
self.cardinality.sketch.insert_any(&term);
Ok(())
})?;
if has_missing {
@@ -182,17 +207,17 @@ impl SegmentCardinalityCollectorBucket {
);
match missing_key {
Key::Str(missing) => {
self.cardinality.insert(missing.as_str());
self.cardinality.sketch.insert_any(&missing);
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.insert(val);
self.cardinality.sketch.insert_any(&val);
}
Key::U64(val) => {
self.cardinality.insert(*val);
self.cardinality.sketch.insert_any(&val);
}
Key::I64(val) => {
self.cardinality.insert(*val);
self.cardinality.sketch.insert_any(&val);
}
}
}
@@ -202,49 +227,16 @@ impl SegmentCardinalityCollectorBucket {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -255,20 +247,27 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
bucket.entries.insert(term_ord);
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
.accessor
.values
.clone()
@@ -283,43 +282,23 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.insert(val);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.insert(val);
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug)]
/// The cardinality collector used during segment collection and for merging results.
/// Uses Apache DataSketches HLL (lg_k=11, Hll4) for compact binary serialization
/// and cross-language compatibility (e.g. Java `datasketches` library).
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
pub struct CardinalityCollector {
sketch: HllSketch,
/// Salt derived from `ColumnType`, used to differentiate values of different column types
/// that map to the same u64 (e.g. bool `false` = 0 vs i64 `0`).
/// Not serialized — only needed during insertion, not after sketch registers are populated.
salt: u8,
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
@@ -332,52 +311,25 @@ impl PartialEq for CardinalityCollector {
}
}
impl Serialize for CardinalityCollector {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.sketch.serialize();
serializer.serialize_bytes(&bytes)
}
}
impl<'de> Deserialize<'de> for CardinalityCollector {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
Ok(Self { sketch, salt: 0 })
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
}
fn new(salt: u8) -> Self {
Self {
sketch: HllSketch::new(LG_K, HllType::Hll4),
salt,
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
}
}
/// Insert a value into the HLL sketch, salted by the column type.
/// The salt ensures that identical u64 values from different column types
/// (e.g. bool `false` vs i64 `0`) are counted as distinct.
pub(crate) fn insert<T: Hash>(&mut self, value: T) {
self.sketch.update((self.salt, value));
}
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.estimate().trunc())
}
/// Serialize the HLL sketch to its compact binary representation.
/// The format is cross-language compatible with Apache DataSketches (Java, C++, Python).
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.serialize()
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
let mut union = HllUnion::new(LG_K);
union.update(&self.sketch);
union.update(&right.sketch);
self.sketch = union.get_result(HllType::Hll4);
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
Ok(())
}
}
@@ -539,75 +491,4 @@ mod tests {
Ok(())
}
#[test]
fn cardinality_collector_serde_roundtrip() {
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("hello");
collector.insert("world");
collector.insert("hello"); // duplicate
let serialized = serde_json::to_vec(&collector).unwrap();
let deserialized: CardinalityCollector = serde_json::from_slice(&serialized).unwrap();
let original_estimate = collector.finalize().unwrap();
let roundtrip_estimate = deserialized.finalize().unwrap();
assert_eq!(original_estimate, roundtrip_estimate);
assert_eq!(original_estimate, 2.0);
}
#[test]
fn cardinality_collector_merge() {
use super::CardinalityCollector;
let mut left = CardinalityCollector::default();
left.insert("a");
left.insert("b");
let mut right = CardinalityCollector::default();
right.insert("b");
right.insert("c");
left.merge_fruits(right).unwrap();
let estimate = left.finalize().unwrap();
assert_eq!(estimate, 3.0);
}
#[test]
fn cardinality_collector_serialize_deserialize_binary() {
use datasketches::hll::HllSketch;
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("apple");
collector.insert("banana");
collector.insert("cherry");
let bytes = collector.to_sketch_bytes();
let deserialized = HllSketch::deserialize(&bytes).unwrap();
assert!((deserialized.estimate() - 3.0).abs() < 0.01);
}
#[test]
fn cardinality_collector_salt_differentiates_types() {
use super::CardinalityCollector;
// Without salt, same u64 value from different column types would collide
let mut collector_bool = CardinalityCollector::new(5); // e.g. ColumnType::Bool
collector_bool.insert(0u64); // false
collector_bool.insert(1u64); // true
let mut collector_i64 = CardinalityCollector::new(2); // e.g. ColumnType::I64
collector_i64.insert(0u64);
collector_i64.insert(1u64);
// Merge them
collector_bool.merge_fruits(collector_i64).unwrap();
let estimate = collector_bool.finalize().unwrap();
// Should be 4 because salt makes (5, 0) != (2, 0) and (5, 1) != (2, 1)
assert_eq!(estimate, 4.0);
}
}

View File

@@ -52,8 +52,10 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,9 +8,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -317,28 +318,51 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
missing,
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
}
@@ -346,18 +370,15 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
extended_stats,
self.extended_stats,
)),
)?;
@@ -367,36 +388,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnType};
use columnar::{Column, ColumnBlockAccessor, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -55,6 +55,8 @@ pub struct MetricAggReqData {
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result
@@ -107,11 +109,8 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
/// Percentile
pub key: f64,
/// Value at the percentile
pub value: f64,
key: f64,
value: f64,
}
/// Single-metric aggregations use this common result structure.

View File

@@ -7,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// # Percentiles
///
@@ -130,16 +131,10 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) percentiles: PercentilesCollector,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -222,12 +217,6 @@ impl PercentilesCollector {
self.sketch.add(val);
}
/// Encode the underlying DDSketch to Java-compatible binary format
/// for cross-language serialization with Java consumers.
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.to_java_bytes()
}
pub(crate) fn merge_fruits(&mut self, right: PercentilesCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
@@ -240,18 +229,33 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
percentiles: PercentilesCollector::new(),
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -259,18 +263,12 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
results.push(
name,
@@ -283,33 +281,40 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}
@@ -331,7 +336,7 @@ mod tests {
use crate::aggregation::AggregationCollector;
use crate::query::AllQuery;
use crate::schema::{Schema, FAST};
use crate::{assert_nearly_equals, Index};
use crate::Index;
#[test]
fn test_aggregation_percentiles_empty_index() -> crate::Result<()> {
@@ -614,16 +619,12 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["range_with_stats"]["buckets"][0]["doc_count"], 3);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"],
5.0028295751107414
);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"],
10.07469668951144
);
@@ -669,14 +670,8 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_nearly_equals!(
res["percentiles"]["values"]["1.0"].as_f64().unwrap(),
5.0028295751107414
);
assert_nearly_equals!(
res["percentiles"]["values"]["99.0"].as_f64().unwrap(),
10.07469668951144
);
assert_eq!(res["percentiles"]["values"]["1.0"], 5.0028295751107414);
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951144);
Ok(())
}

View File

@@ -1,6 +1,5 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
@@ -8,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -110,16 +110,6 @@ impl Default for IntermediateStats {
}
impl IntermediateStats {
/// Returns the number of values collected.
pub fn count(&self) -> u64 {
self.count
}
/// Returns the sum of all values collected.
pub fn sum(&self) -> f64 {
self.sum
}
/// Merges the other stats intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateStats) {
self.count += other.count;
@@ -197,75 +187,75 @@ pub enum StatsType {
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
impl SegmentStatsCollector {
pub fn from_req(accessor_idx: usize) -> Self {
Self {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
let intermediate_metric_result = match req.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
self.collecting_for
req.collecting_for
)))
}
};
@@ -281,67 +271,41 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,8 +52,10 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -15,11 +15,12 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{AggregationError, BucketId};
use crate::aggregation::AggregationError;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
@@ -471,10 +472,7 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
req: req.clone(),
}
}
@@ -520,8 +518,7 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
}
impl TopHitsSegmentCollector {
@@ -530,29 +527,19 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
num_hits,
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
fn into_top_hits_collector(
self,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = top_n.into_vec();
let top_results = self.top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
@@ -567,24 +554,54 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, &req_data.req),
);
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
@@ -594,54 +611,24 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let req = &req_data.req;
let accessors = &req_data.accessors;
for &doc_id in docs {
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn prepare_max_bucket(
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
}
Ok(())
}
}
@@ -759,7 +746,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -888,7 +875,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
pub(crate) mod cached_sub_aggs;
mod buf_collector;
mod collector;
mod date;
mod error;
@@ -162,19 +162,6 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
@@ -348,37 +335,19 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
match field_type {
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
}
}

View File

@@ -8,67 +8,25 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: Debug {
pub trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
fn collect_block(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -78,7 +36,26 @@ pub trait SegmentAggregationCollector: Debug {
}
}
#[derive(Default)]
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -96,13 +73,12 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
&mut self,
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?;
}
Ok(())
@@ -110,13 +86,23 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect(parent_bucket_id, docs, agg_data)?;
collector.collect_block(docs, agg_data)?;
}
Ok(())
}
@@ -126,15 +112,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}

View File

@@ -486,9 +486,9 @@ mod tests {
use std::collections::BTreeSet;
use columnar::Dictionary;
use rand::distr::Uniform;
use rand::distributions::Uniform;
use rand::prelude::SliceRandom;
use rand::{rng, Rng};
use rand::{thread_rng, Rng};
use super::{FacetCollector, FacetCounts};
use crate::collector::facet_collector::compress_mapping;
@@ -731,7 +731,7 @@ mod tests {
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let uniform = Uniform::new_inclusive(1, 100_000).unwrap();
let uniform = Uniform::new_inclusive(1, 100_000);
let mut docs: Vec<TantivyDocument> =
vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)]
.into_iter()
@@ -741,11 +741,14 @@ mod tests {
std::iter::repeat_n(doc, count)
})
.map(|mut doc| {
doc.add_facet(facet_field, &format!("/facet/{}", rng().sample(uniform)));
doc.add_facet(
facet_field,
&format!("/facet/{}", thread_rng().sample(uniform)),
);
doc
})
.collect();
docs[..].shuffle(&mut rng());
docs[..].shuffle(&mut thread_rng());
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for doc in docs {
@@ -819,8 +822,8 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rng;
use rand::seq::SliceRandom;
use rand::thread_rng;
use test::Bencher;
use crate::collector::FacetCollector;
@@ -843,7 +846,7 @@ mod bench {
}
}
// 40425 docs
docs[..].shuffle(&mut rng());
docs[..].shuffle(&mut thread_rng());
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for doc in docs {

View File

@@ -1,50 +1,25 @@
mod order;
mod sort_by_bytes;
mod sort_by_erased_type;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_bytes::SortByBytes;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
mod tests {
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::schema::{Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
@@ -319,9 +294,11 @@ pub(crate) mod tests {
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
let results: Vec<((Score, Option<String>), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(f, doc)| (f, ids[&doc]))
.collect())
}
assert_eq!(
@@ -346,51 +323,6 @@ pub(crate) mod tests {
Ok(())
}
#[test]
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, OwnedValue);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
(SortBySimilarityScore, score_order),
(SortByErasedType::for_field("city"), city_order),
));
let results: Vec<((Score, OwnedValue), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Null), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Null), 3),
]
);
Ok(())
}
use proptest::prelude::*;
proptest! {
@@ -440,10 +372,15 @@ pub(crate) mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
sort_hits(&mut comparable_docs, order);
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -1,116 +1,36 @@
use std::cmp::Ordering;
use columnar::MonotonicallyMappableToU64;
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::schema::Schema;
use crate::{DocId, Order, Score};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
(OwnedValue::Null, _) => {
if NULLS_FIRST {
Ordering::Less
} else {
Ordering::Greater
}
}
(_, OwnedValue::Null) => {
if NULLS_FIRST {
Ordering::Greater
} else {
Ordering::Less
}
}
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
if *b < 0 {
Ordering::Greater
} else {
a.cmp(&(*b as u64))
}
}
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
if *a < 0 {
Ordering::Less
} else {
(*a as u64).cmp(b)
}
}
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(a, b) => {
let ord = a.discriminant_value().cmp(&b.discriminant_value());
// If the discriminant is equal, it's because a new type was added, but hasn't been
// included in this `match` statement.
assert!(
ord != Ordering::Equal,
"Unimplemented comparison for type of {a:?}, {b:?}"
);
ord
}
}
}
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
}
/// Compare values naturally (e.g. 1 < 2).
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first).
///
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
lhs.partial_cmp(rhs).unwrap()
}
}
/// A (partial) implementation of comparison for OwnedValue.
/// Sorts document in reverse order.
///
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
/// mismatched types. The one exception is Null, for which we do define all comparisons.
impl Comparator<OwnedValue> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
}
}
/// Compare values in reverse (e.g. 2 < 1).
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first).
///
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
@@ -123,15 +43,11 @@ where NaturalComparator: Comparator<T>
}
}
/// Compare values in reverse, but treating `None` as lower than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
/// (e.g. `[Some(10), Some(20), None]`).
/// Sorts document in reverse order, but considers None as having the lowest value.
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in an e-commerce website, if sorting by price ascending,
/// the cheapest items would appear first, and items without a price would appear last.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
@@ -191,84 +107,6 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
}
}
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first), but with `None` values appearing first
/// (e.g. `[None, Some(20), Some(10)]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalNoneIsHigherComparator;
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
}
/// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum {
@@ -277,10 +115,8 @@ pub enum ComparatorEnum {
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
ReverseNoneLower,
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
NaturalNoneHigher,
}
impl From<Order> for ComparatorEnum {
@@ -297,7 +133,6 @@ where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
NaturalNoneIsHigherComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
@@ -305,7 +140,6 @@ where
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
}
@@ -488,12 +322,11 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: Clone + 'static + Sync + Send,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
@@ -513,55 +346,3 @@ where
.convert_segment_sort_key(sort_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::OwnedValue;
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = None;
let v1 = Some(1_u64);
let v2 = Some(2_u64);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(None, Some(2)) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(Some(1), None) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(None, None) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
#[test]
fn test_mixed_ownedvalue_compare() {
let u = OwnedValue::U64(10);
let i = OwnedValue::I64(10);
let f = OwnedValue::F64(10.0);
let nc = NaturalComparator;
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
let u2 = OwnedValue::U64(11);
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
let s = OwnedValue::Str("a".to_string());
// Str < U64
assert_eq!(nc.compare(&s, &u), Ordering::Less);
// Str < I64
assert_eq!(nc.compare(&s, &i), Ordering::Less);
// Str < F64
assert_eq!(nc.compare(&s, &f), Ordering::Less);
}
}

View File

@@ -1,168 +0,0 @@
use columnar::BytesColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a bytes column.
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByBytes {
column_name: String,
}
impl SortByBytes {
/// Creates a new sort by bytes sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByBytes {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByBytes {
type SortKey = Option<Vec<u8>>;
type Child = ByBytesColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })
}
}
/// Segment-level sort key computer for bytes columns.
pub struct ByBytesColumnSegmentSortKeyComputer {
bytes_column_opt: Option<BytesColumn>,
}
impl SegmentSortKeyComputer for ByBytesColumnSegmentSortKeyComputer {
type SortKey = Option<Vec<u8>>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let bytes_column = self.bytes_column_opt.as_ref()?;
bytes_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<Vec<u8>> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let bytes_column = self.bytes_column_opt.as_ref()?;
let mut bytes = Vec::new();
bytes_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
Some(bytes)
}
}
#[cfg(test)]
mod tests {
use super::SortByBytes;
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{BytesOptions, Schema, FAST, INDEXED};
use crate::{Index, IndexWriter, Order, TantivyDocument};
#[test]
fn test_sort_by_bytes_asc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let id_field = schema_builder.add_u64_field("id", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Insert documents with byte values in non-sorted order
let test_data: Vec<(u64, Vec<u8>)> = vec![
(1, vec![0x02, 0x00]),
(2, vec![0x00, 0x10]),
(3, vec![0x01, 0x00]),
(4, vec![0x00, 0x20]),
];
for (id, bytes) in &test_data {
let mut doc = TantivyDocument::new();
doc.add_u64(id_field, *id);
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort ascending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Asc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order: [0x00,0x10], [0x00,0x20], [0x01,0x00], [0x02,0x00]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x00, 0x10]),
Some(vec![0x00, 0x20]),
Some(vec![0x01, 0x00]),
Some(vec![0x02, 0x00]),
]
);
Ok(())
}
#[test]
fn test_sort_by_bytes_desc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let test_data: Vec<Vec<u8>> = vec![vec![0x00, 0x10], vec![0x02, 0x00], vec![0x01, 0x00]];
for bytes in &test_data {
let mut doc = TantivyDocument::new();
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort descending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Desc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order (descending): [0x02,0x00], [0x01,0x00], [0x00,0x10]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x02, 0x00]),
Some(vec![0x01, 0x00]),
Some(vec![0x00, 0x10]),
]
);
Ok(())
}
}

View File

@@ -1,430 +0,0 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use crate::collector::sort_key::{
NaturalComparator, SortByBytes, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
/// use a SortKeyComputer implementation with a known-type at compile time.
#[derive(Debug, Clone)]
pub enum SortByErasedType {
/// Sort by a fast field
Field(String),
/// Sort by score
Score,
}
impl SortByErasedType {
/// Creates a new sort key computer which will sort by the given fast field column, with type
/// erasure.
pub fn for_field(column_name: impl ToString) -> Self {
Self::Field(column_name.to_string())
}
/// Creates a new sort key computer which will sort by score, with type erasure.
pub fn for_score() -> Self {
Self::Score
}
}
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
inner: C,
converter: F,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
}
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
Some(score_value.to_u64())
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
}
}
impl SortKeyComputer for SortByErasedType {
type SortKey = OwnedValue;
type Child = ErasedColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
matches!(self, Self::Score)
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {
let fast_fields = segment_reader.fast_fields();
// TODO: We currently double-open the column to avoid relying on the implementation
// details of `SortByString` or `SortByStaticFastValue`. Once
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
// the column that we open here.
let (_column, column_type) =
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
FastFieldNotAvailableError {
field_name: column_name.to_owned(),
}
})?;
match column_type {
ColumnType::Str => {
let computer = SortByString::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::Bytes => {
let computer = SortByBytes::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<Vec<u8>>| {
val.map(OwnedValue::Bytes).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::I64 => {
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::F64 => {
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::Bool => {
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::DateTime => {
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
})
}
column_type => {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {column_type:?}, which is not supported for \
sorting by owned value yet.",
column_name
)))
}
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
}
}
pub struct ErasedColumnSegmentSortKeyComputer {
inner: Box<dyn ErasedSegmentSortKeyComputer>,
}
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}
}
#[cfg(test)]
mod tests {
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_owned_u64() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
);
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("id"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
);
}
#[test]
fn test_sort_by_owned_string() {
let mut schema_builder = Schema::builder();
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(city_field => "tokyo")).unwrap();
writer.add_document(doc!(city_field => "austin")).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("city"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Str("austin".to_string()),
OwnedValue::Str("tokyo".to_string()),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_bytes() {
let mut schema_builder = Schema::builder();
let data_field = schema_builder.add_bytes_field("data", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer
.add_document(doc!(data_field => vec![0x03u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x01u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x02u8, 0x00]))
.unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// Sort descending (Natural - highest first)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("data"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Null
]
);
// Sort ascending (ReverseNoneLower - lowest first, nulls last)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("data"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
);
}
#[test]
fn test_sort_by_owned_score() {
let mut schema_builder = Schema::builder();
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(body_field => "a a")).unwrap();
writer.add_document(doc!(body_field => "a")).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
let query = query_parser.parse_query("a").unwrap();
// Sort by score descending (Natural)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] > values[1]);
// Sort by score ascending (ReverseNoneLower)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_score(),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] < values[1]);
}
}

View File

@@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore {
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {

View File

@@ -34,7 +34,9 @@ impl<T: FastValue> SortByStaticFastValue<T> {
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
@@ -82,8 +84,8 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {

View File

@@ -30,7 +30,9 @@ impl SortByString {
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
@@ -48,8 +50,8 @@ pub struct ByStringColumnSegmentSortKeyComputer {
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -58,8 +60,6 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();

View File

@@ -12,21 +12,13 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + Send + Sync + Clone;
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
Self::SegmentComparator::default()
}
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
@@ -55,7 +47,7 @@ pub trait SegmentSortKeyComputer: 'static {
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.segment_comparator().compare(left, right)
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
@@ -89,7 +81,7 @@ pub trait SegmentSortKeyComputer: 'static {
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
@@ -144,7 +136,10 @@ where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = (
@@ -193,11 +188,6 @@ where
TailSegmentSortKeyComputer::SegmentSortKey,
);
type SegmentComparator = (
HeadSegmentSortKeyComputer::SegmentComparator,
TailSegmentSortKeyComputer::SegmentComparator,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -279,12 +269,11 @@ impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync,
NewScore: 'static + Clone + Send + Sync,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
@@ -474,7 +463,6 @@ where
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)

View File

@@ -160,7 +160,7 @@ mod tests {
expected: &[(crate::Score, usize)],
) {
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
vals.shuffle(&mut rand::rng());
vals.shuffle(&mut rand::thread_rng());
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
assert_eq!(&vals_merged, expected);
}

View File

@@ -1,22 +1,64 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// The feature of the document. In practice, this is
/// is a type which can be compared with a `Comparator<T>`.
/// is any type that implements `PartialOrd`.
pub sort_key: T,
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -23,9 +23,10 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -324,7 +325,7 @@ impl TopDocs {
sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static,
) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>>
where
TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug,
TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug,
{
TopBySortKeyCollector::new(sort_key_computer, self.doc_range())
}
@@ -445,7 +446,7 @@ where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
SegmentSortKeyComputer<SortKey = TSortKey>,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
@@ -480,7 +481,6 @@ where
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score)
@@ -500,13 +500,8 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
@@ -585,18 +580,6 @@ where
}
}
#[inline(always)]
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
) -> std::cmp::Ordering {
c.compare(&lhs.sort_key, &rhs.sort_key)
.reverse() // Reverse here because we want top K.
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
// sort by doc id
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
where
D: Ord,
@@ -617,13 +600,10 @@ where
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
return;
}
}
@@ -649,7 +629,9 @@ where
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
compare_for_top_k(&self.comparator, lhs, rhs)
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
});
let median_score = median_el.sort_key.clone();
@@ -664,8 +646,11 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
}
@@ -770,33 +755,6 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -814,17 +772,14 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1451,10 +1406,15 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -48,15 +48,7 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>,
{
match self {
Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect();
let num_fruits = args.len();

View File

@@ -227,6 +227,9 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
ReferenceValueLeaf::IpAddr(_) => {
unimplemented!("IP address support in dynamic fields is not yet implemented")
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not implemented")
}
},
ReferenceValue::Array(elements) => {
for val in elements {
@@ -406,7 +409,7 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_str("red");
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
}
#[test]
@@ -416,8 +419,8 @@ mod tests {
term.append_type_and_fast_value(-4i64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
)
}
@@ -428,8 +431,8 @@ mod tests {
term.append_type_and_fast_value(4u64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
)
}
@@ -439,8 +442,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(4.0f64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
)
}
@@ -450,8 +453,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(true);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
)
}

View File

@@ -5,7 +5,7 @@ use std::ops::Range;
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
use crate::schema::{Field, Schema};
use crate::schema::Field;
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
@@ -167,11 +167,10 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
pub fn space_usage(&self) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let field_name = schema.get_field_name(field_addr.field).to_string();
let mut field_usage = FieldUsage::empty(field_name);
let mut field_usage = FieldUsage::empty(field_addr.field);
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}

View File

@@ -1,5 +1,3 @@
mod file_watcher;
use std::collections::HashMap;
use std::fmt;
use std::fs::{self, File, OpenOptions};
@@ -9,7 +7,6 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, Weak};
use common::StableDeref;
use file_watcher::FileWatcher;
use fs4::fs_std::FileExt;
#[cfg(all(feature = "mmap", unix))]
pub use memmap2::Advice;
@@ -21,6 +18,7 @@ use crate::core::META_FILEPATH;
use crate::directory::error::{
DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError,
};
use crate::directory::file_watcher::FileWatcher;
use crate::directory::{
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
WatchCallback, WatchHandle, WritePtr,
@@ -676,7 +674,7 @@ mod tests {
let num_segments = reader.searcher().segment_readers().len();
assert!(num_segments <= 4);
let num_components_except_deletes_and_tempstore =
crate::index::SegmentComponent::iterator().len() - 1;
crate::index::SegmentComponent::iterator().len() - 2;
let max_num_mmapped = num_components_except_deletes_and_tempstore * num_segments;
assert_eventually(|| {
let num_mmapped = mmap_directory.get_cache_info().mmapped.len();

View File

@@ -5,6 +5,7 @@ mod mmap_directory;
mod directory;
mod directory_lock;
mod file_watcher;
pub mod footer;
mod managed_directory;
mod ram_directory;
@@ -21,7 +22,7 @@ use std::path::PathBuf;
pub use common::file_slice::{FileHandle, FileSlice};
pub use common::{AntiCallToken, OwnedBytes, TerminatingWrite};
pub use self::composite_file::{CompositeFile, CompositeWrite};
pub(crate) use self::composite_file::{CompositeFile, CompositeWrite};
pub use self::directory::{Directory, DirectoryClone, DirectoryLock};
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK};
pub use self::ram_directory::RamDirectory;
@@ -52,7 +53,7 @@ pub use self::mmap_directory::MmapDirectory;
///
/// `WritePtr` are required to implement both Write
/// and Seek.
pub type WritePtr = BufWriter<Box<dyn TerminatingWrite + Send + Sync>>;
pub type WritePtr = BufWriter<Box<dyn TerminatingWrite>>;
#[cfg(test)]
mod tests;

View File

@@ -40,8 +40,6 @@ pub trait DocSet: Send {
/// of `DocSet` should support it.
///
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
///
/// `target` has to be larger or equal to `.doc()` when calling `seek`.
fn seek(&mut self, target: DocId) -> DocId {
let mut doc = self.doc();
debug_assert!(doc <= target);
@@ -51,57 +49,6 @@ pub trait DocSet: Send {
doc
}
/// !!!Dragons ahead!!!
/// In spirit, this is an approximate and dangerous version of `seek`.
///
/// It can leave the DocSet in an `invalid` state and might return a
/// lower bound of what the result of Seek would have been.
///
///
/// More accurately it returns either:
/// - Found if the target is in the docset. In that case, the DocSet is left in a valid state.
/// - SeekLowerBound(seek_lower_bound) if the target is not in the docset. In that case, The
/// DocSet can be the left in a invalid state. The DocSet should then only receives call to
/// `seek_danger(..)` until it returns `Found`, and get back to a valid state.
///
/// `seek_lower_bound` can be any `DocId` (in the docset or not) as long as it is in
/// `(target .. seek_result] U {TERMINATED}` where `seek_result` is the first document in the
/// docset greater than to `target`.
///
/// `seek_danger` may return `SeekLowerBound(TERMINATED)`.
///
/// Calling `seek_danger` with TERMINATED as a target is allowed,
/// and should always return NewTarget(TERMINATED) or anything larger as TERMINATED is NOT in
/// the DocSet.
///
/// DocSets that already have an efficient `seek` method don't need to implement
/// `seek_danger`.
///
/// Consecutive calls to seek_danger are guaranteed to have strictly increasing `target`
/// values.
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
if target >= TERMINATED {
debug_assert!(target == TERMINATED);
// No need to advance.
return SeekDangerResult::SeekLowerBound(target);
}
// The default implementation does not include any
// `danger zone` behavior.
//
// It does not leave the scorer in an invalid state.
// For this reason, we can safely call `self.doc()`.
let mut doc = self.doc();
if doc < target {
doc = self.seek(target);
}
if doc == target {
SeekDangerResult::Found
} else {
SeekDangerResult::SeekLowerBound(doc)
}
}
/// Fills a given mutable buffer with the next doc ids from the
/// `DocSet`
///
@@ -147,15 +94,6 @@ pub trait DocSet: Send {
/// which would be the number of documents in the DocSet.
///
/// By default this returns `size_hint()`.
///
/// DocSets may have vastly different cost depending on their type,
/// e.g. an intersection with 10 hits is much cheaper than
/// a phrase search with 10 hits, since it needs to load positions.
///
/// ### Future Work
/// We may want to differentiate `DocSet` costs more more granular, e.g.
/// creation_cost, advance_cost, seek_cost on to get a good estimation
/// what query types to choose.
fn cost(&self) -> u64 {
self.size_hint() as u64
}
@@ -190,17 +128,6 @@ pub trait DocSet: Send {
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SeekDangerResult {
/// The target was found in the DocSet.
Found,
/// The target was not found in the DocSet.
/// We return a range in which the value could be.
/// The given target can be any DocId, that is <= than the first document
/// in the docset after the target.
SeekLowerBound(DocId),
}
impl DocSet for &mut dyn DocSet {
fn advance(&mut self) -> u32 {
(**self).advance()
@@ -210,10 +137,6 @@ impl DocSet for &mut dyn DocSet {
(**self).seek(target)
}
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
(**self).seek_danger(target)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -246,11 +169,6 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_danger(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)

View File

@@ -162,7 +162,7 @@ mod tests {
mod bench {
use rand::prelude::IteratorRandom;
use rand::rng;
use rand::thread_rng;
use test::Bencher;
use super::AliveBitSet;
@@ -176,7 +176,7 @@ mod bench {
}
fn remove_rand(raw: &mut Vec<u32>) {
let i = (0..raw.len()).choose(&mut rng()).unwrap();
let i = (0..raw.len()).choose(&mut thread_rng()).unwrap();
raw.remove(i);
}

View File

@@ -683,7 +683,7 @@ mod tests {
}
#[test]
fn test_datefastfield() -> crate::Result<()> {
fn test_datefastfield() {
let mut schema_builder = Schema::builder();
let date_field = schema_builder.add_date_field(
"date",
@@ -697,22 +697,28 @@ mod tests {
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
let mut index_writer = index.writer_for_tests().unwrap();
index_writer.set_merge_policy(Box::new(NoMergePolicy));
index_writer.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))?;
index_writer.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))?;
index_writer.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))?;
index_writer.commit()?;
let reader = index.reader()?;
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
@@ -746,7 +752,6 @@ mod tests {
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);
}
Ok(())
}
#[test]
@@ -879,7 +884,7 @@ mod tests {
const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000;
let times: Vec<DateTime> = std::iter::repeat_with(|| {
// +- One hour.
let t = T0 + rng.random_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS);
let t = T0 + rng.gen_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS);
DateTime::from_timestamp_micros(t)
})
.take(1_000)

View File

@@ -8,7 +8,7 @@ use columnar::{
};
use common::ByteCount;
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
use crate::core::json_utils::encode_column_name;
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
@@ -39,15 +39,19 @@ impl FastFieldReaders {
self.resolve_column_name_given_default_field(column_name, default_field_opt)
}
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
json_path_sep_to_dot(&mut field_name);
let space_usage = column_handle.space_usage()?;
let mut field_usage = FieldUsage::empty(field_name);
field_usage.set_column_usage(space_usage);
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
let mut field_usage = FieldUsage::empty(field);
field_usage.add_field_idx(0, num_bytes);
per_field_usages.push(field_usage);
}
// TODO fix space usage for JSON fields.
Ok(PerFieldSpaceUsage::new(per_field_usages))
}

View File

@@ -189,6 +189,9 @@ impl FastFieldsWriter {
.record_str(doc_id, field_name, &token.text);
}
}
ReferenceValueLeaf::Geometry(_) => {
panic!("Geometry fields should not be routed to fast field writer")
}
},
ReferenceValue::Array(val) => {
// TODO: Check this is the correct behaviour we want.
@@ -320,6 +323,9 @@ fn record_json_value_to_columnar_writer<'a, V: Value<'a>>(
"Pre-tokenized string support in dynamic fields is not yet implemented"
)
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not yet implemented")
}
},
ReferenceValue::Array(elements) => {
for el in elements {

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
use crate::schema::{Field, Schema};
use crate::schema::Field;
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
@@ -37,8 +37,8 @@ impl FieldNormReaders {
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
self.data.space_usage(schema)
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
}
/// Returns a handle to inner file

View File

@@ -1,6 +1,6 @@
use std::collections::HashSet;
use rand::{rng, Rng};
use rand::{thread_rng, Rng};
use crate::indexer::index_writer::MEMORY_BUDGET_NUM_BYTES_MIN;
use crate::schema::*;
@@ -29,7 +29,7 @@ fn test_functional_store() -> crate::Result<()> {
let index = Index::create_in_ram(schema);
let reader = index.reader()?;
let mut rng = rng();
let mut rng = thread_rng();
let mut index_writer: IndexWriter =
index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?;
@@ -38,9 +38,9 @@ fn test_functional_store() -> crate::Result<()> {
let mut doc_id = 0u64;
for _iteration in 0..get_num_iterations() {
let num_docs: usize = rng.random_range(0..4);
let num_docs: usize = rng.gen_range(0..4);
if !doc_set.is_empty() {
let doc_to_remove_id = rng.random_range(0..doc_set.len());
let doc_to_remove_id = rng.gen_range(0..doc_set.len());
let removed_doc_id = doc_set.swap_remove(doc_to_remove_id);
index_writer.delete_term(Term::from_field_u64(id_field, removed_doc_id));
}
@@ -70,10 +70,10 @@ const LOREM: &str = "Doc Lorem ipsum dolor sit amet, consectetur adipiscing elit
cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat \
non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
fn get_text() -> String {
use rand::seq::IndexedRandom;
let mut rng = rng();
use rand::seq::SliceRandom;
let mut rng = thread_rng();
let tokens: Vec<_> = LOREM.split(' ').collect();
let random_val = rng.random_range(0..20);
let random_val = rng.gen_range(0..20);
(0..random_val)
.map(|_| tokens.choose(&mut rng).unwrap())
@@ -101,7 +101,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> {
let index = Index::create_from_tempdir(schema)?;
let reader = index.reader()?;
let mut rng = rng();
let mut rng = thread_rng();
let mut index_writer: IndexWriter =
index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?;
@@ -110,7 +110,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> {
let mut uncommitted_docs: HashSet<u64> = HashSet::new();
for _ in 0..get_num_iterations() {
let random_val = rng.random_range(0..20);
let random_val = rng.gen_range(0..20);
if random_val == 0 {
index_writer.commit()?;
committed_docs.extend(&uncommitted_docs);

View File

@@ -1,6 +1,8 @@
use std::collections::HashSet;
use std::fmt;
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
@@ -11,9 +13,9 @@ use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DeleteMeta {
struct DeleteMeta {
num_deleted_docs: u32,
pub opstamp: Opstamp,
opstamp: Opstamp,
}
#[derive(Clone, Default)]
@@ -35,6 +37,7 @@ impl SegmentMetaInventory {
let inner = InnerSegmentMeta {
segment_id,
max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: None,
};
SegmentMeta::from(self.inventory.track(inner))
@@ -82,6 +85,15 @@ impl SegmentMeta {
self.tracked.segment_id
}
/// Removes the Component::TempStore from the alive list and
/// therefore marks the temp docstore file to be deleted by
/// the garbage collection.
pub fn untrack_temp_docstore(&self) {
self.tracked
.include_temp_doc_store
.store(false, std::sync::atomic::Ordering::Relaxed);
}
/// Returns the number of deleted documents.
pub fn num_deleted_docs(&self) -> u32 {
self.tracked
@@ -99,9 +111,20 @@ impl SegmentMeta {
/// is by removing all files that have been created by tantivy
/// and are not used by any segment anymore.
pub fn list_files(&self) -> HashSet<PathBuf> {
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
if self
.tracked
.include_temp_doc_store
.load(std::sync::atomic::Ordering::Relaxed)
{
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
} else {
SegmentComponent::iterator()
.filter(|comp| *comp != &SegmentComponent::TempStore)
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
}
}
/// Returns the relative path of a component of our segment.
@@ -115,9 +138,11 @@ impl SegmentMeta {
SegmentComponent::Positions => ".pos".to_string(),
SegmentComponent::Terms => ".term".to_string(),
SegmentComponent::Store => ".store".to_string(),
SegmentComponent::TempStore => ".store.temp".to_string(),
SegmentComponent::FastFields => ".fast".to_string(),
SegmentComponent::FieldNorms => ".fieldnorm".to_string(),
SegmentComponent::Delete => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
SegmentComponent::Spatial => ".spatial".to_string(),
});
PathBuf::from(path)
}
@@ -159,6 +184,7 @@ impl SegmentMeta {
segment_id: inner_meta.segment_id,
max_doc,
deletes: None,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
});
SegmentMeta { tracked }
}
@@ -177,6 +203,7 @@ impl SegmentMeta {
let tracked = self.tracked.map(move |inner_meta| InnerSegmentMeta {
segment_id: inner_meta.segment_id,
max_doc: inner_meta.max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: Some(delete_meta),
});
SegmentMeta { tracked }
@@ -187,7 +214,15 @@ impl SegmentMeta {
struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
pub deletes: Option<DeleteMeta>,
deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
#[serde(default = "default_temp_store")]
pub(crate) include_temp_doc_store: Arc<AtomicBool>,
}
fn default_temp_store() -> Arc<AtomicBool> {
Arc::new(AtomicBool::new(false))
}
impl InnerSegmentMeta {
@@ -370,10 +405,7 @@ mod tests {
schema_builder.build()
};
let index_metas = IndexMeta {
index_settings: IndexSettings {
docstore_compression: Compressor::None,
..Default::default()
},
index_settings: IndexSettings::default(),
segments: Vec::new(),
schema,
opstamp: 0u64,
@@ -382,7 +414,7 @@ mod tests {
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
json,
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
@@ -463,8 +495,6 @@ mod tests {
#[test]
#[cfg(feature = "lz4-compression")]
fn test_index_settings_default() {
use crate::store::Compressor;
let mut index_settings = IndexSettings::default();
assert_eq!(
index_settings,

View File

@@ -46,7 +46,7 @@ impl Segment {
///
/// This method is only used when updating `max_doc` from 0
/// as we finalize a fresh new segment.
pub fn with_max_doc(self, max_doc: u32) -> Segment {
pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_max_doc(max_doc),

Some files were not shown because too many files have changed in this diff Show More