mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-03-16 10:10:43 +00:00
Compare commits
22 Commits
composite-
...
flatheadmi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
643639f14b | ||
|
|
f85a27068d | ||
|
|
1619e05bc5 | ||
|
|
5d03c600ba | ||
|
|
32beb06382 | ||
|
|
d8bc0e7c99 | ||
|
|
79622f1f0b | ||
|
|
d26d6c34fc | ||
|
|
6da54fa5da | ||
|
|
9f10279681 | ||
|
|
68009bb25b | ||
|
|
459456ca28 | ||
|
|
dbbc8c3f65 | ||
|
|
d3049cb323 | ||
|
|
ccdf399cd7 | ||
|
|
2dc46b235e | ||
|
|
f38140f72f | ||
|
|
0996bea7ac | ||
|
|
1c66567efc | ||
|
|
b2a9bb279d | ||
|
|
558c99fa2d | ||
|
|
43b5f34721 |
@@ -1,125 +0,0 @@
|
||||
---
|
||||
name: rationalize-deps
|
||||
description: Analyze Cargo.toml dependencies and attempt to remove unused features to reduce compile times and binary size
|
||||
---
|
||||
|
||||
# Rationalize Dependencies
|
||||
|
||||
This skill analyzes Cargo.toml dependencies to identify and remove unused features.
|
||||
|
||||
## Overview
|
||||
|
||||
Many crates enable features by default that may not be needed. This skill:
|
||||
1. Identifies dependencies with default features enabled
|
||||
2. Tests if `default-features = false` works
|
||||
3. Identifies which specific features are actually needed
|
||||
4. Verifies compilation after changes
|
||||
|
||||
## Step 1: Identify the target
|
||||
|
||||
Ask the user which crate(s) to analyze:
|
||||
- A specific crate name (e.g., "tokio", "serde")
|
||||
- A specific workspace member (e.g., "quickwit-search")
|
||||
- "all" to scan the entire workspace
|
||||
|
||||
## Step 2: Analyze current dependencies
|
||||
|
||||
For the workspace Cargo.toml (`quickwit/Cargo.toml`), list dependencies that:
|
||||
- Do NOT have `default-features = false`
|
||||
- Have default features that might be unnecessary
|
||||
|
||||
Run: `cargo tree -p <crate> -f "{p} {f}" --edges features` to see what features are actually used.
|
||||
|
||||
## Step 3: For each candidate dependency
|
||||
|
||||
### 3a: Check the crate's default features
|
||||
|
||||
Look up the crate on crates.io or check its Cargo.toml to understand:
|
||||
- What features are enabled by default
|
||||
- What each feature provides
|
||||
|
||||
Use: `cargo metadata --format-version=1 | jq '.packages[] | select(.name == "<crate>") | .features'`
|
||||
|
||||
### 3b: Try disabling default features
|
||||
|
||||
Modify the dependency in `quickwit/Cargo.toml`:
|
||||
|
||||
From:
|
||||
```toml
|
||||
some-crate = { version = "1.0" }
|
||||
```
|
||||
|
||||
To:
|
||||
```toml
|
||||
some-crate = { version = "1.0", default-features = false }
|
||||
```
|
||||
|
||||
### 3c: Run cargo check
|
||||
|
||||
Run: `cargo check --workspace` (or target specific packages for faster feedback)
|
||||
|
||||
If compilation fails:
|
||||
1. Read the error messages to identify which features are needed
|
||||
2. Add only the required features explicitly:
|
||||
```toml
|
||||
some-crate = { version = "1.0", default-features = false, features = ["needed-feature"] }
|
||||
```
|
||||
3. Re-run cargo check
|
||||
|
||||
### 3d: Binary search for minimal features
|
||||
|
||||
If there are many default features, use binary search:
|
||||
1. Start with no features
|
||||
2. If it fails, add half the default features
|
||||
3. Continue until you find the minimal set
|
||||
|
||||
## Step 4: Document findings
|
||||
|
||||
For each dependency analyzed, report:
|
||||
- Original configuration
|
||||
- New configuration (if changed)
|
||||
- Features that were removed
|
||||
- Any features that are required
|
||||
|
||||
## Step 5: Verify full build
|
||||
|
||||
After all changes, run:
|
||||
```bash
|
||||
cargo check --workspace --all-targets
|
||||
cargo test --workspace --no-run
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Serde
|
||||
Often only needs `derive`:
|
||||
```toml
|
||||
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
|
||||
```
|
||||
|
||||
### Tokio
|
||||
Identify which runtime features are actually used:
|
||||
```toml
|
||||
tokio = { version = "1.0", default-features = false, features = ["rt-multi-thread", "macros", "sync"] }
|
||||
```
|
||||
|
||||
### Reqwest
|
||||
Often doesn't need all TLS backends:
|
||||
```toml
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
|
||||
```
|
||||
|
||||
## Rollback
|
||||
|
||||
If changes cause issues:
|
||||
```bash
|
||||
git checkout quickwit/Cargo.toml
|
||||
cargo check --workspace
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- Start with large crates that have many default features (tokio, reqwest, hyper)
|
||||
- Use `cargo bloat --crates` to identify large dependencies
|
||||
- Check `cargo tree -d` for duplicate dependencies that might indicate feature conflicts
|
||||
- Some features are needed only for tests - consider using `[dev-dependencies]` features
|
||||
@@ -1,60 +0,0 @@
|
||||
---
|
||||
name: simple-pr
|
||||
description: Create a simple PR from staged changes with an auto-generated commit message
|
||||
disable-model-invocation: true
|
||||
---
|
||||
|
||||
# Simple PR
|
||||
|
||||
Follow these steps to create a simple PR from staged changes:
|
||||
|
||||
## Step 1: Check workspace state
|
||||
|
||||
Run: `git status`
|
||||
|
||||
Verify that all changes have been staged (no unstaged changes). If there are unstaged changes, abort and ask the user to stage their changes first with `git add`.
|
||||
|
||||
Also verify that we are on the `main` branch. If not, abort and ask the user to switch to main first.
|
||||
|
||||
## Step 2: Ensure main is up to date
|
||||
|
||||
Run: `git pull origin main`
|
||||
|
||||
This ensures we're working from the latest code.
|
||||
|
||||
## Step 3: Review staged changes
|
||||
|
||||
Run: `git diff --cached`
|
||||
|
||||
Review the staged changes to understand what the PR will contain.
|
||||
|
||||
## Step 4: Generate commit message
|
||||
|
||||
Based on the staged changes, generate a concise commit message (1-2 sentences) that describes the "why" rather than the "what".
|
||||
|
||||
Display the proposed commit message to the user and ask for confirmation before proceeding.
|
||||
|
||||
## Step 5: Create a new branch
|
||||
|
||||
Get the git username: `git config user.name | tr ' ' '-' | tr '[:upper:]' '[:lower:]'`
|
||||
|
||||
Create a short, descriptive branch name based on the changes (e.g., `fix-typo-in-readme`, `add-retry-logic`, `update-deps`).
|
||||
|
||||
Create and checkout the branch: `git checkout -b {username}/{short-descriptive-name}`
|
||||
|
||||
## Step 6: Commit changes
|
||||
|
||||
Commit with the message from step 3:
|
||||
```
|
||||
git commit -m "{commit-message}"
|
||||
```
|
||||
|
||||
## Step 7: Push and open a PR
|
||||
|
||||
Push the branch and open a PR:
|
||||
```
|
||||
git push -u origin {branch-name}
|
||||
gh pr create --title "{commit-message-title}" --body "{longer-description-if-needed}"
|
||||
```
|
||||
|
||||
Report the PR URL to the user when complete.
|
||||
4
.github/workflows/coverage.yml
vendored
4
.github/workflows/coverage.yml
vendored
@@ -15,11 +15,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
|
||||
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
continue-on-error: true
|
||||
|
||||
30
.github/workflows/test.yml
vendored
30
.github/workflows/test.yml
vendored
@@ -39,11 +39,11 @@ jobs:
|
||||
|
||||
- name: Check Formatting
|
||||
run: cargo +nightly fmt --all -- --check
|
||||
|
||||
|
||||
- name: Check Stable Compilation
|
||||
run: cargo build --all-features
|
||||
|
||||
|
||||
|
||||
- name: Check Bench Compilation
|
||||
run: cargo +nightly bench --no-run --profile=dev --all-features
|
||||
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
features:
|
||||
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
|
||||
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
- { label: "none", flags: "" }
|
||||
features: [
|
||||
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
|
||||
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
]
|
||||
|
||||
name: test-${{ matrix.features.label}}
|
||||
|
||||
@@ -80,21 +80,7 @@ jobs:
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
cargo +stable nextest run --lib --no-default-features --verbose --workspace
|
||||
else
|
||||
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
|
||||
fi
|
||||
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
|
||||
- name: Run doctests
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
echo "no doctest for no feature flag"
|
||||
else
|
||||
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
fi
|
||||
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
|
||||
53
Cargo.toml
53
Cargo.toml
@@ -15,7 +15,7 @@ rust-version = "1.85"
|
||||
exclude = ["benches/*.json", "benches/*.txt"]
|
||||
|
||||
[dependencies]
|
||||
oneshot = "0.1.13"
|
||||
oneshot = "0.1.7"
|
||||
base64 = "0.22.0"
|
||||
byteorder = "1.4.3"
|
||||
crc32fast = "1.3.2"
|
||||
@@ -27,7 +27,7 @@ regex = { version = "1.5.5", default-features = false, features = [
|
||||
aho-corasick = "1.0"
|
||||
tantivy-fst = "0.5"
|
||||
memmap2 = { version = "0.9.0", optional = true }
|
||||
lz4_flex = { version = "0.12", default-features = false, optional = true }
|
||||
lz4_flex = { version = "0.11", default-features = false, optional = true }
|
||||
zstd = { version = "0.13", optional = true, default-features = false }
|
||||
tempfile = { version = "3.12.0", optional = true }
|
||||
log = "0.4.16"
|
||||
@@ -37,9 +37,9 @@ fs4 = { version = "0.13.1", optional = true }
|
||||
levenshtein_automata = "0.2.1"
|
||||
uuid = { version = "1.0.0", features = ["v4", "serde"] }
|
||||
crossbeam-channel = "0.5.4"
|
||||
rust-stemmers = { version = "1.2.0", optional = true }
|
||||
rust-stemmers = "1.2.0"
|
||||
downcast-rs = "2.0.1"
|
||||
bitpacking = { version = "0.9.3", default-features = false, features = [
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
] }
|
||||
census = "0.4.2"
|
||||
@@ -50,12 +50,13 @@ fail = { version = "0.5.0", optional = true }
|
||||
time = { version = "0.3.35", features = ["serde-well-known"] }
|
||||
smallvec = "1.8.0"
|
||||
rayon = "1.5.2"
|
||||
lru = "0.16.3"
|
||||
lru = "0.12.0"
|
||||
fastdivide = "0.4.0"
|
||||
itertools = "0.14.0"
|
||||
measure_time = "0.9.0"
|
||||
arc-swap = "1.5.0"
|
||||
bon = "3.3.1"
|
||||
i_triangle = "0.38.0"
|
||||
|
||||
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
|
||||
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
|
||||
@@ -64,28 +65,29 @@ query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tanti
|
||||
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
|
||||
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
sketches-ddsketch = { path = "./sketches-ddsketch", features = ["use_serde"] }
|
||||
datasketches = "0.2.0"
|
||||
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
|
||||
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
typetag = "0.2.21"
|
||||
geo-types = "0.7.17"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = "0.3.9"
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.14.2"
|
||||
rand = "0.9"
|
||||
binggan = "0.14.0"
|
||||
rand = "0.8.5"
|
||||
maplit = "1.0.2"
|
||||
matches = "0.1.9"
|
||||
pretty_assertions = "1.2.1"
|
||||
proptest = "1.7.0"
|
||||
proptest = "1.0.0"
|
||||
test-log = "0.2.10"
|
||||
futures = "0.3.21"
|
||||
paste = "1.0.11"
|
||||
more-asserts = "0.3.1"
|
||||
rand_distr = "0.5"
|
||||
rand_distr = "0.4.3"
|
||||
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
||||
postcard = { version = "1.0.4", features = [
|
||||
"use-std",
|
||||
@@ -113,8 +115,7 @@ debug-assertions = true
|
||||
overflow-checks = true
|
||||
|
||||
[features]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
|
||||
stemmer = ["rust-stemmers"]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
|
||||
mmap = ["fs4", "tempfile", "memmap2"]
|
||||
stopwords = []
|
||||
|
||||
@@ -144,7 +145,6 @@ members = [
|
||||
"sstable",
|
||||
"tokenizer-api",
|
||||
"columnar",
|
||||
"sketches-ddsketch",
|
||||
]
|
||||
|
||||
# Following the "fail" crate best practises, we isolate
|
||||
@@ -175,31 +175,6 @@ harness = false
|
||||
name = "exists_json"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_query"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "and_or_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bool_queries_with_range"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "str_search_and_get"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "merge_segments"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "regex_all_terms"
|
||||
harness = false
|
||||
|
||||
|
||||
@@ -123,7 +123,6 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
|
||||
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
|
||||
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
|
||||
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
|
||||
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
|
||||
- and [more](https://github.com/search?q=tantivy)!
|
||||
|
||||
### On average, how much faster is Tantivy compared to Lucene?
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use common::DateTime;
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::seq::IndexedRandom;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_distr::Distribution;
|
||||
use serde_json::json;
|
||||
@@ -55,39 +53,27 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, stats_f64);
|
||||
register!(group, extendedstats_f64);
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_7);
|
||||
register!(group, terms_all_unique);
|
||||
register!(group, terms_150_000);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_histogram);
|
||||
register!(group, terms_zipf_1000);
|
||||
register!(group, terms_zipf_1000_with_histogram);
|
||||
register!(group, terms_zipf_1000_with_avg_sub_agg);
|
||||
register!(group, terms_few_with_avg_sub_agg);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, composite_term_many_page_1000);
|
||||
register!(group, composite_term_many_page_1000_with_avg_sub_agg);
|
||||
register!(group, composite_term_few);
|
||||
register!(group, composite_histogram);
|
||||
register!(group, composite_histogram_calendar);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, terms_status_with_cardinality_agg);
|
||||
register!(group, terms_few_with_cardinality_agg);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
register!(group, range_agg_with_term_agg_status);
|
||||
register!(group, range_agg_with_term_agg_few);
|
||||
register!(group, range_agg_with_term_agg_many);
|
||||
register!(group, histogram);
|
||||
register!(group, histogram_hard_bounds);
|
||||
register!(group, histogram_with_avg_sub_agg);
|
||||
register!(group, histogram_with_term_agg_status);
|
||||
register!(group, histogram_with_term_agg_few);
|
||||
register!(group, avg_and_range_with_avg_sub_agg);
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
@@ -146,12 +132,12 @@ fn extendedstats_f64(index: &Index) {
|
||||
}
|
||||
fn percentiles_f64(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
}
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
@@ -166,10 +152,10 @@ fn cardinality_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
@@ -182,20 +168,13 @@ fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_7(index: &Index) {
|
||||
fn terms_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_all_unique(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_150_000(index: &Index) {
|
||||
fn terms_many(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms" } },
|
||||
});
|
||||
@@ -243,10 +222,11 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
|
||||
fn terms_few_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_all_unique_terms" },
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
@@ -254,60 +234,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
@@ -320,75 +246,6 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn composite_term_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_ctf": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{ "text_few_terms": { "terms": { "field": "text_few_terms" } } }
|
||||
],
|
||||
"size": 1000
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn composite_term_many_page_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_ctmp1000": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
|
||||
],
|
||||
"size": 1000
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_ctmp1000wasa": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
|
||||
],
|
||||
"size": 1000,
|
||||
|
||||
},
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn composite_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_ch": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{ "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } }
|
||||
],
|
||||
"size": 1000
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn composite_histogram_calendar(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_chc": {
|
||||
"composite": {
|
||||
"sources": [
|
||||
{ "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } }
|
||||
],
|
||||
"size": 1000
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
|
||||
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
|
||||
@@ -433,7 +290,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn range_agg_with_term_agg_status(index: &Index) {
|
||||
fn range_agg_with_term_agg_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"range": {
|
||||
@@ -448,7 +305,7 @@ fn range_agg_with_term_agg_status(index: &Index) {
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -504,12 +361,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn histogram_with_term_agg_status(index: &Index) {
|
||||
fn histogram_with_term_agg_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"histogram": { "field": "score_f64", "interval": 10 },
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } }
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -554,13 +411,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
// Flag to use existing index
|
||||
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
|
||||
if reuse_index && std::path::Path::new("agg_bench").exists() {
|
||||
return Index::open_in_dir("agg_bench");
|
||||
}
|
||||
// crreate dir
|
||||
std::fs::create_dir_all("agg_bench")?;
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_fieldtype = tantivy::schema::TextOptions::default()
|
||||
.set_indexing_options(
|
||||
@@ -569,48 +419,20 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
.set_stored();
|
||||
let text_field = schema_builder.add_text_field("text", text_fieldtype);
|
||||
let json_field = schema_builder.add_json_field("json", FAST);
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_few_terms_status =
|
||||
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
|
||||
let text_field_1000_terms_zipf =
|
||||
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
|
||||
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
|
||||
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
|
||||
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
|
||||
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
|
||||
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
|
||||
let date_field = schema_builder.add_date_field("timestamp", FAST);
|
||||
// use tmp dir
|
||||
let index = if reuse_index {
|
||||
Index::create_in_dir("agg_bench", schema_builder.build())?
|
||||
} else {
|
||||
Index::create_from_tempdir(schema_builder.build())?
|
||||
};
|
||||
// Approximate log proportions
|
||||
let status_field_data = [
|
||||
("INFO", 8000),
|
||||
("ERROR", 300),
|
||||
("WARN", 1200),
|
||||
("DEBUG", 500),
|
||||
("OK", 500),
|
||||
("CRITICAL", 20),
|
||||
("EMERGENCY", 1),
|
||||
];
|
||||
let log_level_distribution =
|
||||
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
|
||||
let index = Index::create_from_tempdir(schema_builder.build())?;
|
||||
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
|
||||
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
|
||||
|
||||
let many_terms_data = (0..150_000)
|
||||
.map(|num| format!("author{num}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Prepare 1000 unique terms sampled using a Zipf distribution.
|
||||
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
|
||||
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
|
||||
let zipf_1000 = rand_distr::Zipf::new(1000.0, 1.1f64).unwrap();
|
||||
|
||||
{
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
|
||||
@@ -620,25 +442,15 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!())?;
|
||||
}
|
||||
if cardinality == Cardinality::Multivalued {
|
||||
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let term_1000_a = &terms_1000[idx_a];
|
||||
let term_1000_b = &terms_1000[idx_b];
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
text_field => "cool",
|
||||
text_field => "cool",
|
||||
text_field_all_unique_terms => "cool",
|
||||
text_field_all_unique_terms => "coolo",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_few_terms_status => log_level_sample_a,
|
||||
text_field_few_terms_status => log_level_sample_b,
|
||||
text_field_1000_terms_zipf => term_1000_a.as_str(),
|
||||
text_field_1000_terms_zipf => term_1000_b.as_str(),
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
score_field => 1u64,
|
||||
score_field => 1u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
@@ -653,8 +465,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
}
|
||||
let _val_max = 1_000_000.0;
|
||||
for _ in 0..doc_with_value {
|
||||
let val: f64 = rng.random_range(0.0..1_000_000.0);
|
||||
let json = if rng.random_bool(0.1) {
|
||||
let val: f64 = rng.gen_range(0.0..1_000_000.0);
|
||||
let json = if rng.gen_bool(0.1) {
|
||||
// 10% are numeric values
|
||||
json!({ "mixed_type": val })
|
||||
} else {
|
||||
@@ -663,14 +475,11 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!(
|
||||
text_field => "cool",
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
|
||||
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
|
||||
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
|
||||
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
score_field => val as u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
score_field_i64 => val as i64,
|
||||
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
|
||||
))?;
|
||||
if cardinality == Cardinality::OptionalSparse {
|
||||
for _ in 0..20 {
|
||||
@@ -719,7 +528,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -735,7 +544,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,29 +55,29 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
|
||||
for _ in 0..num_docs {
|
||||
let has_a = rng.random_bool(p_a as f64);
|
||||
let has_b = rng.random_bool(p_b as f64);
|
||||
let has_c = rng.random_bool(p_c as f64);
|
||||
let score = rng.random_range(0u64..100u64);
|
||||
let score2 = rng.random_range(0u64..100_000u64);
|
||||
let has_a = rng.gen_bool(p_a as f64);
|
||||
let has_b = rng.gen_bool(p_b as f64);
|
||||
let has_c = rng.gen_bool(p_c as f64);
|
||||
let score = rng.gen_range(0u64..100u64);
|
||||
let score2 = rng.gen_range(0u64..100_000u64);
|
||||
let mut title_tokens: Vec<&str> = Vec::new();
|
||||
let mut body_tokens: Vec<&str> = Vec::new();
|
||||
if has_a {
|
||||
if rng.random_bool(0.1) {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("a");
|
||||
} else {
|
||||
body_tokens.push("a");
|
||||
}
|
||||
}
|
||||
if has_b {
|
||||
if rng.random_bool(0.1) {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("b");
|
||||
} else {
|
||||
body_tokens.push("b");
|
||||
}
|
||||
}
|
||||
if has_c {
|
||||
if rng.random_bool(0.1) {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("c");
|
||||
} else {
|
||||
body_tokens.push("c");
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
|
||||
// Unified schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
|
||||
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.random_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.random_range(0u64..1000u64);
|
||||
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.random_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.random_range(0u64..10000000u64);
|
||||
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Build query parser for title field
|
||||
let qp_title = QueryParser::for_index(&index, vec![f_title]);
|
||||
|
||||
BenchIndex {
|
||||
index,
|
||||
searcher,
|
||||
query_parser: qp_title,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define all four field types
|
||||
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Define the three terms we want to test with
|
||||
let terms = ["a", "b", "z"];
|
||||
|
||||
// Generate all combinations of terms and field names
|
||||
let mut queries = Vec::new();
|
||||
for &term in &terms {
|
||||
for &field_name in &field_names {
|
||||
let query_str = format!(
|
||||
"{} AND {}:[{} TO {}]",
|
||||
term, field_name, range_low, range_high
|
||||
);
|
||||
queries.push((query_str, field_name.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let query_str = format!(
|
||||
"{}:[{} TO {}] AND {}:[{} TO {}]",
|
||||
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
|
||||
);
|
||||
queries.push((query_str, "num_asc_fast".to_string()));
|
||||
|
||||
// Run all benchmark tasks for each query and its corresponding field name
|
||||
for (query_str, field_name) in queries {
|
||||
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given query string and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
field_name: &str,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task(bench_group, bench_index, query_str, Count, "count");
|
||||
|
||||
// Test all results
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
DocSetCollector,
|
||||
"all results",
|
||||
);
|
||||
|
||||
// Test top 100 by the field (if it's a FAST field)
|
||||
if field_name.ends_with("_fast") {
|
||||
// Ascending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
|
||||
// Descending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
|
||||
*count
|
||||
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(top_docs) =
|
||||
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
|
||||
{
|
||||
doc_set.len()
|
||||
} else {
|
||||
eprintln!(
|
||||
"Unknown collector result type: {:?}",
|
||||
std::any::type_name::<C::Fruit>()
|
||||
);
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
// Benchmarks segment merging
|
||||
//
|
||||
// Notes:
|
||||
// - Input segments are kept intact (no deletes / no IndexWriter merge).
|
||||
// - Output is written to a `NullDirectory` that discards all files except
|
||||
// fieldnorms (needed for merging).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::directory::error::{DeleteError, OpenReadError, OpenWriteError};
|
||||
use tantivy::directory::{
|
||||
AntiCallToken, Directory, FileHandle, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle,
|
||||
WritePtr,
|
||||
};
|
||||
use tantivy::indexer::{merge_filtered_segments, NoMergePolicy};
|
||||
use tantivy::schema::{Schema, TEXT};
|
||||
use tantivy::{doc, HasLen, Index, IndexSettings, Segment};
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
struct NullDirectory {
|
||||
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
|
||||
}
|
||||
|
||||
struct NullWriter;
|
||||
|
||||
impl Write for NullWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TerminatingWrite for NullWriter {
|
||||
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct InMemoryWriter {
|
||||
path: PathBuf,
|
||||
buffer: Vec<u8>,
|
||||
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
|
||||
}
|
||||
|
||||
impl Write for InMemoryWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.buffer.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TerminatingWrite for InMemoryWriter {
|
||||
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
|
||||
let bytes = OwnedBytes::new(std::mem::take(&mut self.buffer));
|
||||
self.blobs.write().unwrap().insert(self.path.clone(), bytes);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct NullFileHandle;
|
||||
impl HasLen for NullFileHandle {
|
||||
fn len(&self) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
impl FileHandle for NullFileHandle {
|
||||
fn read_bytes(&self, _range: std::ops::Range<usize>) -> io::Result<OwnedBytes> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Directory for NullDirectory {
|
||||
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
|
||||
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
|
||||
return Ok(Arc::new(bytes.clone()));
|
||||
}
|
||||
Ok(Arc::new(NullFileHandle))
|
||||
}
|
||||
|
||||
fn delete(&self, _path: &Path) -> Result<(), DeleteError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn exists(&self, _path: &Path) -> Result<bool, OpenReadError> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
|
||||
let path_buf = path.to_path_buf();
|
||||
if path.to_string_lossy().ends_with(".fieldnorm") {
|
||||
let writer = InMemoryWriter {
|
||||
path: path_buf,
|
||||
buffer: Vec::new(),
|
||||
blobs: Arc::clone(&self.blobs),
|
||||
};
|
||||
Ok(io::BufWriter::new(Box::new(writer)))
|
||||
} else {
|
||||
Ok(io::BufWriter::new(Box::new(NullWriter)))
|
||||
}
|
||||
}
|
||||
|
||||
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
|
||||
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
|
||||
return Ok(bytes.as_slice().to_vec());
|
||||
}
|
||||
Err(OpenReadError::FileDoesNotExist(path.to_path_buf()))
|
||||
}
|
||||
|
||||
fn atomic_write(&self, _path: &Path, _data: &[u8]) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sync_directory(&self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn watch(&self, _watch_callback: WatchCallback) -> tantivy::Result<WatchHandle> {
|
||||
Ok(WatchHandle::empty())
|
||||
}
|
||||
}
|
||||
|
||||
struct MergeScenario {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
segments: Vec<Segment>,
|
||||
settings: IndexSettings,
|
||||
label: String,
|
||||
}
|
||||
|
||||
fn build_index(
|
||||
num_segments: usize,
|
||||
docs_per_segment: usize,
|
||||
tokens_per_doc: usize,
|
||||
vocab_size: usize,
|
||||
) -> MergeScenario {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let body = schema_builder.add_text_field("body", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
assert!(vocab_size > 0);
|
||||
let total_tokens = num_segments * docs_per_segment * tokens_per_doc;
|
||||
let use_unique_terms = vocab_size >= total_tokens;
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
let mut next_token_id: u64 = 0;
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 256_000_000).unwrap();
|
||||
writer.set_merge_policy(Box::new(NoMergePolicy));
|
||||
for _ in 0..num_segments {
|
||||
for _ in 0..docs_per_segment {
|
||||
let mut tokens = Vec::with_capacity(tokens_per_doc);
|
||||
for _ in 0..tokens_per_doc {
|
||||
let token_id = if use_unique_terms {
|
||||
let id = next_token_id;
|
||||
next_token_id += 1;
|
||||
id
|
||||
} else {
|
||||
rng.random_range(0..vocab_size as u64)
|
||||
};
|
||||
tokens.push(format!("term_{token_id}"));
|
||||
}
|
||||
writer.add_document(doc!(body => tokens.join(" "))).unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let segments = index.searchable_segments().unwrap();
|
||||
let settings = index.settings().clone();
|
||||
let label = format!(
|
||||
"segments={}, docs/seg={}, tokens/doc={}, vocab={}",
|
||||
num_segments, docs_per_segment, tokens_per_doc, vocab_size
|
||||
);
|
||||
|
||||
MergeScenario {
|
||||
index,
|
||||
segments,
|
||||
settings,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let scenarios = vec![
|
||||
build_index(8, 50_000, 12, 8),
|
||||
build_index(16, 50_000, 12, 8),
|
||||
build_index(16, 100_000, 12, 8),
|
||||
build_index(8, 50_000, 8, 8 * 50_000 * 8),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for scenario in scenarios {
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("merge_segments inv_index — {}", scenario.label));
|
||||
let segments = scenario.segments.clone();
|
||||
let settings = scenario.settings.clone();
|
||||
group.register("merge", move |_| {
|
||||
let output_dir = NullDirectory::default();
|
||||
let filter_doc_ids = vec![None; segments.len()];
|
||||
let merged_index =
|
||||
merge_filtered_segments(&segments, settings.clone(), filter_doc_ids, output_dir)
|
||||
.unwrap();
|
||||
black_box(merged_index);
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
use std::ops::Bound;
|
||||
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::RangeQuery;
|
||||
use tantivy::schema::{Schema, FAST, INDEXED};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
|
||||
// Schema with fast fields only
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.random_range(0u64..1000u64);
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.random_range(0u64..10000000u64);
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
BenchIndex { index, searcher }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
// Dense distribution - random values in small range (0-999)
|
||||
(
|
||||
"dense_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"dense_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
1000,
|
||||
1002,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
10_000_000,
|
||||
10_000_002,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define fast field types
|
||||
let field_names = ["num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Generate range queries for fast fields
|
||||
for &field_name in &field_names {
|
||||
// Create the range query
|
||||
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
|
||||
let lower_term = Term::from_field_u64(field, range_low);
|
||||
let upper_term = Term::from_field_u64(field, range_high);
|
||||
|
||||
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
|
||||
|
||||
run_benchmark_tasks(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given range query and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task_count(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
"count",
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test top 100 by the field (ascending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_asc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
|
||||
// Test top 100 by the field (descending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_desc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query,
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task_count(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = CountSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_docset(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = DocSetSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_asc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100AscSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_desc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100DescSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct CountSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl CountSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &Count).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct DocSetSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl DocSetSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100AscSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100AscSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100DescSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100DescSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
@@ -1,260 +0,0 @@
|
||||
use std::fmt::Display;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy::collector::{Count, TopDocs};
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::*;
|
||||
use tantivy::{doc, Index};
|
||||
|
||||
#[global_allocator]
|
||||
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
|
||||
|
||||
fn main() {
|
||||
bench_range_query();
|
||||
}
|
||||
|
||||
fn bench_range_query() {
|
||||
let index = get_index_0_to_100();
|
||||
let mut runner = BenchRunner::new();
|
||||
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
|
||||
|
||||
runner.set_name("range_query on u64");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("id", "Single Valued Range Field"),
|
||||
("ids", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent()),
|
||||
("10_percent", get_10_percent()),
|
||||
("1_percent", get_1_percent()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
|
||||
runner.set_name("range_query on ip");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("ip", "Single Valued Range Field"),
|
||||
("ips", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent_ip()),
|
||||
("10_percent", get_10_percent_ip()),
|
||||
("1_percent", get_1_percent_ip()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
}
|
||||
|
||||
fn test_range<T: Display>(
|
||||
runner: &mut BenchRunner,
|
||||
index: &Index,
|
||||
field_name_and_descr: &[(&str, &str)],
|
||||
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
|
||||
) {
|
||||
for (field, suffix) in field_name_and_descr {
|
||||
let term_num_hits = vec![
|
||||
("", ""),
|
||||
("1_percent", "veryfew"),
|
||||
("10_percent", "few"),
|
||||
("90_percent", "most"),
|
||||
];
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(suffix);
|
||||
// all intersect combinations
|
||||
for (range_name, range) in &range_num_hits {
|
||||
for (term_name, term) in &term_num_hits {
|
||||
let index = &index;
|
||||
let test_name = if term_name.is_empty() {
|
||||
format!("id_range_hit_{}", range_name)
|
||||
} else {
|
||||
format!(
|
||||
"id_range_hit_{}_intersect_with_term_{}",
|
||||
range_name, term_name
|
||||
)
|
||||
};
|
||||
group.register(test_name, move |_| {
|
||||
let query = if term_name.is_empty() {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("AND id_name:{}", term)
|
||||
};
|
||||
black_box(execute_query(field, range, &query, index));
|
||||
});
|
||||
}
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id_name = if rng.random_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.random_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"most".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id_name,
|
||||
id: rng.random_range(0..100),
|
||||
// Multiply by 1000, so that we create most buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.random_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_docs(&docs)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
pub id_name: String,
|
||||
pub id: u64,
|
||||
pub ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
|
||||
let ids_u64_field =
|
||||
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
|
||||
|
||||
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
|
||||
let ids_f64_field = schema_builder.add_f64_field(
|
||||
"ids_f64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
|
||||
let ids_i64_field = schema_builder.add_i64_field(
|
||||
"ids_i64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
|
||||
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
|
||||
|
||||
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
|
||||
|
||||
let schema = schema_builder.build();
|
||||
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
|
||||
for doc in docs.iter() {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_u64_field => doc.id,
|
||||
ids_u64_field => doc.id,
|
||||
id_u64_field => doc.id,
|
||||
id_f64_field => doc.id as f64,
|
||||
id_i64_field => doc.id as i64,
|
||||
text_field => doc.id_name.to_string(),
|
||||
text_field2 => doc.id_name.to_string(),
|
||||
ips_field => doc.ip,
|
||||
ips_field => doc.ip,
|
||||
ip_field => doc.ip,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<u64> {
|
||||
0..=90
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<u64> {
|
||||
0..=10
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<u64> {
|
||||
10..=10
|
||||
}
|
||||
|
||||
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
struct NumHits {
|
||||
count: usize,
|
||||
}
|
||||
impl OutputValue for NumHits {
|
||||
fn column_title() -> &'static str {
|
||||
"NumHits"
|
||||
}
|
||||
fn format(&self) -> Option<String> {
|
||||
Some(self.count.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_query<T: Display>(
|
||||
field: &str,
|
||||
id_range: &RangeInclusive<T>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> NumHits {
|
||||
let gen_query_inclusive = |from: &T, to: &T| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(id_range.start(), id_range.end());
|
||||
execute_query_(&query, index)
|
||||
}
|
||||
|
||||
fn execute_query_(query: &str, index: &Index) -> NumHits {
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let num_hits = searcher
|
||||
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
|
||||
.unwrap()
|
||||
.1;
|
||||
NumHits { count: num_hits }
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
// Benchmarks regex query that matches all terms in a synthetic index.
|
||||
//
|
||||
// Corpus model:
|
||||
// - N unique terms: t000000, t000001, ...
|
||||
// - M docs
|
||||
// - K tokens per doc: doc i gets terms derived from (i, token_index)
|
||||
//
|
||||
// Query:
|
||||
// - Regex "t.*" to match all terms
|
||||
//
|
||||
// Run with:
|
||||
// - cargo bench --bench regex_all_terms
|
||||
//
|
||||
|
||||
use std::fmt::Write;
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use tantivy::collector::Count;
|
||||
use tantivy::query::RegexQuery;
|
||||
use tantivy::schema::{Schema, TEXT};
|
||||
use tantivy::{doc, Index, ReloadPolicy};
|
||||
|
||||
const HEAP_SIZE_BYTES: usize = 200_000_000;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct BenchConfig {
|
||||
num_terms: usize,
|
||||
num_docs: usize,
|
||||
tokens_per_doc: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let configs = default_configs();
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for config in configs {
|
||||
let (index, text_field) = build_index(config, HEAP_SIZE_BYTES);
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.expect("reader");
|
||||
let searcher = reader.searcher();
|
||||
let query = RegexQuery::from_pattern("t.*", text_field).expect("regex query");
|
||||
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!(
|
||||
"regex_all_terms_t{}_d{}_k{}",
|
||||
config.num_terms, config.num_docs, config.tokens_per_doc
|
||||
));
|
||||
group.register("regex_count", move |_| {
|
||||
let count = searcher.search(&query, &Count).expect("search");
|
||||
black_box(count);
|
||||
});
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
fn default_configs() -> Vec<BenchConfig> {
|
||||
vec![
|
||||
BenchConfig {
|
||||
num_terms: 10_000,
|
||||
num_docs: 100_000,
|
||||
tokens_per_doc: 1,
|
||||
},
|
||||
BenchConfig {
|
||||
num_terms: 10_000,
|
||||
num_docs: 100_000,
|
||||
tokens_per_doc: 8,
|
||||
},
|
||||
BenchConfig {
|
||||
num_terms: 100_000,
|
||||
num_docs: 100_000,
|
||||
tokens_per_doc: 1,
|
||||
},
|
||||
BenchConfig {
|
||||
num_terms: 100_000,
|
||||
num_docs: 100_000,
|
||||
tokens_per_doc: 8,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn build_index(config: BenchConfig, heap_size_bytes: usize) -> (Index, tantivy::schema::Field) {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
let term_width = config.num_terms.to_string().len();
|
||||
{
|
||||
let mut writer = index
|
||||
.writer_with_num_threads(1, heap_size_bytes)
|
||||
.expect("writer");
|
||||
let mut buffer = String::new();
|
||||
for doc_id in 0..config.num_docs {
|
||||
buffer.clear();
|
||||
for token_idx in 0..config.tokens_per_doc {
|
||||
if token_idx > 0 {
|
||||
buffer.push(' ');
|
||||
}
|
||||
let term_id = (doc_id * config.tokens_per_doc + token_idx) % config.num_terms;
|
||||
write!(&mut buffer, "t{term_id:0term_width$}").expect("write token");
|
||||
}
|
||||
writer
|
||||
.add_document(doc!(text_field => buffer.as_str()))
|
||||
.expect("add_document");
|
||||
}
|
||||
writer.commit().expect("commit");
|
||||
}
|
||||
|
||||
(index, text_field)
|
||||
}
|
||||
@@ -1,421 +0,0 @@
|
||||
// This benchmark compares different approaches for retrieving string values:
|
||||
//
|
||||
// 1. Fast Field Approach: retrieves string values via term_ords() and ord_to_str()
|
||||
//
|
||||
// 2. Doc Store Approach: retrieves string values via searcher.doc() and field extraction
|
||||
//
|
||||
// The benchmark includes various data distributions:
|
||||
// - Dense Sequential: Sequential document IDs with dense data
|
||||
// - Dense Random: Random document IDs with dense data
|
||||
// - Sparse Sequential: Sequential document IDs with sparse data
|
||||
// - Sparse Random: Random document IDs with sparse data
|
||||
use std::ops::Bound;
|
||||
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, DocSetCollector};
|
||||
use tantivy::query::RangeQuery;
|
||||
use tantivy::schema::document::TantivyDocument;
|
||||
use tantivy::schema::{Schema, Value, FAST, STORED, STRING};
|
||||
use tantivy::{doc, Index, ReloadPolicy, Searcher, Term};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
|
||||
// Schema with string fast field and stored field for doc access
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_str_fast = schema_builder.add_text_field("str_fast", STRING | STORED | FAST);
|
||||
let f_str_stored = schema_builder.add_text_field("str_stored", STRING | STORED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense_random" => {
|
||||
for _doc_id in 0..num_docs {
|
||||
let suffix = rng.gen_range(0u64..1000u64);
|
||||
let str_val = format!("str_{:03}", suffix);
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_str_fast=>str_val.clone(),
|
||||
f_str_stored=>str_val,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"dense_sequential" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let suffix = doc_id as u64 % 1000;
|
||||
let str_val = format!("str_{:03}", suffix);
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_str_fast=>str_val.clone(),
|
||||
f_str_stored=>str_val,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse_random" => {
|
||||
for _doc_id in 0..num_docs {
|
||||
let suffix = rng.gen_range(0u64..1000000u64);
|
||||
let str_val = format!("str_{:07}", suffix);
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_str_fast=>str_val.clone(),
|
||||
f_str_stored=>str_val,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse_sequential" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let suffix = doc_id as u64;
|
||||
let str_val = format!("str_{:07}", suffix);
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_str_fast=>str_val.clone(),
|
||||
f_str_stored=>str_val,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
BenchIndex { index, searcher }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
(
|
||||
"dense_random_search_low_range".to_string(),
|
||||
1_000_000,
|
||||
"dense_random",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense_random_search_high_range".to_string(),
|
||||
1_000_000,
|
||||
"dense_random",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"dense_sequential_search_low_range".to_string(),
|
||||
1_000_000,
|
||||
"dense_sequential",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense_sequential_search_high_range".to_string(),
|
||||
1_000_000,
|
||||
"dense_sequential",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"sparse_random_search_low_range".to_string(),
|
||||
1_000_000,
|
||||
"sparse_random",
|
||||
0,
|
||||
9999,
|
||||
),
|
||||
(
|
||||
"sparse_random_search_high_range".to_string(),
|
||||
1_000_000,
|
||||
"sparse_random",
|
||||
990_000,
|
||||
999_999,
|
||||
),
|
||||
(
|
||||
"sparse_sequential_search_low_range".to_string(),
|
||||
1_000_000,
|
||||
"sparse_sequential",
|
||||
0,
|
||||
9999,
|
||||
),
|
||||
(
|
||||
"sparse_sequential_search_high_range".to_string(),
|
||||
1_000_000,
|
||||
"sparse_sequential",
|
||||
990_000,
|
||||
999_999,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, distribution, range_low, range_high) in scenarios {
|
||||
let bench_index = build_shared_indices(n, distribution);
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(scenario_id);
|
||||
|
||||
let field = bench_index.searcher.schema().get_field("str_fast").unwrap();
|
||||
|
||||
let (lower_str, upper_str) =
|
||||
if distribution == "dense_sequential" || distribution == "dense_random" {
|
||||
(
|
||||
format!("str_{:03}", range_low),
|
||||
format!("str_{:03}", range_high),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
format!("str_{:07}", range_low),
|
||||
format!("str_{:07}", range_high),
|
||||
)
|
||||
};
|
||||
|
||||
let lower_term = Term::from_field_text(field, &lower_str);
|
||||
let upper_term = Term::from_field_text(field, &upper_str);
|
||||
|
||||
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
|
||||
|
||||
run_benchmark_tasks(&mut group, &bench_index, query, range_low, range_high);
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given range query
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
// Test count of matching documents
|
||||
add_bench_task_count(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test fetching all DocIds of matching documents
|
||||
add_bench_task_docset(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test fetching all string fast field values of matching documents
|
||||
add_bench_task_fetch_all_strings(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test fetching all string values of matching documents through doc() method
|
||||
add_bench_task_fetch_all_strings_from_doc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
}
|
||||
|
||||
fn add_bench_task_count(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!("string_search_count_[{}-{}]", range_low, range_high);
|
||||
|
||||
let search_task = CountSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_docset(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!("string_fetch_all_docset_[{}-{}]", range_low, range_high);
|
||||
|
||||
let search_task = DocSetSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_fetch_all_strings(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"string_fastfield_fetch_all_strings_[{}-{}]",
|
||||
range_low, range_high
|
||||
);
|
||||
|
||||
let search_task = FetchAllStringsSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
|
||||
bench_group.register(task_name, move |_| {
|
||||
let result = black_box(search_task.run());
|
||||
result.len()
|
||||
});
|
||||
}
|
||||
|
||||
fn add_bench_task_fetch_all_strings_from_doc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"string_doc_fetch_all_strings_[{}-{}]",
|
||||
range_low, range_high
|
||||
);
|
||||
|
||||
let search_task = FetchAllStringsFromDocTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
|
||||
bench_group.register(task_name, move |_| {
|
||||
let result = black_box(search_task.run());
|
||||
result.len()
|
||||
});
|
||||
}
|
||||
|
||||
struct CountSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl CountSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &Count).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct DocSetSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl DocSetSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct FetchAllStringsSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl FetchAllStringsSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> Vec<String> {
|
||||
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
|
||||
docs.sort();
|
||||
let mut strings = Vec::with_capacity(docs.len());
|
||||
|
||||
for doc_address in docs {
|
||||
let segment_reader = &self.searcher.segment_readers()[doc_address.segment_ord as usize];
|
||||
let str_column_opt = segment_reader.fast_fields().str("str_fast");
|
||||
|
||||
if let Ok(Some(str_column)) = str_column_opt {
|
||||
let doc_id = doc_address.doc_id;
|
||||
let term_ord = str_column.term_ords(doc_id).next().unwrap();
|
||||
let mut str_buffer = String::new();
|
||||
if str_column.ord_to_str(term_ord, &mut str_buffer).is_ok() {
|
||||
strings.push(str_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
strings
|
||||
}
|
||||
}
|
||||
|
||||
struct FetchAllStringsFromDocTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl FetchAllStringsFromDocTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> Vec<String> {
|
||||
let doc_addresses = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
let mut docs = doc_addresses.into_iter().collect::<Vec<_>>();
|
||||
docs.sort();
|
||||
let mut strings = Vec::with_capacity(docs.len());
|
||||
|
||||
let str_stored_field = self
|
||||
.searcher
|
||||
.schema()
|
||||
.get_field("str_stored")
|
||||
.expect("str_stored field should exist");
|
||||
|
||||
for doc_address in docs {
|
||||
// Get the document from the doc store (row store access)
|
||||
if let Ok(doc) = self.searcher.doc::<TantivyDocument>(doc_address) {
|
||||
// Extract string values from the stored field
|
||||
if let Some(field_value) = doc.get_first(str_stored_field) {
|
||||
if let Some(text) = field_value.as_value().as_str() {
|
||||
strings.push(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
strings
|
||||
}
|
||||
}
|
||||
@@ -18,5 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy"
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.9"
|
||||
rand = "0.8"
|
||||
proptest = "1"
|
||||
|
||||
@@ -4,8 +4,8 @@ extern crate test;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::rng;
|
||||
use rand::seq::IteratorRandom;
|
||||
use rand::thread_rng;
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
|
||||
use test::Bencher;
|
||||
|
||||
@@ -27,7 +27,7 @@ mod tests {
|
||||
let num_els = 1_000_000u32;
|
||||
let bit_unpacker = BitUnpacker::new(bit_width);
|
||||
let data = create_bitpacked_data(bit_width, num_els);
|
||||
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
|
||||
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut thread_rng(), 100_000);
|
||||
b.iter(|| {
|
||||
let mut out = 0u64;
|
||||
for &idx in &idxs {
|
||||
|
||||
@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
|
||||
#[inline]
|
||||
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
|
||||
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
|
||||
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
|
||||
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
|
||||
}
|
||||
|
||||
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
|
||||
@@ -66,19 +66,17 @@ unsafe fn filter_vec_avx2_aux(
|
||||
]);
|
||||
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
|
||||
for _ in 0..num_words {
|
||||
unsafe {
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
}
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
}
|
||||
unsafe { output_tail.offset_from(output) as usize }
|
||||
output_tail.offset_from(output) as usize
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -94,7 +92,8 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
|
||||
let too_low = op_greater(*range.start(), val);
|
||||
let too_high = op_greater(val, *range.end());
|
||||
let inside = op_or(too_low, too_high);
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
|
||||
as u8
|
||||
}
|
||||
|
||||
union U8x32 {
|
||||
|
||||
@@ -22,7 +22,7 @@ downcast-rs = "2.0.1"
|
||||
[dev-dependencies]
|
||||
proptest = "1"
|
||||
more-asserts = "0.3.1"
|
||||
rand = "0.9"
|
||||
rand = "0.8"
|
||||
binggan = "0.14.0"
|
||||
|
||||
[[bench]]
|
||||
|
||||
@@ -9,7 +9,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_co
|
||||
fn get_data() -> Vec<u64> {
|
||||
let mut rng = StdRng::seed_from_u64(2u64);
|
||||
let mut data: Vec<_> = (100..55_000_u64)
|
||||
.map(|num| num + rng.random::<u8>() as u64)
|
||||
.map(|num| num + rng.r#gen::<u8>() as u64)
|
||||
.collect();
|
||||
data.push(99_000);
|
||||
data.insert(1000, 2000);
|
||||
|
||||
@@ -6,7 +6,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_valu
|
||||
fn get_data() -> Vec<u64> {
|
||||
let mut rng = StdRng::seed_from_u64(2u64);
|
||||
let mut data: Vec<_> = (100..55_000_u64)
|
||||
.map(|num| num + rng.random::<u8>() as u64)
|
||||
.map(|num| num + rng.r#gen::<u8>() as u64)
|
||||
.collect();
|
||||
data.push(99_000);
|
||||
data.insert(1000, 2000);
|
||||
|
||||
@@ -8,7 +8,7 @@ const TOTAL_NUM_VALUES: u32 = 1_000_000;
|
||||
fn gen_optional_index(fill_ratio: f64) -> OptionalIndex {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let vals: Vec<u32> = (0..TOTAL_NUM_VALUES)
|
||||
.map(|_| rng.random_bool(fill_ratio))
|
||||
.map(|_| rng.gen_bool(fill_ratio))
|
||||
.enumerate()
|
||||
.filter(|(_pos, val)| *val)
|
||||
.map(|(pos, _)| pos as u32)
|
||||
@@ -25,7 +25,7 @@ fn random_range_iterator(
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let mut current = start;
|
||||
std::iter::from_fn(move || {
|
||||
current += rng.random_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
if current >= end { None } else { Some(current) }
|
||||
})
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ fn get_data_50percent_item() -> Vec<u128> {
|
||||
|
||||
let mut data = vec![];
|
||||
for _ in 0..300_000 {
|
||||
let val = rng.random_range(1..=100);
|
||||
let val = rng.gen_range(1..=100);
|
||||
data.push(val);
|
||||
}
|
||||
data.push(SINGLE_ITEM);
|
||||
|
||||
@@ -34,7 +34,7 @@ fn get_data_50percent_item() -> Vec<u128> {
|
||||
|
||||
let mut data = vec![];
|
||||
for _ in 0..300_000 {
|
||||
let val = rng.random_range(1..=100);
|
||||
let val = rng.gen_range(1..=100);
|
||||
data.push(val);
|
||||
}
|
||||
data.push(SINGLE_ITEM);
|
||||
|
||||
@@ -29,20 +29,12 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub fn fetch_block_with_missing(
|
||||
&mut self,
|
||||
docs: &[u32],
|
||||
accessor: &Column<T>,
|
||||
missing: Option<T>,
|
||||
) {
|
||||
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
|
||||
self.fetch_block(docs, accessor);
|
||||
// no missing values
|
||||
if accessor.index.get_cardinality().is_full() {
|
||||
return;
|
||||
}
|
||||
let Some(missing) = missing else {
|
||||
return;
|
||||
};
|
||||
|
||||
// We can compare docid_cache length with docs to find missing docs
|
||||
// For multi value columns we can't rely on the length and always need to scan
|
||||
|
||||
@@ -85,8 +85,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn first(&self, doc_id: DocId) -> Option<T> {
|
||||
self.values_for_doc(doc_id).next()
|
||||
pub fn first(&self, row_id: RowId) -> Option<T> {
|
||||
self.values_for_doc(row_id).next()
|
||||
}
|
||||
|
||||
/// Load the first value for each docid in the provided slice.
|
||||
|
||||
@@ -31,7 +31,7 @@ pub use u64_based::{
|
||||
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
|
||||
};
|
||||
pub use u128_based::{
|
||||
CompactHit, CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
|
||||
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
|
||||
serialize_column_values_u128,
|
||||
};
|
||||
pub use vec_column::VecColumn;
|
||||
|
||||
@@ -292,19 +292,6 @@ impl BinarySerializable for IPCodecParams {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the result of looking up a u128 value in the compact space.
|
||||
///
|
||||
/// If a value is outside the compact space, the next compact value is returned.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CompactHit {
|
||||
/// The value exists in the compact space
|
||||
Exact(u32),
|
||||
/// The value does not exist in the compact space, but the next higher value does
|
||||
Next(u32),
|
||||
/// The value is greater than the maximum compact value
|
||||
AfterLast,
|
||||
}
|
||||
|
||||
/// Exposes the compact space compressed values as u64.
|
||||
///
|
||||
/// This allows faster access to the values, as u64 is faster to work with than u128.
|
||||
@@ -322,11 +309,6 @@ impl CompactSpaceU64Accessor {
|
||||
pub fn compact_to_u128(&self, compact: u32) -> u128 {
|
||||
self.0.compact_to_u128(compact)
|
||||
}
|
||||
|
||||
/// Finds the next compact space value for a given u128 value.
|
||||
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
|
||||
self.0.u128_to_next_compact(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnValues<u64> for CompactSpaceU64Accessor {
|
||||
@@ -448,26 +430,6 @@ impl CompactSpaceDecompressor {
|
||||
Ok(decompressor)
|
||||
}
|
||||
|
||||
/// Finds the next compact space value for a given u128 value
|
||||
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
|
||||
// Try to convert to compact space
|
||||
match self.u128_to_compact(value) {
|
||||
// Value is in compact space, return its compact representation
|
||||
Ok(compact) => CompactHit::Exact(compact),
|
||||
// Value is not in compact space
|
||||
Err(pos) => {
|
||||
if pos >= self.params.compact_space.ranges_mapping.len() {
|
||||
// Value is beyond all ranges, no next value exists
|
||||
CompactHit::AfterLast
|
||||
} else {
|
||||
// Get the next range and return its start compact value
|
||||
let next_range = &self.params.compact_space.ranges_mapping[pos];
|
||||
CompactHit::Next(next_range.compact_start)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converting to compact space for the decompressor is more complex, since we may get values
|
||||
/// which are outside the compact space. e.g. if we map
|
||||
/// 1000 => 5
|
||||
@@ -861,41 +823,6 @@ mod tests {
|
||||
let _data = test_aux_vals(vals);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u128_to_next_compact() {
|
||||
let vals = &[100u128, 200u128, 1_000_000_000u128, 1_000_000_100u128];
|
||||
let mut data = test_aux_vals(vals);
|
||||
|
||||
let _header = U128Header::deserialize(&mut data);
|
||||
let decomp = CompactSpaceDecompressor::open(data).unwrap();
|
||||
|
||||
// Test value that's already in a range
|
||||
let compact_100 = decomp.u128_to_compact(100).unwrap();
|
||||
assert_eq!(
|
||||
decomp.u128_to_next_compact(100),
|
||||
CompactHit::Exact(compact_100)
|
||||
);
|
||||
|
||||
// Test value between two ranges
|
||||
let compact_million = decomp.u128_to_compact(1_000_000_000).unwrap();
|
||||
assert_eq!(
|
||||
decomp.u128_to_next_compact(250),
|
||||
CompactHit::Next(compact_million)
|
||||
);
|
||||
|
||||
// Test value before the first range
|
||||
assert_eq!(
|
||||
decomp.u128_to_next_compact(50),
|
||||
CompactHit::Next(compact_100)
|
||||
);
|
||||
|
||||
// Test value after the last range
|
||||
assert_eq!(
|
||||
decomp.u128_to_next_compact(10_000_000_000),
|
||||
CompactHit::AfterLast
|
||||
);
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn num_strategy() -> impl Strategy<Value = u128> {
|
||||
|
||||
@@ -7,7 +7,7 @@ mod compact_space;
|
||||
|
||||
use common::{BinarySerializable, OwnedBytes, VInt};
|
||||
pub use compact_space::{
|
||||
CompactHit, CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
|
||||
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
|
||||
};
|
||||
|
||||
use crate::column_values::monotonic_map_column;
|
||||
|
||||
@@ -41,6 +41,12 @@ fn transform_range_before_linear_transformation(
|
||||
if range.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if stats.min_value > *range.end() {
|
||||
return None;
|
||||
}
|
||||
if stats.max_value < *range.start() {
|
||||
return None;
|
||||
}
|
||||
let shifted_range =
|
||||
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
|
||||
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
|
||||
|
||||
@@ -268,7 +268,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn linear_interpol_fast_field_rand() {
|
||||
let mut rng = rand::rng();
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..50 {
|
||||
let mut data = (0..10_000).map(|_| rng.next_u64()).collect::<Vec<_>>();
|
||||
create_and_validate::<LinearCodec>(&data, "random");
|
||||
|
||||
@@ -122,7 +122,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
assert_eq!(vals, buffer);
|
||||
|
||||
if !vals.is_empty() {
|
||||
let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1);
|
||||
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
|
||||
let expected_positions: Vec<u32> = vals
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
||||
@@ -3,8 +3,7 @@ use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{ByteCount, DateTime, OwnedBytes};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
|
||||
|
||||
use crate::column::{BytesColumn, Column, StrColumn};
|
||||
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
|
||||
@@ -318,89 +317,10 @@ impl DynamicColumnHandle {
|
||||
}
|
||||
|
||||
pub fn num_bytes(&self) -> ByteCount {
|
||||
self.file_slice.num_bytes()
|
||||
}
|
||||
|
||||
/// Legacy helper returning the column space usage.
|
||||
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
self.space_usage()
|
||||
}
|
||||
|
||||
/// Return the space usage of the column, optionally broken down by dictionary and column
|
||||
/// values.
|
||||
///
|
||||
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
|
||||
/// the dictionary and the remaining column data (including index and values).
|
||||
/// For all other column types, the dictionary size is `None` and the column size
|
||||
/// equals the total bytes.
|
||||
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
let total_num_bytes = self.num_bytes();
|
||||
let dynamic_column = self.open()?;
|
||||
let dictionary_num_bytes = match &dynamic_column {
|
||||
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
|
||||
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
|
||||
_ => {
|
||||
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
|
||||
}
|
||||
};
|
||||
assert!(dictionary_num_bytes <= total_num_bytes);
|
||||
let column_num_bytes =
|
||||
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
|
||||
Ok(ColumnSpaceUsage::new(
|
||||
column_num_bytes,
|
||||
Some(dictionary_num_bytes),
|
||||
))
|
||||
self.file_slice.len().into()
|
||||
}
|
||||
|
||||
pub fn column_type(&self) -> ColumnType {
|
||||
self.column_type
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents space usage of a column.
|
||||
///
|
||||
/// `column_num_bytes` tracks the column payload (index, values and footer).
|
||||
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
|
||||
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnSpaceUsage {
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
}
|
||||
|
||||
impl ColumnSpaceUsage {
|
||||
pub(crate) fn new(
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
) -> Self {
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn column_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes
|
||||
}
|
||||
|
||||
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
|
||||
self.dictionary_num_bytes
|
||||
}
|
||||
|
||||
pub fn total_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Merge two space usage values by summing their components.
|
||||
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
|
||||
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
|
||||
(Some(val), None) | (None, Some(val)) => Some(val),
|
||||
(None, None) => None,
|
||||
};
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ pub use columnar::{
|
||||
use sstable::VoidSSTable;
|
||||
pub use value::{NumericalType, NumericalValue};
|
||||
|
||||
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
|
||||
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
|
||||
|
||||
pub type RowId = u32;
|
||||
pub type DocId = u32;
|
||||
@@ -59,7 +59,7 @@ pub struct RowAddr {
|
||||
pub row_id: RowId,
|
||||
}
|
||||
|
||||
pub use sstable::{Dictionary, TermOrdHit};
|
||||
pub use sstable::Dictionary;
|
||||
pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>;
|
||||
|
||||
pub use common::DateTime;
|
||||
|
||||
@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
|
||||
let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
|
||||
panic!();
|
||||
};
|
||||
let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
|
||||
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect();
|
||||
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
|
||||
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
|
||||
panic!();
|
||||
};
|
||||
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
|
||||
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect();
|
||||
assert_eq!(
|
||||
&vals,
|
||||
&[
|
||||
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
|
||||
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
|
||||
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect();
|
||||
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
|
||||
assert_eq!(str_col.num_rows(), 5);
|
||||
let mut term_buffer = String::new();
|
||||
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
|
||||
panic!();
|
||||
};
|
||||
let index: Vec<Option<u64>> = (0..5)
|
||||
.map(|doc_id| bytes_col.ords().first(doc_id))
|
||||
.map(|row_id| bytes_col.ords().first(row_id))
|
||||
.collect();
|
||||
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
|
||||
assert_eq!(bytes_col.num_rows(), 5);
|
||||
|
||||
@@ -21,5 +21,5 @@ serde = { version = "1.0.136", features = ["derive"] }
|
||||
[dev-dependencies]
|
||||
binggan = "0.14.0"
|
||||
proptest = "1.0.0"
|
||||
rand = "0.9"
|
||||
rand = "0.8.4"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use binggan::{BenchRunner, black_box};
|
||||
use rand::rng;
|
||||
use rand::seq::IteratorRandom;
|
||||
use rand::thread_rng;
|
||||
use tantivy_common::{BitSet, TinySet, serialize_vint_u32};
|
||||
|
||||
fn bench_vint() {
|
||||
@@ -17,7 +17,7 @@ fn bench_vint() {
|
||||
black_box(out);
|
||||
});
|
||||
|
||||
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut rng(), 100_000);
|
||||
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
|
||||
runner.bench_function("bench_vint_rand", move |_| {
|
||||
let mut out = 0u64;
|
||||
for val in vals.iter().cloned() {
|
||||
|
||||
@@ -181,14 +181,6 @@ pub struct BitSet {
|
||||
len: u64,
|
||||
max_value: u32,
|
||||
}
|
||||
impl std::fmt::Debug for BitSet {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BitSet")
|
||||
.field("len", &self.len)
|
||||
.field("max_value", &self.max_value)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_buckets(max_val: u32) -> u32 {
|
||||
max_val.div_ceil(64u32)
|
||||
@@ -416,7 +408,7 @@ mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use ownedbytes::OwnedBytes;
|
||||
use rand::distr::Bernoulli;
|
||||
use rand::distributions::Bernoulli;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ At indexing, tantivy will try to interpret number and strings as different type
|
||||
priority order.
|
||||
|
||||
Numbers will be interpreted as u64, i64 and f64 in that order.
|
||||
Strings will be interpreted as rfc3339 dates or simple strings.
|
||||
Strings will be interpreted as rfc3999 dates or simple strings.
|
||||
|
||||
The first working type is picked and is the only term that is emitted for indexing.
|
||||
Note this interpretation happens on a per-document basis, and there is no effort to try to sniff
|
||||
@@ -81,7 +81,7 @@ Will be interpreted as
|
||||
(my_path.my_segment, String, 233) or (my_path.my_segment, u64, 233)
|
||||
```
|
||||
|
||||
Likewise, we need to emit two tokens if the query contains an rfc3339 date.
|
||||
Likewise, we need to emit two tokens if the query contains an rfc3999 date.
|
||||
Indeed the date could have been actually a single token inside the text of a document at ingestion time. Generally speaking, we will always at least emit a string token in query parsing, and sometimes more.
|
||||
|
||||
If one more json field is defined, things get even more complicated.
|
||||
|
||||
66
examples/geo_json.rs
Normal file
66
examples/geo_json.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use geo_types::Point;
|
||||
use tantivy::collector::TopDocs;
|
||||
use tantivy::query::SpatialQuery;
|
||||
use tantivy::schema::{Schema, Value, SPATIAL, STORED, TEXT};
|
||||
use tantivy::spatial::point::GeoPoint;
|
||||
use tantivy::{Index, IndexWriter, TantivyDocument};
|
||||
fn main() -> tantivy::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_json_field("properties", STORED | TEXT);
|
||||
schema_builder.add_spatial_field("geometry", STORED | SPATIAL);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
|
||||
let doc = TantivyDocument::parse_json(
|
||||
&schema,
|
||||
r#"{
|
||||
"type":"Feature",
|
||||
"geometry":{
|
||||
"type":"Polygon",
|
||||
"coordinates":[[[-99.483911,45.577697],[-99.483869,45.571457],[-99.481739,45.571461],[-99.474881,45.571584],[-99.473167,45.571615],[-99.463394,45.57168],[-99.463391,45.57883],[-99.463368,45.586076],[-99.48177,45.585926],[-99.48384,45.585953],[-99.483885,45.57873],[-99.483911,45.577697]]]
|
||||
},
|
||||
"properties":{
|
||||
"admin_level":"8",
|
||||
"border_type":"city",
|
||||
"boundary":"administrative",
|
||||
"gnis:feature_id":"1267426",
|
||||
"name":"Hosmer",
|
||||
"place":"city",
|
||||
"source":"TIGER/Line® 2008 Place Shapefiles (http://www.census.gov/geo/www/tiger/)",
|
||||
"wikidata":"Q2442118",
|
||||
"wikipedia":"en:Hosmer, South Dakota"
|
||||
}
|
||||
}"#,
|
||||
)?;
|
||||
index_writer.add_document(doc)?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let field = schema.get_field("geometry").unwrap();
|
||||
let query = SpatialQuery::new(
|
||||
field,
|
||||
[
|
||||
GeoPoint {
|
||||
lon: -99.49,
|
||||
lat: 45.56,
|
||||
},
|
||||
GeoPoint {
|
||||
lon: -99.45,
|
||||
lat: 45.59,
|
||||
},
|
||||
],
|
||||
tantivy::query::SpatialQueryType::Intersects,
|
||||
);
|
||||
let hits = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
for (_score, doc_address) in &hits {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)?;
|
||||
if let Some(field_value) = retrieved_doc.get_first(field) {
|
||||
if let Some(geometry_box) = field_value.as_value().into_geometry() {
|
||||
println!("Retrieved geometry: {:?}", geometry_box);
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(hits.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
@@ -560,7 +560,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
(
|
||||
(
|
||||
value((), tag(">=")),
|
||||
map(word_infallible(")", false), |(bound, err)| {
|
||||
map(word_infallible("", false), |(bound, err)| {
|
||||
(
|
||||
(
|
||||
bound
|
||||
@@ -574,7 +574,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
),
|
||||
(
|
||||
value((), tag("<=")),
|
||||
map(word_infallible(")", false), |(bound, err)| {
|
||||
map(word_infallible("", false), |(bound, err)| {
|
||||
(
|
||||
(
|
||||
UserInputBound::Unbounded,
|
||||
@@ -588,7 +588,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
),
|
||||
(
|
||||
value((), tag(">")),
|
||||
map(word_infallible(")", false), |(bound, err)| {
|
||||
map(word_infallible("", false), |(bound, err)| {
|
||||
(
|
||||
(
|
||||
bound
|
||||
@@ -602,7 +602,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
),
|
||||
(
|
||||
value((), tag("<")),
|
||||
map(word_infallible(")", false), |(bound, err)| {
|
||||
map(word_infallible("", false), |(bound, err)| {
|
||||
(
|
||||
(
|
||||
UserInputBound::Unbounded,
|
||||
@@ -704,11 +704,7 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
|
||||
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
|
||||
char('/'),
|
||||
),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
peek(alt((multispace1, eof))),
|
||||
),
|
||||
|elements| UserInputLeaf::Regex {
|
||||
field: None,
|
||||
@@ -725,12 +721,8 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
opt_i_err(char('/'), "missing delimiter /"),
|
||||
),
|
||||
opt_i_err(
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
"expected whitespace, closing parenthesis, or end of input",
|
||||
peek(alt((multispace1, eof))),
|
||||
"expected whitespace or end of input",
|
||||
),
|
||||
)(inp)
|
||||
{
|
||||
@@ -1331,14 +1323,6 @@ mod test {
|
||||
test_parse_query_to_ast_helper("<a", "{\"*\" TO \"a\"}");
|
||||
test_parse_query_to_ast_helper("<=a", "{\"*\" TO \"a\"]");
|
||||
test_parse_query_to_ast_helper("<=bsd", "{\"*\" TO \"bsd\"]");
|
||||
|
||||
test_parse_query_to_ast_helper("(<=42)", "{\"*\" TO \"42\"]");
|
||||
test_parse_query_to_ast_helper("(<=42 )", "{\"*\" TO \"42\"]");
|
||||
test_parse_query_to_ast_helper("(age:>5)", "\"age\":{\"5\" TO \"*\"}");
|
||||
test_parse_query_to_ast_helper(
|
||||
"(title:bar AND age:>12)",
|
||||
"(+\"title\":bar +\"age\":{\"12\" TO \"*\"})",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1715,10 +1699,6 @@ mod test {
|
||||
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
|
||||
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
|
||||
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
|
||||
|
||||
// Regexes between parentheses
|
||||
test_parse_query_to_ast_helper("foo:(/A.*/)", "\"foo\":/A.*/");
|
||||
test_parse_query_to_ast_helper("foo:(/A.*/ OR /B.*/)", "(?\"foo\":/A.*/ ?\"foo\":/B.*/)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -66,7 +66,6 @@ impl UserInputLeaf {
|
||||
}
|
||||
UserInputLeaf::Range { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
UserInputLeaf::Set { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
UserInputLeaf::Regex { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
_ => (), // field was already set, do nothing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
[package]
|
||||
name = "sketches-ddsketch"
|
||||
version = "0.3.0"
|
||||
authors = ["Mike Heffner <mikeh@fesnel.com>"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/mheffner/rust-sketches-ddsketch"
|
||||
homepage = "https://github.com/mheffner/rust-sketches-ddsketch"
|
||||
description = """
|
||||
A direct port of the Golang DDSketch implementation.
|
||||
"""
|
||||
exclude = [".gitignore"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
serde = { package = "serde", version = "1.0", optional = true, features = ["derive", "serde_derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5.1"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
|
||||
[features]
|
||||
use_serde = ["serde", "serde/derive"]
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2019] [Mike Heffner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,11 +0,0 @@
|
||||
clean:
|
||||
cargo clean
|
||||
|
||||
test:
|
||||
cargo test
|
||||
|
||||
test_logs:
|
||||
cargo test -- --nocapture
|
||||
|
||||
test_performance:
|
||||
cargo test --release --jobs 1 test_performance -- --ignored --nocapture
|
||||
@@ -1,37 +0,0 @@
|
||||
# sketches-ddsketch
|
||||
|
||||
This is a direct port of the [Golang](https://github.com/DataDog/sketches-go)
|
||||
[DDSketch](https://arxiv.org/pdf/1908.10693.pdf) quantile sketch implementation
|
||||
to Rust. DDSketch is a fully-mergeable quantile sketch with relative-error
|
||||
guarantees and is extremely fast.
|
||||
|
||||
# DDSketch
|
||||
|
||||
* Sketch size automatically grows as needed, starting with 128 bins.
|
||||
* Extremely fast sample insertion and sketch merges.
|
||||
|
||||
## Usage
|
||||
|
||||
```rust
|
||||
use sketches_ddsketch::{Config, DDSketch};
|
||||
|
||||
let config = Config::defaults();
|
||||
let mut sketch = DDSketch::new(c);
|
||||
|
||||
sketch.add(1.0);
|
||||
sketch.add(1.0);
|
||||
sketch.add(1.0);
|
||||
|
||||
// Get p=50%
|
||||
let quantile = sketch.quantile(0.5).unwrap();
|
||||
assert_eq!(quantile, Some(1.0));
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
No performance tuning has been done with this implementation of the port, so we
|
||||
would expect similar profiles to the original implementation.
|
||||
|
||||
Out of the box we see can achieve over 70M sample inserts/sec and 350K sketch
|
||||
merges/sec. All tests run on a single core Intel i7 processor with 4.2Ghz max
|
||||
clock.
|
||||
@@ -1,98 +0,0 @@
|
||||
#[cfg(feature = "use_serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_MAX_BINS: u32 = 2048;
|
||||
const DEFAULT_ALPHA: f64 = 0.01;
|
||||
const DEFAULT_MIN_VALUE: f64 = 1.0e-9;
|
||||
|
||||
/// The configuration struct for constructing a `DDSketch`
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
|
||||
pub struct Config {
|
||||
pub max_num_bins: u32,
|
||||
pub gamma: f64,
|
||||
pub(crate) gamma_ln: f64,
|
||||
pub(crate) min_value: f64,
|
||||
pub offset: i32,
|
||||
}
|
||||
|
||||
fn log_gamma(value: f64, gamma_ln: f64) -> f64 {
|
||||
value.ln() / gamma_ln
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Construct a new `Config` struct with specific parameters. If you are unsure of how to
|
||||
/// configure this, the `defaults` method constructs a `Config` with built-in defaults.
|
||||
///
|
||||
/// `max_num_bins` is the max number of bins the DDSketch will grow to, in steps of 128 bins.
|
||||
pub fn new(alpha: f64, max_num_bins: u32, min_value: f64) -> Self {
|
||||
// Aligned with Java's LogarithmicMapping / LogLikeIndexMapping:
|
||||
// gamma = (1 + alpha) / (1 - alpha) (correctingFactor=1 for LogarithmicMapping)
|
||||
// gamma_ln = gamma.ln() (not ln_1p, to match Java's Math.log(gamma))
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (gamma() static method)
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogarithmicMapping.java (constructor, correctingFactor()=1)
|
||||
let gamma = (1.0 + alpha) / (1.0 - alpha);
|
||||
let gamma_ln = gamma.ln();
|
||||
|
||||
Config {
|
||||
max_num_bins,
|
||||
gamma,
|
||||
gamma_ln,
|
||||
min_value,
|
||||
offset: 1 - (log_gamma(min_value, gamma_ln) as i32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a `Config` using built-in default settings
|
||||
pub fn defaults() -> Self {
|
||||
Self::new(DEFAULT_ALPHA, DEFAULT_MAX_BINS, DEFAULT_MIN_VALUE)
|
||||
}
|
||||
|
||||
pub fn key(&self, v: f64) -> i32 {
|
||||
// Aligned with Java's LogLikeIndexMapping.index(): floor-based indexing.
|
||||
// Java uses `(int) index` / `(int) index - 1` which is equivalent to floor().
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (index() method)
|
||||
self.log_gamma(v).floor() as i32
|
||||
}
|
||||
|
||||
pub fn value(&self, key: i32) -> f64 {
|
||||
// Aligned with Java's LogLikeIndexMapping.value():
|
||||
// lowerBound(index) * (1 + relativeAccuracy)
|
||||
// = logInverse((index - indexOffset) / multiplier) * (1 + relativeAccuracy)
|
||||
// = gamma^key * 2*gamma/(gamma+1)
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (value() and lowerBound() methods)
|
||||
self.pow_gamma(key) * (2.0 * self.gamma / (1.0 + self.gamma))
|
||||
}
|
||||
|
||||
pub fn log_gamma(&self, value: f64) -> f64 {
|
||||
log_gamma(value, self.gamma_ln)
|
||||
}
|
||||
|
||||
pub fn pow_gamma(&self, key: i32) -> f64 {
|
||||
((key as f64) * self.gamma_ln).exp()
|
||||
}
|
||||
|
||||
pub fn min_possible(&self) -> f64 {
|
||||
self.min_value
|
||||
}
|
||||
|
||||
/// Reconstruct a Config from a gamma value (as decoded from the binary format).
|
||||
/// Uses default max_num_bins and min_value.
|
||||
/// See Java: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogarithmicMapping.java (LogarithmicMapping(double gamma, double indexOffset) constructor)
|
||||
pub(crate) fn from_gamma(gamma: f64) -> Self {
|
||||
let gamma_ln = gamma.ln();
|
||||
Config {
|
||||
max_num_bins: DEFAULT_MAX_BINS,
|
||||
gamma,
|
||||
gamma_ln,
|
||||
min_value: DEFAULT_MIN_VALUE,
|
||||
offset: 1 - (log_gamma(DEFAULT_MIN_VALUE, gamma_ln) as i32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self::new(DEFAULT_ALPHA, DEFAULT_MAX_BINS, DEFAULT_MIN_VALUE)
|
||||
}
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
use std::{error, fmt};
|
||||
|
||||
#[cfg(feature = "use_serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::store::Store;
|
||||
|
||||
type Result<T> = std::result::Result<T, DDSketchError>;
|
||||
|
||||
/// General error type for DDSketch, represents either an invalid quantile or an
|
||||
/// incompatible merge operation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DDSketchError {
|
||||
Quantile,
|
||||
Merge,
|
||||
}
|
||||
impl fmt::Display for DDSketchError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
DDSketchError::Quantile => {
|
||||
write!(f, "Invalid quantile, must be between 0 and 1 (inclusive)")
|
||||
}
|
||||
DDSketchError::Merge => write!(f, "Can not merge sketches with different configs"),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl error::Error for DDSketchError {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
// Generic
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct represents a [DDSketch](https://arxiv.org/pdf/1908.10693.pdf)
|
||||
#[derive(Clone)]
|
||||
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
|
||||
pub struct DDSketch {
|
||||
pub(crate) config: Config,
|
||||
pub(crate) store: Store,
|
||||
pub(crate) negative_store: Store,
|
||||
pub(crate) min: f64,
|
||||
pub(crate) max: f64,
|
||||
pub(crate) sum: f64,
|
||||
pub(crate) zero_count: u64,
|
||||
}
|
||||
|
||||
impl Default for DDSketch {
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: functions should return Option<> in the case of empty
|
||||
impl DDSketch {
|
||||
/// Construct a `DDSketch`. Requires a `Config` specifying the parameters of the sketch
|
||||
pub fn new(config: Config) -> Self {
|
||||
DDSketch {
|
||||
config,
|
||||
store: Store::new(config.max_num_bins as usize),
|
||||
negative_store: Store::new(config.max_num_bins as usize),
|
||||
min: f64::INFINITY,
|
||||
max: f64::NEG_INFINITY,
|
||||
sum: 0.0,
|
||||
zero_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add the sample to the sketch
|
||||
pub fn add(&mut self, v: f64) {
|
||||
if v > self.config.min_possible() {
|
||||
let key = self.config.key(v);
|
||||
self.store.add(key);
|
||||
} else if v < -self.config.min_possible() {
|
||||
let key = self.config.key(-v);
|
||||
self.negative_store.add(key);
|
||||
} else {
|
||||
self.zero_count += 1;
|
||||
}
|
||||
|
||||
if v < self.min {
|
||||
self.min = v;
|
||||
}
|
||||
if self.max < v {
|
||||
self.max = v;
|
||||
}
|
||||
self.sum += v;
|
||||
}
|
||||
|
||||
/// Return the quantile value for quantiles between 0.0 and 1.0. Result is an error, represented
|
||||
/// as DDSketchError::Quantile if the requested quantile is outside of that range.
|
||||
///
|
||||
/// If the sketch is empty the result is None, else Some(v) for the quantile value.
|
||||
pub fn quantile(&self, q: f64) -> Result<Option<f64>> {
|
||||
if !(0.0..=1.0).contains(&q) {
|
||||
return Err(DDSketchError::Quantile);
|
||||
}
|
||||
|
||||
if self.empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if q == 0.0 {
|
||||
return Ok(Some(self.min));
|
||||
} else if q == 1.0 {
|
||||
return Ok(Some(self.max));
|
||||
}
|
||||
|
||||
let rank = (q * (self.count() as f64 - 1.0)) as u64;
|
||||
let quantile;
|
||||
if rank < self.negative_store.count() {
|
||||
let reversed_rank = self.negative_store.count() - rank - 1;
|
||||
let key = self.negative_store.key_at_rank(reversed_rank);
|
||||
quantile = -self.config.value(key);
|
||||
} else if rank < self.zero_count + self.negative_store.count() {
|
||||
quantile = 0.0;
|
||||
} else {
|
||||
let key = self
|
||||
.store
|
||||
.key_at_rank(rank - self.zero_count - self.negative_store.count());
|
||||
quantile = self.config.value(key);
|
||||
}
|
||||
|
||||
Ok(Some(quantile))
|
||||
}
|
||||
|
||||
/// Returns the minimum value seen, or None if sketch is empty
|
||||
pub fn min(&self) -> Option<f64> {
|
||||
if self.empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.min)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum value seen, or None if sketch is empty
|
||||
pub fn max(&self) -> Option<f64> {
|
||||
if self.empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of values seen, or None if sketch is empty
|
||||
pub fn sum(&self) -> Option<f64> {
|
||||
if self.empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.sum)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of values added to the sketch
|
||||
pub fn count(&self) -> usize {
|
||||
(self.store.count() + self.zero_count + self.negative_store.count()) as usize
|
||||
}
|
||||
|
||||
/// Returns the length of the underlying `Store`. This is mainly only useful for understanding
|
||||
/// how much the sketch has grown given the inserted values.
|
||||
pub fn length(&self) -> usize {
|
||||
self.store.length() as usize + self.negative_store.length() as usize
|
||||
}
|
||||
|
||||
/// Merge the contents of another sketch into this one. The sketch that is merged into this one
|
||||
/// is unchanged after the merge.
|
||||
pub fn merge(&mut self, o: &DDSketch) -> Result<()> {
|
||||
if self.config != o.config {
|
||||
return Err(DDSketchError::Merge);
|
||||
}
|
||||
|
||||
let was_empty = self.store.count() == 0;
|
||||
|
||||
// Merge the stores
|
||||
self.store.merge(&o.store);
|
||||
self.negative_store.merge(&o.negative_store);
|
||||
self.zero_count += o.zero_count;
|
||||
|
||||
// Need to ensure we don't override min/max with initializers
|
||||
// if either store were empty
|
||||
if was_empty {
|
||||
self.min = o.min;
|
||||
self.max = o.max;
|
||||
} else if o.store.count() > 0 {
|
||||
if o.min < self.min {
|
||||
self.min = o.min
|
||||
}
|
||||
if o.max > self.max {
|
||||
self.max = o.max;
|
||||
}
|
||||
}
|
||||
self.sum += o.sum;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn empty(&self) -> bool {
|
||||
self.count() == 0
|
||||
}
|
||||
|
||||
/// Encode this sketch into the Java-compatible binary format used by
|
||||
/// `com.datadoghq.sketch.ddsketch.DDSketchWithExactSummaryStatistics`.
|
||||
pub fn to_java_bytes(&self) -> Vec<u8> {
|
||||
crate::encoding::encode_to_java_bytes(self)
|
||||
}
|
||||
|
||||
/// Decode a sketch from the Java-compatible binary format.
|
||||
/// Accepts bytes produced by Java's `DDSketchWithExactSummaryStatistics.encode()`
|
||||
/// with or without the `0x02` version prefix.
|
||||
pub fn from_java_bytes(
|
||||
bytes: &[u8],
|
||||
) -> std::result::Result<Self, crate::encoding::DecodeError> {
|
||||
crate::encoding::decode_from_java_bytes(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use crate::{Config, DDSketch};
|
||||
|
||||
#[test]
|
||||
fn test_add_zero() {
|
||||
let alpha = 0.01;
|
||||
let c = Config::new(alpha, 2048, 10e-9);
|
||||
let mut dd = DDSketch::new(c);
|
||||
dd.add(0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quartiles() {
|
||||
let alpha = 0.01;
|
||||
let c = Config::new(alpha, 2048, 10e-9);
|
||||
let mut dd = DDSketch::new(c);
|
||||
|
||||
// Initialize sketch with {1.0, 2.0, 3.0, 4.0}
|
||||
for i in 1..5 {
|
||||
dd.add(i as f64);
|
||||
}
|
||||
|
||||
// We expect the following mappings from quantile to value:
|
||||
// [0,0.33]: 1.0, (0.34,0.66]: 2.0, (0.67,0.99]: 3.0, (0.99, 1.0]: 4.0
|
||||
let test_cases = vec![
|
||||
(0.0, 1.0),
|
||||
(0.25, 1.0),
|
||||
(0.33, 1.0),
|
||||
(0.34, 2.0),
|
||||
(0.5, 2.0),
|
||||
(0.66, 2.0),
|
||||
(0.67, 3.0),
|
||||
(0.75, 3.0),
|
||||
(0.99, 3.0),
|
||||
(1.0, 4.0),
|
||||
];
|
||||
|
||||
for (q, val) in test_cases {
|
||||
assert_relative_eq!(dd.quantile(q).unwrap().unwrap(), val, max_relative = alpha);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neg_quartiles() {
|
||||
let alpha = 0.01;
|
||||
let c = Config::new(alpha, 2048, 10e-9);
|
||||
let mut dd = DDSketch::new(c);
|
||||
|
||||
// Initialize sketch with {1.0, 2.0, 3.0, 4.0}
|
||||
for i in 1..5 {
|
||||
dd.add(-i as f64);
|
||||
}
|
||||
|
||||
let test_cases = vec![
|
||||
(0.0, -4.0),
|
||||
(0.25, -4.0),
|
||||
(0.5, -3.0),
|
||||
(0.75, -2.0),
|
||||
(1.0, -1.0),
|
||||
];
|
||||
|
||||
for (q, val) in test_cases {
|
||||
assert_relative_eq!(dd.quantile(q).unwrap().unwrap(), val, max_relative = alpha);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_quantile() {
|
||||
let c = Config::defaults();
|
||||
let mut dd = DDSketch::new(c);
|
||||
|
||||
for i in 1..101 {
|
||||
dd.add(i as f64);
|
||||
}
|
||||
|
||||
assert_eq!(dd.quantile(0.95).unwrap().unwrap().ceil(), 95.0);
|
||||
|
||||
assert!(dd.quantile(-1.01).is_err());
|
||||
assert!(dd.quantile(1.01).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_sketch() {
|
||||
let c = Config::defaults();
|
||||
let dd = DDSketch::new(c);
|
||||
|
||||
assert_eq!(dd.quantile(0.98).unwrap(), None);
|
||||
assert_eq!(dd.max(), None);
|
||||
assert_eq!(dd.min(), None);
|
||||
assert_eq!(dd.sum(), None);
|
||||
assert_eq!(dd.count(), 0);
|
||||
|
||||
assert!(dd.quantile(1.01).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_histogram_data() {
|
||||
let values = &[
|
||||
0.754225035,
|
||||
0.752900282,
|
||||
0.752812246,
|
||||
0.752602367,
|
||||
0.754310155,
|
||||
0.753525981,
|
||||
0.752981082,
|
||||
0.752715536,
|
||||
0.751667941,
|
||||
0.755079054,
|
||||
0.753528150,
|
||||
0.755188464,
|
||||
0.752508723,
|
||||
0.750064549,
|
||||
0.753960428,
|
||||
0.751139298,
|
||||
0.752523560,
|
||||
0.753253428,
|
||||
0.753498342,
|
||||
0.751858358,
|
||||
0.752104636,
|
||||
0.753841300,
|
||||
0.754467374,
|
||||
0.753814334,
|
||||
0.750881719,
|
||||
0.753182556,
|
||||
0.752576884,
|
||||
0.753945708,
|
||||
0.753571911,
|
||||
0.752314573,
|
||||
0.752586651,
|
||||
];
|
||||
|
||||
let c = Config::defaults();
|
||||
let mut dd = DDSketch::new(c);
|
||||
|
||||
for value in values {
|
||||
dd.add(*value);
|
||||
}
|
||||
|
||||
assert_eq!(dd.max(), Some(0.755188464));
|
||||
assert_eq!(dd.min(), Some(0.750064549));
|
||||
assert_eq!(dd.count(), 31);
|
||||
assert_eq!(dd.sum(), Some(23.343630625000003));
|
||||
|
||||
assert!(dd.quantile(0.25).unwrap().is_some());
|
||||
assert!(dd.quantile(0.5).unwrap().is_some());
|
||||
assert!(dd.quantile(0.75).unwrap().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_length() {
|
||||
let mut dd = DDSketch::default();
|
||||
assert_eq!(dd.length(), 0);
|
||||
|
||||
dd.add(1.0);
|
||||
assert_eq!(dd.length(), 128);
|
||||
dd.add(2.0);
|
||||
dd.add(3.0);
|
||||
assert_eq!(dd.length(), 128);
|
||||
|
||||
dd.add(-1.0);
|
||||
assert_eq!(dd.length(), 256);
|
||||
dd.add(-2.0);
|
||||
dd.add(-3.0);
|
||||
assert_eq!(dd.length(), 256);
|
||||
}
|
||||
}
|
||||
@@ -1,813 +0,0 @@
|
||||
//! Java-compatible binary encoding/decoding for DDSketch.
|
||||
//!
|
||||
//! This module implements the binary format used by the Java
|
||||
//! `com.datadoghq.sketch.ddsketch.DDSketchWithExactSummaryStatistics` class
|
||||
//! from the DataDog/sketches-java library. It enables cross-language
|
||||
//! serialization so that sketches produced in Rust can be deserialized
|
||||
//! and merged by Java consumers.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::ddsketch::DDSketch;
|
||||
use crate::store::Store;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Flag byte layout
|
||||
//
|
||||
// Each flag byte packs a 2-bit type ordinal in the low bits and a 6-bit
|
||||
// subflag in the upper bits: (subflag << 2) | type_ordinal
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/Flag.java
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The 2-bit type field occupying the low bits of every flag byte.
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum FlagType {
|
||||
SketchFeatures = 0,
|
||||
PositiveStore = 1,
|
||||
IndexMapping = 2,
|
||||
NegativeStore = 3,
|
||||
}
|
||||
|
||||
impl FlagType {
|
||||
fn from_byte(b: u8) -> Option<Self> {
|
||||
match b & 0x03 {
|
||||
0 => Some(Self::SketchFeatures),
|
||||
1 => Some(Self::PositiveStore),
|
||||
2 => Some(Self::IndexMapping),
|
||||
3 => Some(Self::NegativeStore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a flag byte from a subflag and a type.
|
||||
const fn flag(subflag: u8, flag_type: FlagType) -> u8 {
|
||||
(subflag << 2) | (flag_type as u8)
|
||||
}
|
||||
|
||||
// Pre-computed flag bytes for the sketch features we encode/decode.
|
||||
const FLAG_INDEX_MAPPING_LOG: u8 = flag(0, FlagType::IndexMapping); // 0x02
|
||||
const FLAG_ZERO_COUNT: u8 = flag(1, FlagType::SketchFeatures); // 0x04
|
||||
const FLAG_COUNT: u8 = flag(0x28, FlagType::SketchFeatures); // 0xA0
|
||||
const FLAG_SUM: u8 = flag(0x21, FlagType::SketchFeatures); // 0x84
|
||||
const FLAG_MIN: u8 = flag(0x22, FlagType::SketchFeatures); // 0x88
|
||||
const FLAG_MAX: u8 = flag(0x23, FlagType::SketchFeatures); // 0x8C
|
||||
|
||||
/// BinEncodingMode subflags for store flag bytes.
|
||||
/// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/BinEncodingMode.java
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum BinEncodingMode {
|
||||
IndexDeltasAndCounts = 1,
|
||||
IndexDeltas = 2,
|
||||
ContiguousCounts = 3,
|
||||
}
|
||||
|
||||
impl BinEncodingMode {
|
||||
fn from_subflag(subflag: u8) -> Option<Self> {
|
||||
match subflag {
|
||||
1 => Some(Self::IndexDeltasAndCounts),
|
||||
2 => Some(Self::IndexDeltas),
|
||||
3 => Some(Self::ContiguousCounts),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const VAR_DOUBLE_ROTATE_DISTANCE: u32 = 6;
|
||||
const MAX_VAR_LEN_64: usize = 9;
|
||||
|
||||
const DEFAULT_MAX_BINS: u32 = 2048;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error type
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DecodeError {
|
||||
UnexpectedEof,
|
||||
InvalidFlag(u8),
|
||||
InvalidData(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for DecodeError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::UnexpectedEof => write!(f, "unexpected end of input"),
|
||||
Self::InvalidFlag(b) => write!(f, "invalid flag byte: 0x{b:02X}"),
|
||||
Self::InvalidData(msg) => write!(f, "invalid data: {msg}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DecodeError {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// VarEncoding — bit-exact port of Java VarEncodingHelper
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/VarEncodingHelper.java
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn encode_unsigned_var_long(out: &mut Vec<u8>, mut value: u64) {
|
||||
let length = ((63 - value.leading_zeros() as i32) / 7).clamp(0, 8);
|
||||
for _ in 0..length {
|
||||
out.push((value as u8) | 0x80);
|
||||
value >>= 7;
|
||||
}
|
||||
out.push(value as u8);
|
||||
}
|
||||
|
||||
fn decode_unsigned_var_long(input: &mut &[u8]) -> Result<u64, DecodeError> {
|
||||
let mut value: u64 = 0;
|
||||
let mut shift: u32 = 0;
|
||||
loop {
|
||||
let next = read_byte(input)?;
|
||||
if next < 0x80 || shift == 56 {
|
||||
return Ok(value | (u64::from(next) << shift));
|
||||
}
|
||||
value |= (u64::from(next) & 0x7F) << shift;
|
||||
shift += 7;
|
||||
}
|
||||
}
|
||||
|
||||
/// ZigZag encode then var-long encode.
|
||||
fn encode_signed_var_long(out: &mut Vec<u8>, value: i64) {
|
||||
let encoded = ((value >> 63) ^ (value << 1)) as u64;
|
||||
encode_unsigned_var_long(out, encoded);
|
||||
}
|
||||
|
||||
fn decode_signed_var_long(input: &mut &[u8]) -> Result<i64, DecodeError> {
|
||||
let encoded = decode_unsigned_var_long(input)?;
|
||||
Ok(((encoded >> 1) as i64) ^ -((encoded & 1) as i64))
|
||||
}
|
||||
|
||||
fn double_to_var_bits(value: f64) -> u64 {
|
||||
let bits = f64::to_bits(value + 1.0).wrapping_sub(f64::to_bits(1.0));
|
||||
bits.rotate_left(VAR_DOUBLE_ROTATE_DISTANCE)
|
||||
}
|
||||
|
||||
fn var_bits_to_double(bits: u64) -> f64 {
|
||||
f64::from_bits(
|
||||
bits.rotate_right(VAR_DOUBLE_ROTATE_DISTANCE)
|
||||
.wrapping_add(f64::to_bits(1.0)),
|
||||
) - 1.0
|
||||
}
|
||||
|
||||
fn encode_var_double(out: &mut Vec<u8>, value: f64) {
|
||||
let mut bits = double_to_var_bits(value);
|
||||
for _ in 0..MAX_VAR_LEN_64 - 1 {
|
||||
let next = (bits >> 57) as u8;
|
||||
bits <<= 7;
|
||||
if bits == 0 {
|
||||
out.push(next);
|
||||
return;
|
||||
}
|
||||
out.push(next | 0x80);
|
||||
}
|
||||
out.push((bits >> 56) as u8);
|
||||
}
|
||||
|
||||
fn decode_var_double(input: &mut &[u8]) -> Result<f64, DecodeError> {
|
||||
let mut bits: u64 = 0;
|
||||
let mut shift: i32 = 57; // 8*8 - 7
|
||||
loop {
|
||||
let next = read_byte(input)?;
|
||||
if shift == 1 {
|
||||
bits |= u64::from(next);
|
||||
break;
|
||||
}
|
||||
if next < 0x80 {
|
||||
bits |= u64::from(next) << shift;
|
||||
break;
|
||||
}
|
||||
bits |= (u64::from(next) & 0x7F) << shift;
|
||||
shift -= 7;
|
||||
}
|
||||
Ok(var_bits_to_double(bits))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Byte-level helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn read_byte(input: &mut &[u8]) -> Result<u8, DecodeError> {
|
||||
match input.split_first() {
|
||||
Some((&byte, rest)) => {
|
||||
*input = rest;
|
||||
Ok(byte)
|
||||
}
|
||||
None => Err(DecodeError::UnexpectedEof),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_f64_le(out: &mut Vec<u8>, value: f64) {
|
||||
out.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
fn read_f64_le(input: &mut &[u8]) -> Result<f64, DecodeError> {
|
||||
if input.len() < 8 {
|
||||
return Err(DecodeError::UnexpectedEof);
|
||||
}
|
||||
let (bytes, rest) = input.split_at(8);
|
||||
*input = rest;
|
||||
// bytes is guaranteed to be length 8 by the split_at above.
|
||||
let arr = [
|
||||
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
|
||||
];
|
||||
Ok(f64::from_le_bytes(arr))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Store encoding/decoding
|
||||
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/store/DenseStore.java (encode/decode methods)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Collect non-zero bins in the store as (absolute_index, count) pairs.
|
||||
///
|
||||
/// Allocation is acceptable here: this runs once per encode and the Vec
|
||||
/// has at most `max_num_bins` entries.
|
||||
fn collect_non_zero_bins(store: &Store) -> Vec<(i32, u64)> {
|
||||
if store.count == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
let start = (store.min_key - store.offset) as usize;
|
||||
let end = ((store.max_key - store.offset + 1) as usize).min(store.bins.len());
|
||||
store.bins[start..end]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, &count)| count > 0)
|
||||
.map(|(i, &count)| (start as i32 + i as i32 + store.offset, count))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn encode_store(out: &mut Vec<u8>, store: &Store, flag_type: FlagType) {
|
||||
let bins = collect_non_zero_bins(store);
|
||||
if bins.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
out.push(flag(BinEncodingMode::IndexDeltasAndCounts as u8, flag_type));
|
||||
encode_unsigned_var_long(out, bins.len() as u64);
|
||||
|
||||
let mut prev_index: i64 = 0;
|
||||
for &(index, count) in &bins {
|
||||
encode_signed_var_long(out, i64::from(index) - prev_index);
|
||||
encode_var_double(out, count as f64);
|
||||
prev_index = i64::from(index);
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_store(input: &mut &[u8], subflag: u8, bin_limit: usize) -> Result<Store, DecodeError> {
|
||||
let mode = BinEncodingMode::from_subflag(subflag).ok_or_else(|| {
|
||||
DecodeError::InvalidData(format!("unknown bin encoding mode subflag: {subflag}"))
|
||||
})?;
|
||||
let num_bins = decode_unsigned_var_long(input)? as usize;
|
||||
let mut store = Store::new(bin_limit);
|
||||
|
||||
match mode {
|
||||
BinEncodingMode::IndexDeltasAndCounts => {
|
||||
let mut index: i64 = 0;
|
||||
for _ in 0..num_bins {
|
||||
index += decode_signed_var_long(input)?;
|
||||
let count = decode_var_double(input)?;
|
||||
store.add_count(index as i32, count as u64);
|
||||
}
|
||||
}
|
||||
BinEncodingMode::IndexDeltas => {
|
||||
let mut index: i64 = 0;
|
||||
for _ in 0..num_bins {
|
||||
index += decode_signed_var_long(input)?;
|
||||
store.add_count(index as i32, 1);
|
||||
}
|
||||
}
|
||||
BinEncodingMode::ContiguousCounts => {
|
||||
let start_index = decode_signed_var_long(input)?;
|
||||
let index_delta = decode_signed_var_long(input)?;
|
||||
let mut index = start_index;
|
||||
for _ in 0..num_bins {
|
||||
let count = decode_var_double(input)?;
|
||||
store.add_count(index as i32, count as u64);
|
||||
index += index_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Top-level encode / decode
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Encode a DDSketch into the Java-compatible binary format.
|
||||
///
|
||||
/// The output follows the encoding order of
|
||||
/// `DDSketchWithExactSummaryStatistics.encode()` then `DDSketch.encode()`:
|
||||
///
|
||||
/// 1. Summary statistics: COUNT, MIN, MAX (if count > 0)
|
||||
/// 2. SUM (if sum != 0)
|
||||
/// 3. Index mapping (LOG layout): gamma, indexOffset
|
||||
/// 4. Zero count (if > 0)
|
||||
/// 5. Positive store bins
|
||||
/// 6. Negative store bins
|
||||
pub fn encode_to_java_bytes(sketch: &DDSketch) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
let count = sketch.count() as f64;
|
||||
|
||||
// Summary statistics (DDSketchWithExactSummaryStatistics.encode)
|
||||
if count != 0.0 {
|
||||
out.push(FLAG_COUNT);
|
||||
encode_var_double(&mut out, count);
|
||||
out.push(FLAG_MIN);
|
||||
write_f64_le(&mut out, sketch.min);
|
||||
out.push(FLAG_MAX);
|
||||
write_f64_le(&mut out, sketch.max);
|
||||
}
|
||||
if sketch.sum != 0.0 {
|
||||
out.push(FLAG_SUM);
|
||||
write_f64_le(&mut out, sketch.sum);
|
||||
}
|
||||
|
||||
// DDSketch.encode: index mapping + zero count + stores
|
||||
out.push(FLAG_INDEX_MAPPING_LOG);
|
||||
write_f64_le(&mut out, sketch.config.gamma);
|
||||
write_f64_le(&mut out, 0.0_f64);
|
||||
|
||||
if sketch.zero_count != 0 {
|
||||
out.push(FLAG_ZERO_COUNT);
|
||||
encode_var_double(&mut out, sketch.zero_count as f64);
|
||||
}
|
||||
|
||||
encode_store(&mut out, &sketch.store, FlagType::PositiveStore);
|
||||
encode_store(&mut out, &sketch.negative_store, FlagType::NegativeStore);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Decode a DDSketch from the Java-compatible binary format.
|
||||
///
|
||||
/// Accepts bytes with or without a `0x02` version prefix.
|
||||
pub fn decode_from_java_bytes(bytes: &[u8]) -> Result<DDSketch, DecodeError> {
|
||||
if bytes.is_empty() {
|
||||
return Err(DecodeError::UnexpectedEof);
|
||||
}
|
||||
|
||||
let mut input = bytes;
|
||||
|
||||
// Skip optional version prefix (0x02 followed by a valid flag byte).
|
||||
if input.len() >= 2 && input[0] == 0x02 && is_valid_flag_byte(input[1]) {
|
||||
input = &input[1..];
|
||||
}
|
||||
|
||||
let mut gamma: Option<f64> = None;
|
||||
let mut zero_count: f64 = 0.0;
|
||||
let mut sum: f64 = 0.0;
|
||||
let mut min: f64 = f64::INFINITY;
|
||||
let mut max: f64 = f64::NEG_INFINITY;
|
||||
let mut positive_store: Option<Store> = None;
|
||||
let mut negative_store: Option<Store> = None;
|
||||
|
||||
while !input.is_empty() {
|
||||
let flag_byte = read_byte(&mut input)?;
|
||||
let flag_type =
|
||||
FlagType::from_byte(flag_byte).ok_or(DecodeError::InvalidFlag(flag_byte))?;
|
||||
let subflag = flag_byte >> 2;
|
||||
|
||||
match flag_type {
|
||||
FlagType::IndexMapping => {
|
||||
gamma = Some(read_f64_le(&mut input)?);
|
||||
let _index_offset = read_f64_le(&mut input)?;
|
||||
}
|
||||
FlagType::SketchFeatures => match flag_byte {
|
||||
FLAG_ZERO_COUNT => zero_count += decode_var_double(&mut input)?,
|
||||
FLAG_COUNT => {
|
||||
let _count = decode_var_double(&mut input)?;
|
||||
}
|
||||
FLAG_SUM => sum = read_f64_le(&mut input)?,
|
||||
FLAG_MIN => min = read_f64_le(&mut input)?,
|
||||
FLAG_MAX => max = read_f64_le(&mut input)?,
|
||||
_ => return Err(DecodeError::InvalidFlag(flag_byte)),
|
||||
},
|
||||
FlagType::PositiveStore => {
|
||||
positive_store = Some(decode_store(
|
||||
&mut input,
|
||||
subflag,
|
||||
DEFAULT_MAX_BINS as usize,
|
||||
)?);
|
||||
}
|
||||
FlagType::NegativeStore => {
|
||||
negative_store = Some(decode_store(
|
||||
&mut input,
|
||||
subflag,
|
||||
DEFAULT_MAX_BINS as usize,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let g = gamma.unwrap_or_else(|| Config::defaults().gamma);
|
||||
let config = Config::from_gamma(g);
|
||||
let store = positive_store.unwrap_or_else(|| Store::new(config.max_num_bins as usize));
|
||||
let neg = negative_store.unwrap_or_else(|| Store::new(config.max_num_bins as usize));
|
||||
|
||||
Ok(DDSketch {
|
||||
config,
|
||||
store,
|
||||
negative_store: neg,
|
||||
min,
|
||||
max,
|
||||
sum,
|
||||
zero_count: zero_count as u64,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check whether a byte is a valid flag byte for the DDSketch binary format.
|
||||
fn is_valid_flag_byte(b: u8) -> bool {
|
||||
// Known sketch-feature flags
|
||||
if matches!(
|
||||
b,
|
||||
FLAG_ZERO_COUNT | FLAG_COUNT | FLAG_SUM | FLAG_MIN | FLAG_MAX | FLAG_INDEX_MAPPING_LOG
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
let Some(flag_type) = FlagType::from_byte(b) else {
|
||||
return false;
|
||||
};
|
||||
let subflag = b >> 2;
|
||||
match flag_type {
|
||||
FlagType::PositiveStore | FlagType::NegativeStore => (1..=3).contains(&subflag),
|
||||
FlagType::IndexMapping => subflag <= 4, // LOG=0, LOG_LINEAR=1 .. LOG_QUARTIC=4
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Config, DDSketch};
|
||||
|
||||
// --- VarEncoding unit tests ---
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_var_long_zero() {
|
||||
let mut buf = Vec::new();
|
||||
encode_unsigned_var_long(&mut buf, 0);
|
||||
assert_eq!(buf, [0x00]);
|
||||
|
||||
let mut input = buf.as_slice();
|
||||
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 0);
|
||||
assert!(input.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_var_long_small() {
|
||||
let mut buf = Vec::new();
|
||||
encode_unsigned_var_long(&mut buf, 1);
|
||||
assert_eq!(buf, [0x01]);
|
||||
|
||||
let mut input = buf.as_slice();
|
||||
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_var_long_128() {
|
||||
let mut buf = Vec::new();
|
||||
encode_unsigned_var_long(&mut buf, 128);
|
||||
assert_eq!(buf, [0x80, 0x01]);
|
||||
|
||||
let mut input = buf.as_slice();
|
||||
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_var_long_roundtrip() {
|
||||
for v in [0u64, 1, 127, 128, 255, 256, 16383, 16384, u64::MAX] {
|
||||
let mut buf = Vec::new();
|
||||
encode_unsigned_var_long(&mut buf, v);
|
||||
let mut input = buf.as_slice();
|
||||
let decoded = decode_unsigned_var_long(&mut input).unwrap();
|
||||
assert_eq!(decoded, v, "roundtrip failed for {}", v);
|
||||
assert!(input.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signed_var_long_roundtrip() {
|
||||
for v in [0i64, 1, -1, 63, -64, 64, -65, i64::MAX, i64::MIN] {
|
||||
let mut buf = Vec::new();
|
||||
encode_signed_var_long(&mut buf, v);
|
||||
let mut input = buf.as_slice();
|
||||
let decoded = decode_signed_var_long(&mut input).unwrap();
|
||||
assert_eq!(decoded, v, "roundtrip failed for {}", v);
|
||||
assert!(input.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_double_roundtrip() {
|
||||
for v in [0.0, 1.0, 2.0, 5.0, 15.0, 42.0, 100.0, 1e-9, 1e15, 0.5, 7.77] {
|
||||
let mut buf = Vec::new();
|
||||
encode_var_double(&mut buf, v);
|
||||
let mut input = buf.as_slice();
|
||||
let decoded = decode_var_double(&mut input).unwrap();
|
||||
assert!(
|
||||
(decoded - v).abs() < 1e-15 || decoded == v,
|
||||
"roundtrip failed for {}: got {}",
|
||||
v,
|
||||
decoded,
|
||||
);
|
||||
assert!(input.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_double_small_integers() {
|
||||
let mut buf = Vec::new();
|
||||
encode_var_double(&mut buf, 1.0);
|
||||
assert_eq!(buf.len(), 1, "VarDouble(1.0) should be 1 byte");
|
||||
|
||||
buf.clear();
|
||||
encode_var_double(&mut buf, 5.0);
|
||||
assert_eq!(buf.len(), 1, "VarDouble(5.0) should be 1 byte");
|
||||
}
|
||||
|
||||
// --- DDSketch encode/decode roundtrip tests ---
|
||||
|
||||
#[test]
|
||||
fn test_encode_empty_sketch() {
|
||||
let sketch = DDSketch::new(Config::defaults());
|
||||
let bytes = sketch.to_java_bytes();
|
||||
assert!(!bytes.is_empty());
|
||||
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
assert_eq!(decoded.count(), 0);
|
||||
assert_eq!(decoded.min(), None);
|
||||
assert_eq!(decoded.max(), None);
|
||||
assert_eq!(decoded.sum(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_simple_sketch() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 5);
|
||||
assert_eq!(decoded.min(), Some(1.0));
|
||||
assert_eq!(decoded.max(), Some(5.0));
|
||||
assert_eq!(decoded.sum(), Some(15.0));
|
||||
|
||||
assert_quantiles_match(&sketch, &decoded, &[0.5, 0.9, 0.95, 0.99]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_single_value() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
sketch.add(42.0);
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 1);
|
||||
assert_eq!(decoded.min(), Some(42.0));
|
||||
assert_eq!(decoded.max(), Some(42.0));
|
||||
assert_eq!(decoded.sum(), Some(42.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_negative_values() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [-3.0, -1.0, 2.0, 5.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 4);
|
||||
assert_eq!(decoded.min(), Some(-3.0));
|
||||
assert_eq!(decoded.max(), Some(5.0));
|
||||
assert_eq!(decoded.sum(), Some(3.0));
|
||||
|
||||
assert_quantiles_match(&sketch, &decoded, &[0.0, 0.25, 0.5, 0.75, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_with_zero_value() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [0.0, 1.0, 2.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 3);
|
||||
assert_eq!(decoded.min(), Some(0.0));
|
||||
assert_eq!(decoded.max(), Some(2.0));
|
||||
assert_eq!(decoded.sum(), Some(3.0));
|
||||
assert_eq!(decoded.zero_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_large_range() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
sketch.add(0.001);
|
||||
sketch.add(1_000_000.0);
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 2);
|
||||
assert_eq!(decoded.min(), Some(0.001));
|
||||
assert_eq!(decoded.max(), Some(1_000_000.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_with_version_prefix() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [1.0, 2.0, 3.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
|
||||
// Simulate Java's toByteArrayV2: prepend 0x02
|
||||
let mut v2_bytes = vec![0x02];
|
||||
v2_bytes.extend_from_slice(&bytes);
|
||||
|
||||
let decoded = DDSketch::from_java_bytes(&v2_bytes).unwrap();
|
||||
assert_eq!(decoded.count(), 3);
|
||||
assert_eq!(decoded.min(), Some(1.0));
|
||||
assert_eq!(decoded.max(), Some(3.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_byte_level_encoding() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
sketch.add(1.0);
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
|
||||
assert_eq!(bytes[0], FLAG_COUNT, "first byte should be COUNT flag");
|
||||
assert!(
|
||||
bytes.contains(&FLAG_INDEX_MAPPING_LOG),
|
||||
"should contain index mapping flag"
|
||||
);
|
||||
}
|
||||
|
||||
// --- Cross-language golden byte tests ---
|
||||
//
|
||||
// Golden bytes generated by Java's DDSketchWithExactSummaryStatistics.encode()
|
||||
// using LogarithmicMapping(0.01) + CollapsingLowestDenseStore(2048).
|
||||
|
||||
const GOLDEN_SIMPLE: &str = "a00588000000000000f03f8c0000000000001440840000000000002e4002fd4a815abf52f03f000000000000000005050002440228021e021602";
|
||||
const GOLDEN_SINGLE: &str = "a0028800000000000045408c000000000000454084000000000000454002fd4a815abf52f03f00000000000000000501f40202";
|
||||
const GOLDEN_NEGATIVE: &str = "a084408800000000000008c08c000000000000144084000000000000084002fd4a815abf52f03f0000000000000000050244025c02070200026c02";
|
||||
const GOLDEN_ZERO: &str = "a0048800000000000000008c000000000000004084000000000000084002fd4a815abf52f03f00000000000000000402050200024402";
|
||||
const GOLDEN_EMPTY: &str = "02fd4a815abf52f03f0000000000000000";
|
||||
const GOLDEN_MANY: &str = "a08d1488000000000000f03f8c0000000000005940840000000000bab34002fd4a815abf52f03f000000000000000005550002440228021e021602120210020c020c020c0208020a020802060208020602060206020602040206020402040204020402040204020402040204020202040202020402020204020202020204020202020202020402020202020202020202020202020202020202020202020202020202020203020202020202020302020202020302020202020302020203020202030202020302030202020302030203020202030203020302030202";
|
||||
|
||||
fn hex_to_bytes(hex: &str) -> Vec<u8> {
|
||||
(0..hex.len())
|
||||
.step_by(2)
|
||||
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bytes_to_hex(bytes: &[u8]) -> String {
|
||||
bytes.iter().map(|b| format!("{b:02x}")).collect()
|
||||
}
|
||||
|
||||
fn assert_golden(label: &str, sketch: &DDSketch, golden_hex: &str) {
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let expected = hex_to_bytes(golden_hex);
|
||||
assert_eq!(
|
||||
bytes,
|
||||
expected,
|
||||
"Rust encoding doesn't match Java golden bytes for {}.\nRust: {}\nJava: {}",
|
||||
label,
|
||||
bytes_to_hex(&bytes),
|
||||
golden_hex,
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_quantiles_match(a: &DDSketch, b: &DDSketch, quantiles: &[f64]) {
|
||||
for &q in quantiles {
|
||||
let va = a.quantile(q).unwrap().unwrap();
|
||||
let vb = b.quantile(q).unwrap().unwrap();
|
||||
assert!(
|
||||
(va - vb).abs() / va.abs().max(1e-15) < 1e-12,
|
||||
"quantile({}) mismatch: {} vs {}",
|
||||
q,
|
||||
va,
|
||||
vb,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_simple() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
assert_golden("SIMPLE", &sketch, GOLDEN_SIMPLE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_single() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
sketch.add(42.0);
|
||||
assert_golden("SINGLE", &sketch, GOLDEN_SINGLE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_negative() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [-3.0, -1.0, 2.0, 5.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
assert_golden("NEGATIVE", &sketch, GOLDEN_NEGATIVE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_zero() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for v in [0.0, 1.0, 2.0] {
|
||||
sketch.add(v);
|
||||
}
|
||||
assert_golden("ZERO", &sketch, GOLDEN_ZERO);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_empty() {
|
||||
let sketch = DDSketch::new(Config::defaults());
|
||||
assert_golden("EMPTY", &sketch, GOLDEN_EMPTY);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_language_many() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for i in 1..=100 {
|
||||
sketch.add(i as f64);
|
||||
}
|
||||
assert_golden("MANY", &sketch, GOLDEN_MANY);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_java_golden_bytes() {
|
||||
for (name, hex) in [
|
||||
("SIMPLE", GOLDEN_SIMPLE),
|
||||
("SINGLE", GOLDEN_SINGLE),
|
||||
("NEGATIVE", GOLDEN_NEGATIVE),
|
||||
("ZERO", GOLDEN_ZERO),
|
||||
("EMPTY", GOLDEN_EMPTY),
|
||||
("MANY", GOLDEN_MANY),
|
||||
] {
|
||||
let bytes = hex_to_bytes(hex);
|
||||
let result = DDSketch::from_java_bytes(&bytes);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"failed to decode {}: {:?}",
|
||||
name,
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_many_values() {
|
||||
let mut sketch = DDSketch::new(Config::defaults());
|
||||
for i in 1..=100 {
|
||||
sketch.add(i as f64);
|
||||
}
|
||||
|
||||
let bytes = sketch.to_java_bytes();
|
||||
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(decoded.count(), 100);
|
||||
assert_eq!(decoded.min(), Some(1.0));
|
||||
assert_eq!(decoded.max(), Some(100.0));
|
||||
assert_eq!(decoded.sum(), Some(5050.0));
|
||||
|
||||
let alpha = 0.01;
|
||||
let orig_p95 = sketch.quantile(0.95).unwrap().unwrap();
|
||||
let dec_p95 = decoded.quantile(0.95).unwrap().unwrap();
|
||||
assert!(
|
||||
(orig_p95 - dec_p95).abs() / orig_p95 < alpha,
|
||||
"p95 mismatch: {} vs {}",
|
||||
orig_p95,
|
||||
dec_p95,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
//! This crate provides a direct port of the [Golang](https://github.com/DataDog/sketches-go)
|
||||
//! [DDSketch](https://arxiv.org/pdf/1908.10693.pdf) implementation to Rust. All efforts
|
||||
//! have been made to keep this as close to the original implementation as possible, with a few
|
||||
//! tweaks to get closer to idiomatic Rust.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! Add multiple samples to a DDSketch and invoke the `quantile` method to pull any quantile from
|
||||
//! 0.0* to *1.0*.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use sketches_ddsketch::{Config, DDSketch};
|
||||
//!
|
||||
//! let c = Config::defaults();
|
||||
//! let mut d = DDSketch::new(c);
|
||||
//!
|
||||
//! d.add(1.0);
|
||||
//! d.add(1.0);
|
||||
//! d.add(1.0);
|
||||
//!
|
||||
//! let q = d.quantile(0.50).unwrap();
|
||||
//!
|
||||
//! assert!(q < Some(1.02));
|
||||
//! assert!(q > Some(0.98));
|
||||
//! ```
|
||||
//!
|
||||
//! Sketches can also be merged.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use sketches_ddsketch::{Config, DDSketch};
|
||||
//!
|
||||
//! let c = Config::defaults();
|
||||
//! let mut d1 = DDSketch::new(c);
|
||||
//! let mut d2 = DDSketch::new(c);
|
||||
//!
|
||||
//! d1.add(1.0);
|
||||
//! d2.add(2.0);
|
||||
//! d2.add(2.0);
|
||||
//!
|
||||
//! d1.merge(&d2);
|
||||
//!
|
||||
//! assert_eq!(d1.count(), 3);
|
||||
//! ```
|
||||
|
||||
pub use self::config::Config;
|
||||
pub use self::ddsketch::{DDSketch, DDSketchError};
|
||||
pub use self::encoding::DecodeError;
|
||||
|
||||
mod config;
|
||||
mod ddsketch;
|
||||
pub mod encoding;
|
||||
mod store;
|
||||
@@ -1,252 +0,0 @@
|
||||
#[cfg(feature = "use_serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const CHUNK_SIZE: i32 = 128;
|
||||
|
||||
// Divide the `dividend` by the `divisor`, rounding towards positive infinity.
|
||||
//
|
||||
// Similar to the nightly only `std::i32::div_ceil`.
|
||||
fn div_ceil(dividend: i32, divisor: i32) -> i32 {
|
||||
(dividend + divisor - 1) / divisor
|
||||
}
|
||||
|
||||
/// CollapsingLowestDenseStore
|
||||
#[derive(Clone, Debug)]
|
||||
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
|
||||
pub struct Store {
|
||||
pub(crate) bins: Vec<u64>,
|
||||
pub(crate) count: u64,
|
||||
pub(crate) min_key: i32,
|
||||
pub(crate) max_key: i32,
|
||||
pub(crate) offset: i32,
|
||||
pub(crate) bin_limit: usize,
|
||||
is_collapsed: bool,
|
||||
}
|
||||
|
||||
impl Store {
|
||||
pub fn new(bin_limit: usize) -> Self {
|
||||
Store {
|
||||
bins: Vec::new(),
|
||||
count: 0,
|
||||
min_key: i32::MAX,
|
||||
max_key: i32::MIN,
|
||||
offset: 0,
|
||||
bin_limit,
|
||||
is_collapsed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of bins.
|
||||
pub fn length(&self) -> i32 {
|
||||
self.bins.len() as i32
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.bins.is_empty()
|
||||
}
|
||||
|
||||
pub fn add(&mut self, key: i32) {
|
||||
let idx = self.get_index(key);
|
||||
self.bins[idx] += 1;
|
||||
self.count += 1;
|
||||
}
|
||||
|
||||
/// See Java: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/store/DenseStore.java (add(int index, double count) method)
|
||||
pub(crate) fn add_count(&mut self, key: i32, count: u64) {
|
||||
let idx = self.get_index(key);
|
||||
self.bins[idx] += count;
|
||||
self.count += count;
|
||||
}
|
||||
|
||||
fn get_index(&mut self, key: i32) -> usize {
|
||||
if key < self.min_key {
|
||||
if self.is_collapsed {
|
||||
return 0;
|
||||
}
|
||||
|
||||
self.extend_range(key, None);
|
||||
if self.is_collapsed {
|
||||
return 0;
|
||||
}
|
||||
} else if key > self.max_key {
|
||||
self.extend_range(key, None);
|
||||
}
|
||||
|
||||
(key - self.offset) as usize
|
||||
}
|
||||
|
||||
fn extend_range(&mut self, key: i32, second_key: Option<i32>) {
|
||||
let second_key = second_key.unwrap_or(key);
|
||||
let new_min_key = i32::min(key, i32::min(second_key, self.min_key));
|
||||
let new_max_key = i32::max(key, i32::max(second_key, self.max_key));
|
||||
|
||||
if self.is_empty() {
|
||||
let new_len = self.get_new_length(new_min_key, new_max_key);
|
||||
self.bins.resize(new_len, 0);
|
||||
self.offset = new_min_key;
|
||||
self.adjust(new_min_key, new_max_key);
|
||||
} else if new_min_key >= self.min_key && new_max_key < self.offset + self.length() {
|
||||
self.min_key = new_min_key;
|
||||
self.max_key = new_max_key;
|
||||
} else {
|
||||
// Grow bins
|
||||
let new_length = self.get_new_length(new_min_key, new_max_key);
|
||||
if new_length > self.length() as usize {
|
||||
self.bins.resize(new_length, 0);
|
||||
}
|
||||
self.adjust(new_min_key, new_max_key);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_new_length(&self, new_min_key: i32, new_max_key: i32) -> usize {
|
||||
let desired_length = new_max_key - new_min_key + 1;
|
||||
usize::min(
|
||||
(CHUNK_SIZE * div_ceil(desired_length, CHUNK_SIZE)) as usize,
|
||||
self.bin_limit,
|
||||
)
|
||||
}
|
||||
|
||||
fn adjust(&mut self, new_min_key: i32, new_max_key: i32) {
|
||||
if new_max_key - new_min_key + 1 > self.length() {
|
||||
let new_min_key = new_max_key - self.length() + 1;
|
||||
|
||||
if new_min_key >= self.max_key {
|
||||
// Put everything in the first bin.
|
||||
self.offset = new_min_key;
|
||||
self.min_key = new_min_key;
|
||||
self.bins.fill(0);
|
||||
self.bins[0] = self.count;
|
||||
} else {
|
||||
let shift = self.offset - new_min_key;
|
||||
if shift < 0 {
|
||||
let collapse_start_index = (self.min_key - self.offset) as usize;
|
||||
let collapse_end_index = (new_min_key - self.offset) as usize;
|
||||
let collapsed_count: u64 = self.bins[collapse_start_index..collapse_end_index]
|
||||
.iter()
|
||||
.sum();
|
||||
let zero_len = (new_min_key - self.min_key) as usize;
|
||||
self.bins.splice(
|
||||
collapse_start_index..collapse_end_index,
|
||||
std::iter::repeat_n(0, zero_len),
|
||||
);
|
||||
self.bins[collapse_end_index] += collapsed_count;
|
||||
}
|
||||
self.min_key = new_min_key;
|
||||
self.shift_bins(shift);
|
||||
}
|
||||
|
||||
self.max_key = new_max_key;
|
||||
self.is_collapsed = true;
|
||||
} else {
|
||||
self.center_bins(new_min_key, new_max_key);
|
||||
self.min_key = new_min_key;
|
||||
self.max_key = new_max_key;
|
||||
}
|
||||
}
|
||||
|
||||
fn shift_bins(&mut self, shift: i32) {
|
||||
if shift > 0 {
|
||||
let shift = shift as usize;
|
||||
self.bins.rotate_right(shift);
|
||||
for idx in 0..shift {
|
||||
self.bins[idx] = 0;
|
||||
}
|
||||
} else {
|
||||
let shift = shift.unsigned_abs() as usize;
|
||||
for idx in 0..shift {
|
||||
self.bins[idx] = 0;
|
||||
}
|
||||
self.bins.rotate_left(shift);
|
||||
}
|
||||
|
||||
self.offset -= shift;
|
||||
}
|
||||
|
||||
fn center_bins(&mut self, new_min_key: i32, new_max_key: i32) {
|
||||
let middle_key = new_min_key + (new_max_key - new_min_key + 1) / 2;
|
||||
let shift = self.offset + self.length() / 2 - middle_key;
|
||||
self.shift_bins(shift)
|
||||
}
|
||||
|
||||
pub fn key_at_rank(&self, rank: u64) -> i32 {
|
||||
let mut n = 0;
|
||||
for (i, bin) in self.bins.iter().enumerate() {
|
||||
n += *bin;
|
||||
if n > rank {
|
||||
return i as i32 + self.offset;
|
||||
}
|
||||
}
|
||||
|
||||
self.max_key
|
||||
}
|
||||
|
||||
pub fn count(&self) -> u64 {
|
||||
self.count
|
||||
}
|
||||
|
||||
pub fn merge(&mut self, other: &Store) {
|
||||
if other.count == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.count == 0 {
|
||||
self.copy(other);
|
||||
return;
|
||||
}
|
||||
|
||||
if other.min_key < self.min_key || other.max_key > self.max_key {
|
||||
self.extend_range(other.min_key, Some(other.max_key));
|
||||
}
|
||||
|
||||
let collapse_start_index = other.min_key - other.offset;
|
||||
let mut collapse_end_index = i32::min(self.min_key, other.max_key + 1) - other.offset;
|
||||
if collapse_end_index > collapse_start_index {
|
||||
let collapsed_count: u64 = self.bins
|
||||
[collapse_start_index as usize..collapse_end_index as usize]
|
||||
.iter()
|
||||
.sum();
|
||||
self.bins[0] += collapsed_count;
|
||||
} else {
|
||||
collapse_end_index = collapse_start_index;
|
||||
}
|
||||
|
||||
for key in (collapse_end_index + other.offset)..(other.max_key + 1) {
|
||||
self.bins[(key - self.offset) as usize] += other.bins[(key - other.offset) as usize]
|
||||
}
|
||||
|
||||
self.count += other.count;
|
||||
}
|
||||
|
||||
fn copy(&mut self, o: &Store) {
|
||||
self.bins = o.bins.clone();
|
||||
self.count = o.count;
|
||||
self.min_key = o.min_key;
|
||||
self.max_key = o.max_key;
|
||||
self.offset = o.offset;
|
||||
self.bin_limit = o.bin_limit;
|
||||
self.is_collapsed = o.is_collapsed;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::store::Store;
|
||||
|
||||
#[test]
|
||||
fn test_simple_store() {
|
||||
let mut s = Store::new(2048);
|
||||
|
||||
for i in 0..2048 {
|
||||
s.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_store_rev() {
|
||||
let mut s = Store::new(2048);
|
||||
|
||||
for i in (0..2048).rev() {
|
||||
s.add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::f64::NAN;
|
||||
|
||||
pub struct Dataset {
|
||||
values: Vec<f64>,
|
||||
sum: f64,
|
||||
sorted: bool,
|
||||
}
|
||||
|
||||
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
|
||||
assert!(!a.is_nan() && !b.is_nan());
|
||||
|
||||
if a < b {
|
||||
return Ordering::Less;
|
||||
} else if a > b {
|
||||
return Ordering::Greater;
|
||||
} else {
|
||||
return Ordering::Equal;
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new() -> Self {
|
||||
Dataset {
|
||||
values: Vec::new(),
|
||||
sum: 0.0,
|
||||
sorted: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, value: f64) {
|
||||
self.values.push(value);
|
||||
self.sum += value;
|
||||
self.sorted = false;
|
||||
}
|
||||
|
||||
// pub fn quantile(&mut self, q: f64) -> f64 {
|
||||
// self.lower_quantile(q)
|
||||
// }
|
||||
|
||||
pub fn lower_quantile(&mut self, q: f64) -> f64 {
|
||||
if q < 0.0 || q > 1.0 || self.values.len() == 0 {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
self.sort();
|
||||
let rank = q * (self.values.len() - 1) as f64;
|
||||
|
||||
self.values[rank.floor() as usize]
|
||||
}
|
||||
|
||||
pub fn upper_quantile(&mut self, q: f64) -> f64 {
|
||||
if q < 0.0 || q > 1.0 || self.values.len() == 0 {
|
||||
return NAN;
|
||||
}
|
||||
|
||||
self.sort();
|
||||
let rank = q * (self.values.len() - 1) as f64;
|
||||
self.values[rank.ceil() as usize]
|
||||
}
|
||||
|
||||
pub fn min(&mut self) -> f64 {
|
||||
self.sort();
|
||||
self.values[0]
|
||||
}
|
||||
|
||||
pub fn max(&mut self) -> f64 {
|
||||
self.sort();
|
||||
self.values[self.values.len() - 1]
|
||||
}
|
||||
|
||||
pub fn sum(&self) -> f64 {
|
||||
self.sum
|
||||
}
|
||||
|
||||
pub fn count(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
fn sort(&mut self) {
|
||||
if self.sorted {
|
||||
return;
|
||||
}
|
||||
|
||||
self.values.sort_by(cmp_f64);
|
||||
self.sorted = true;
|
||||
}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
extern crate rand;
|
||||
extern crate rand_distr;
|
||||
|
||||
use rand::prelude::*;
|
||||
|
||||
pub trait Generator {
|
||||
fn generate(&mut self) -> f64;
|
||||
}
|
||||
|
||||
// Constant generator
|
||||
//
|
||||
pub struct Constant {
|
||||
value: f64,
|
||||
}
|
||||
impl Constant {
|
||||
pub fn new(value: f64) -> Self {
|
||||
Constant { value }
|
||||
}
|
||||
}
|
||||
impl Generator for Constant {
|
||||
fn generate(&mut self) -> f64 {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
// Linear generator
|
||||
//
|
||||
pub struct Linear {
|
||||
current_value: f64,
|
||||
step: f64,
|
||||
}
|
||||
impl Linear {
|
||||
pub fn new(start_value: f64, step: f64) -> Self {
|
||||
Linear {
|
||||
current_value: start_value,
|
||||
step,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Generator for Linear {
|
||||
fn generate(&mut self) -> f64 {
|
||||
let value = self.current_value;
|
||||
self.current_value += self.step;
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
// Normal distribution generator
|
||||
//
|
||||
pub struct Normal {
|
||||
distr: rand_distr::Normal<f64>,
|
||||
}
|
||||
impl Normal {
|
||||
pub fn new(mean: f64, stddev: f64) -> Self {
|
||||
Normal {
|
||||
distr: rand_distr::Normal::new(mean, stddev).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Generator for Normal {
|
||||
fn generate(&mut self) -> f64 {
|
||||
self.distr.sample(&mut rand::thread_rng())
|
||||
}
|
||||
}
|
||||
|
||||
// Lognormal distribution generator
|
||||
//
|
||||
pub struct Lognormal {
|
||||
distr: rand_distr::LogNormal<f64>,
|
||||
}
|
||||
impl Lognormal {
|
||||
pub fn new(mean: f64, stddev: f64) -> Self {
|
||||
Lognormal {
|
||||
distr: rand_distr::LogNormal::new(mean, stddev).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Generator for Lognormal {
|
||||
fn generate(&mut self) -> f64 {
|
||||
self.distr.sample(&mut rand::thread_rng())
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential distribution generator
|
||||
//
|
||||
pub struct Exponential {
|
||||
distr: rand_distr::Exp<f64>,
|
||||
}
|
||||
impl Exponential {
|
||||
pub fn new(lambda: f64) -> Self {
|
||||
Exponential {
|
||||
distr: rand_distr::Exp::new(lambda).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Generator for Exponential {
|
||||
fn generate(&mut self) -> f64 {
|
||||
self.distr.sample(&mut rand::thread_rng())
|
||||
}
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
pub mod dataset;
|
||||
pub mod generator;
|
||||
@@ -1,316 +0,0 @@
|
||||
mod common;
|
||||
use std::time::Instant;
|
||||
|
||||
use common::dataset::Dataset;
|
||||
use common::generator;
|
||||
use common::generator::Generator;
|
||||
use sketches_ddsketch::{Config, DDSketch};
|
||||
|
||||
const TEST_ALPHA: f64 = 0.01;
|
||||
const TEST_MAX_BINS: u32 = 1024;
|
||||
const TEST_MIN_VALUE: f64 = 1.0e-9;
|
||||
|
||||
// Used for float equality
|
||||
const TEST_ERROR_THRESH: f64 = 1.0e-9;
|
||||
|
||||
const TEST_SIZES: [usize; 5] = [3, 5, 10, 100, 1000];
|
||||
const TEST_QUANTILES: [f64; 10] = [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999, 1.0];
|
||||
|
||||
#[test]
|
||||
fn test_constant() {
|
||||
evaluate_sketches(|| Box::new(generator::Constant::new(42.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear() {
|
||||
evaluate_sketches(|| Box::new(generator::Linear::new(0.0, 1.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal() {
|
||||
evaluate_sketches(|| Box::new(generator::Normal::new(35.0, 1.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lognormal() {
|
||||
evaluate_sketches(|| Box::new(generator::Lognormal::new(0.0, 2.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exponential() {
|
||||
evaluate_sketches(|| Box::new(generator::Exponential::new(2.0)));
|
||||
}
|
||||
|
||||
fn evaluate_test_sizes(f: impl Fn(usize)) {
|
||||
for sz in &TEST_SIZES {
|
||||
f(*sz);
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_sketches(gen_factory: impl Fn() -> Box<dyn generator::Generator>) {
|
||||
evaluate_test_sizes(|sz: usize| {
|
||||
let mut generator = gen_factory();
|
||||
evaluate_sketch(sz, &mut generator);
|
||||
});
|
||||
}
|
||||
|
||||
fn new_config() -> Config {
|
||||
Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE)
|
||||
}
|
||||
|
||||
fn assert_float_eq(a: f64, b: f64) {
|
||||
assert!((a - b).abs() < TEST_ERROR_THRESH, "{} != {}", a, b);
|
||||
}
|
||||
|
||||
fn evaluate_sketch(count: usize, generator: &mut Box<dyn generator::Generator>) {
|
||||
let c = new_config();
|
||||
let mut g = DDSketch::new(c);
|
||||
|
||||
let mut d = Dataset::new();
|
||||
|
||||
for _i in 0..count {
|
||||
let value = generator.generate();
|
||||
|
||||
g.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
|
||||
compare_sketches(&mut d, &g);
|
||||
}
|
||||
|
||||
fn compare_sketches(d: &mut Dataset, g: &DDSketch) {
|
||||
for q in &TEST_QUANTILES {
|
||||
let lower = d.lower_quantile(*q);
|
||||
let upper = d.upper_quantile(*q);
|
||||
|
||||
let min_expected;
|
||||
if lower < 0.0 {
|
||||
min_expected = lower * (1.0 + TEST_ALPHA);
|
||||
} else {
|
||||
min_expected = lower * (1.0 - TEST_ALPHA);
|
||||
}
|
||||
|
||||
let max_expected;
|
||||
if upper > 0.0 {
|
||||
max_expected = upper * (1.0 + TEST_ALPHA);
|
||||
} else {
|
||||
max_expected = upper * (1.0 - TEST_ALPHA);
|
||||
}
|
||||
|
||||
let quantile = g.quantile(*q).unwrap().unwrap();
|
||||
|
||||
assert!(
|
||||
min_expected <= quantile,
|
||||
"Lower than min, quantile: {}, wanted {} <= {}",
|
||||
*q,
|
||||
min_expected,
|
||||
quantile
|
||||
);
|
||||
assert!(
|
||||
quantile <= max_expected,
|
||||
"Higher than max, quantile: {}, wanted {} <= {}",
|
||||
*q,
|
||||
quantile,
|
||||
max_expected
|
||||
);
|
||||
|
||||
// verify that calls do not modify result (not mut so not possible?)
|
||||
let quantile2 = g.quantile(*q).unwrap().unwrap();
|
||||
assert_eq!(quantile, quantile2);
|
||||
}
|
||||
|
||||
assert_eq!(g.min().unwrap(), d.min());
|
||||
assert_eq!(g.max().unwrap(), d.max());
|
||||
assert_float_eq(g.sum().unwrap(), d.sum());
|
||||
assert_eq!(g.count(), d.count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_normal() {
|
||||
evaluate_test_sizes(|sz: usize| {
|
||||
let c = new_config();
|
||||
let mut d = Dataset::new();
|
||||
let mut g1 = DDSketch::new(c);
|
||||
|
||||
let mut generator1 = generator::Normal::new(35.0, 1.0);
|
||||
for _ in (0..sz).step_by(3) {
|
||||
let value = generator1.generate();
|
||||
g1.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
let mut g2 = DDSketch::new(c);
|
||||
let mut generator2 = generator::Normal::new(50.0, 2.0);
|
||||
for _ in (1..sz).step_by(3) {
|
||||
let value = generator2.generate();
|
||||
g2.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
g1.merge(&g2).unwrap();
|
||||
|
||||
let mut g3 = DDSketch::new(c);
|
||||
let mut generator3 = generator::Normal::new(40.0, 0.5);
|
||||
for _ in (2..sz).step_by(3) {
|
||||
let value = generator3.generate();
|
||||
g3.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
g1.merge(&g3).unwrap();
|
||||
|
||||
compare_sketches(&mut d, &g1);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_empty() {
|
||||
evaluate_test_sizes(|sz: usize| {
|
||||
let c = new_config();
|
||||
|
||||
let mut d = Dataset::new();
|
||||
|
||||
let mut g1 = DDSketch::new(c);
|
||||
let mut g2 = DDSketch::new(c);
|
||||
let mut generator = generator::Exponential::new(5.0);
|
||||
|
||||
for _ in 0..sz {
|
||||
let value = generator.generate();
|
||||
g2.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
g1.merge(&g2).unwrap();
|
||||
compare_sketches(&mut d, &g1);
|
||||
|
||||
let g3 = DDSketch::new(c);
|
||||
g2.merge(&g3).unwrap();
|
||||
compare_sketches(&mut d, &g2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_mixed() {
|
||||
evaluate_test_sizes(|sz: usize| {
|
||||
let c = new_config();
|
||||
let mut d = Dataset::new();
|
||||
let mut g1 = DDSketch::new(c);
|
||||
|
||||
let mut generator1 = generator::Normal::new(100.0, 1.0);
|
||||
for _ in (0..sz).step_by(3) {
|
||||
let value = generator1.generate();
|
||||
g1.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
|
||||
let mut g2 = DDSketch::new(c);
|
||||
let mut generator2 = generator::Exponential::new(5.0);
|
||||
for _ in (1..sz).step_by(3) {
|
||||
let value = generator2.generate();
|
||||
g2.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
g1.merge(&g2).unwrap();
|
||||
|
||||
let mut g3 = DDSketch::new(c);
|
||||
let mut generator3 = generator::Exponential::new(0.1);
|
||||
for _ in (2..sz).step_by(3) {
|
||||
let value = generator3.generate();
|
||||
g3.add(value);
|
||||
d.add(value);
|
||||
}
|
||||
g1.merge(&g3).unwrap();
|
||||
|
||||
compare_sketches(&mut d, &g1);
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_incompatible() {
|
||||
let c1 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE);
|
||||
let c2 = Config::new(TEST_ALPHA * 2.0, TEST_MAX_BINS, TEST_MIN_VALUE);
|
||||
|
||||
let mut d1 = DDSketch::new(c1);
|
||||
let d2 = DDSketch::new(c2);
|
||||
|
||||
assert!(d1.merge(&d2).is_err());
|
||||
|
||||
let c3 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE * 10.0);
|
||||
let d3 = DDSketch::new(c3);
|
||||
|
||||
assert!(d1.merge(&d3).is_err());
|
||||
|
||||
let c4 = Config::new(TEST_ALPHA, TEST_MAX_BINS * 2, TEST_MIN_VALUE);
|
||||
let d4 = DDSketch::new(c4);
|
||||
|
||||
assert!(d1.merge(&d4).is_err());
|
||||
|
||||
// the same should work
|
||||
let c5 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE);
|
||||
let dsame = DDSketch::new(c5);
|
||||
assert!(d1.merge(&dsame).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_performance_insert() {
|
||||
let c = Config::defaults();
|
||||
let mut g = DDSketch::new(c);
|
||||
let mut gen = generator::Normal::new(1000.0, 500.0);
|
||||
let count = 300_000_000;
|
||||
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..count {
|
||||
values.push(gen.generate());
|
||||
}
|
||||
|
||||
let start_time = Instant::now();
|
||||
for value in values {
|
||||
g.add(value);
|
||||
}
|
||||
|
||||
// This simply ensures the operations don't get optimzed out as ignored
|
||||
let quantile = g.quantile(0.50).unwrap().unwrap();
|
||||
|
||||
let elapsed = start_time.elapsed().as_micros() as f64;
|
||||
let elapsed = elapsed / 1_000_000.0;
|
||||
|
||||
println!(
|
||||
"RESULT: p50={:.2} => Added {}M samples in {:2} secs ({:.2}M samples/sec)",
|
||||
quantile,
|
||||
count / 1_000_000,
|
||||
elapsed,
|
||||
(count as f64) / 1_000_000.0 / elapsed
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_performance_merge() {
|
||||
let c = Config::defaults();
|
||||
let mut gen = generator::Normal::new(1000.0, 500.0);
|
||||
let merge_count = 500_000;
|
||||
let sample_count = 1_000;
|
||||
let mut sketches = Vec::new();
|
||||
|
||||
for _ in 0..merge_count {
|
||||
let mut d = DDSketch::new(c);
|
||||
for _ in 0..sample_count {
|
||||
d.add(gen.generate());
|
||||
}
|
||||
sketches.push(d);
|
||||
}
|
||||
|
||||
let mut base = DDSketch::new(c);
|
||||
|
||||
let start_time = Instant::now();
|
||||
for sketch in &sketches {
|
||||
base.merge(sketch).unwrap();
|
||||
}
|
||||
|
||||
let elapsed = start_time.elapsed().as_micros() as f64;
|
||||
let elapsed = elapsed / 1_000_000.0;
|
||||
|
||||
println!(
|
||||
"RESULT: Merged {} sketches in {:2} secs ({:.2} merges/sec)",
|
||||
merge_count,
|
||||
elapsed,
|
||||
(merge_count as f64) / elapsed
|
||||
);
|
||||
}
|
||||
@@ -95,21 +95,11 @@ pub(crate) fn get_all_ff_reader_or_empty(
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
fallback_type: ColumnType,
|
||||
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
|
||||
let mut ff_field_with_type = get_all_ff_readers(reader, field_name, allowed_column_types)?;
|
||||
let ff_fields = reader.fast_fields();
|
||||
let mut ff_field_with_type =
|
||||
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
|
||||
if ff_field_with_type.is_empty() {
|
||||
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
|
||||
}
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
|
||||
/// Get all fast field reader.
|
||||
pub(crate) fn get_all_ff_readers(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
|
||||
let ff_fields = reader.fast_fields();
|
||||
let ff_field_with_type =
|
||||
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
|
||||
use columnar::{Column, ColumnType, StrColumn};
|
||||
use common::BitSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::Serialize;
|
||||
@@ -9,18 +9,17 @@ use crate::aggregation::accessor_helpers::{
|
||||
get_numeric_or_date_column_types,
|
||||
};
|
||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
pub use crate::aggregation::bucket::{CompositeAggReqData, CompositeSourceAccessors};
|
||||
use crate::aggregation::bucket::{
|
||||
build_segment_filter_collector, build_segment_range_collector, CompositeAggregation,
|
||||
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentCompositeCollector, SegmentHistogramCollector,
|
||||
TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
TermsAggregationInternal,
|
||||
};
|
||||
use crate::aggregation::metric::{
|
||||
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
|
||||
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
|
||||
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
|
||||
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
|
||||
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
@@ -36,7 +35,6 @@ pub struct AggregationsSegmentCtx {
|
||||
/// Request data for each aggregation type.
|
||||
pub per_request: PerRequestAggSegCtx,
|
||||
pub context: AggContextParams,
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
}
|
||||
|
||||
impl AggregationsSegmentCtx {
|
||||
@@ -74,12 +72,6 @@ impl AggregationsSegmentCtx {
|
||||
self.per_request.filter_req_data.push(Some(Box::new(data)));
|
||||
self.per_request.filter_req_data.len() - 1
|
||||
}
|
||||
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
|
||||
self.per_request
|
||||
.composite_req_data
|
||||
.push(Some(Box::new(data)));
|
||||
self.per_request.composite_req_data.len() - 1
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
|
||||
@@ -116,19 +108,20 @@ impl AggregationsSegmentCtx {
|
||||
.expect("range_req_data slot is empty (taken)")
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
|
||||
self.per_request.composite_req_data[idx]
|
||||
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
|
||||
self.per_request.filter_req_data[idx]
|
||||
.as_deref()
|
||||
.expect("composite_req_data slot is empty (taken)")
|
||||
.expect("filter_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- mutable getters ----------
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
|
||||
self.per_request.term_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_cardinality_req_data_mut(
|
||||
&mut self,
|
||||
@@ -136,21 +129,33 @@ impl AggregationsSegmentCtx {
|
||||
) -> &mut CardinalityAggReqData {
|
||||
&mut self.per_request.cardinality_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
|
||||
self.per_request.histogram_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("histogram_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- take / put (terms, histogram, range) ----------
|
||||
|
||||
/// Move out the boxed Terms request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn get_composite_req_data_mut(&mut self, idx: usize) -> &mut CompositeAggReqData {
|
||||
self.per_request.composite_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("composite_req_data slot is empty (taken)")
|
||||
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
|
||||
self.per_request.term_req_data[idx]
|
||||
.take()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- take / put (terms, histogram, range, composite) ----------
|
||||
/// Put back a Terms request into an empty slot at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
|
||||
debug_assert!(self.per_request.term_req_data[idx].is_none());
|
||||
self.per_request.term_req_data[idx] = Some(value);
|
||||
}
|
||||
|
||||
/// Move out the boxed Histogram request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
@@ -200,25 +205,6 @@ impl AggregationsSegmentCtx {
|
||||
debug_assert!(self.per_request.filter_req_data[idx].is_none());
|
||||
self.per_request.filter_req_data[idx] = Some(value);
|
||||
}
|
||||
|
||||
/// Move out the Composite request at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
|
||||
self.per_request.composite_req_data[idx]
|
||||
.take()
|
||||
.expect("composite_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
/// Put back a Composite request into an empty slot at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn put_back_composite_req_data(
|
||||
&mut self,
|
||||
idx: usize,
|
||||
value: Box<CompositeAggReqData>,
|
||||
) {
|
||||
debug_assert!(self.per_request.composite_req_data[idx].is_none());
|
||||
self.per_request.composite_req_data[idx] = Some(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Each type of aggregation has its own request data struct. This struct holds
|
||||
@@ -238,8 +224,6 @@ pub struct PerRequestAggSegCtx {
|
||||
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
|
||||
/// FilterAggReqData contains the request data for a filter aggregation.
|
||||
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
|
||||
/// CompositeAggReqData contains the request data for a composite aggregation.
|
||||
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
|
||||
/// Shared by avg, min, max, sum, stats, extended_stats, count
|
||||
pub stats_metric_req_data: Vec<MetricAggReqData>,
|
||||
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
|
||||
@@ -295,11 +279,6 @@ impl PerRequestAggSegCtx {
|
||||
.iter()
|
||||
.map(|t| t.get_memory_consumption())
|
||||
.sum::<usize>()
|
||||
+ self
|
||||
.composite_req_data
|
||||
.iter()
|
||||
.map(|t| t.as_ref().unwrap().get_memory_consumption())
|
||||
.sum::<usize>()
|
||||
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
|
||||
}
|
||||
|
||||
@@ -336,17 +315,11 @@ impl PerRequestAggSegCtx {
|
||||
.expect("filter_req_data slot is empty (taken)")
|
||||
.name
|
||||
.as_str(),
|
||||
AggKind::Composite => &self.composite_req_data[idx]
|
||||
.as_deref()
|
||||
.expect("composite_req_data slot is empty (taken)")
|
||||
.name
|
||||
.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the aggregation tree into a serializable struct representation.
|
||||
/// Each node contains: { name, kind, children }.
|
||||
#[allow(dead_code)]
|
||||
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
|
||||
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
|
||||
let mut children: Vec<AggTreeViewNode> =
|
||||
@@ -372,19 +345,12 @@ impl PerRequestAggSegCtx {
|
||||
pub(crate) fn build_segment_agg_collectors_root(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
|
||||
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_agg_collectors(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, nodes)
|
||||
}
|
||||
|
||||
fn build_segment_agg_collectors_generic(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let mut collectors = Vec::new();
|
||||
for node in nodes.iter() {
|
||||
@@ -422,8 +388,6 @@ pub(crate) fn build_segment_agg_collector(
|
||||
Ok(Box::new(SegmentCardinalityCollector::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
)))
|
||||
}
|
||||
AggKind::StatsKind(stats_type) => {
|
||||
@@ -434,21 +398,20 @@ pub(crate) fn build_segment_agg_collector(
|
||||
| StatsType::Count
|
||||
| StatsType::Max
|
||||
| StatsType::Min
|
||||
| StatsType::Stats => build_segment_stats_collector(req_data),
|
||||
StatsType::ExtendedStats(sigma) => Ok(Box::new(
|
||||
SegmentExtendedStatsCollector::from_req(req_data, sigma),
|
||||
)),
|
||||
StatsType::Percentiles => {
|
||||
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
|
||||
Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(
|
||||
req_data.field_type,
|
||||
req_data.missing_u64,
|
||||
req_data.accessor.clone(),
|
||||
node.idx_in_req_data,
|
||||
),
|
||||
))
|
||||
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
node.idx_in_req_data,
|
||||
))),
|
||||
StatsType::ExtendedStats(sigma) => {
|
||||
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
|
||||
req_data.field_type,
|
||||
sigma,
|
||||
node.idx_in_req_data,
|
||||
req_data.missing,
|
||||
)))
|
||||
}
|
||||
StatsType::Percentiles => Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
AggKind::TopHits => {
|
||||
@@ -465,9 +428,10 @@ pub(crate) fn build_segment_agg_collector(
|
||||
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
|
||||
AggKind::Filter => build_segment_filter_collector(req, node),
|
||||
AggKind::Composite => Ok(Box::new(SegmentCompositeCollector::from_req_and_validate(
|
||||
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
}
|
||||
@@ -500,7 +464,6 @@ pub enum AggKind {
|
||||
DateHistogram,
|
||||
Range,
|
||||
Filter,
|
||||
Composite,
|
||||
}
|
||||
|
||||
impl AggKind {
|
||||
@@ -516,7 +479,6 @@ impl AggKind {
|
||||
AggKind::DateHistogram => "DateHistogram",
|
||||
AggKind::Range => "Range",
|
||||
AggKind::Filter => "Filter",
|
||||
AggKind::Composite => "Composite",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -531,7 +493,6 @@ pub(crate) fn build_aggregations_data_from_req(
|
||||
let mut data = AggregationsSegmentCtx {
|
||||
per_request: Default::default(),
|
||||
context,
|
||||
column_block_accessor: ColumnBlockAccessor::default(),
|
||||
};
|
||||
|
||||
for (name, agg) in aggs.iter() {
|
||||
@@ -560,9 +521,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: range_req.clone(),
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -580,7 +541,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req.clone(),
|
||||
is_date_histogram: false,
|
||||
bounds: HistogramBounds {
|
||||
@@ -605,7 +568,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req,
|
||||
is_date_histogram: true,
|
||||
bounds: HistogramBounds {
|
||||
@@ -685,6 +650,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for,
|
||||
missing: *missing,
|
||||
@@ -712,6 +678,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for: StatsType::Percentiles,
|
||||
missing: percentiles_req.missing,
|
||||
@@ -786,7 +753,6 @@ fn build_nodes(
|
||||
segment_reader: reader.clone(),
|
||||
evaluator,
|
||||
matching_docs_buffer,
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -795,14 +761,6 @@ fn build_nodes(
|
||||
children,
|
||||
}])
|
||||
}
|
||||
AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node(
|
||||
agg_name,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
data,
|
||||
&req.sub_aggregation,
|
||||
composite_req,
|
||||
)?]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -937,7 +895,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
});
|
||||
}
|
||||
|
||||
// Add one node per accessor
|
||||
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
|
||||
for (accessor, column_type) in column_and_types {
|
||||
let missing_value_for_accessor = if use_special_missing_agg {
|
||||
None
|
||||
@@ -968,8 +926,11 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
// Will be filled later when building collectors
|
||||
sub_aggregation_blueprint: None,
|
||||
sug_aggregations: sub_aggs.clone(),
|
||||
allowed_term_ids,
|
||||
is_top_level,
|
||||
@@ -982,6 +943,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
});
|
||||
@@ -998,35 +960,6 @@ fn build_terms_or_cardinality_nodes(
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
fn build_composite_node(
|
||||
agg_name: &str,
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
data: &mut AggregationsSegmentCtx,
|
||||
sub_aggs: &Aggregations,
|
||||
req: &CompositeAggregation,
|
||||
) -> crate::Result<AggRefNode> {
|
||||
let mut composite_accessors = Vec::with_capacity(req.sources.len());
|
||||
for source in &req.sources {
|
||||
let source_after_key_opt = req.after.get(source.name()).map(|k| &k.0);
|
||||
let source_accessor =
|
||||
CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?;
|
||||
composite_accessors.push(source_accessor);
|
||||
}
|
||||
let agg = CompositeAggReqData {
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
composite_accessors,
|
||||
};
|
||||
let idx = data.push_composite_req_data(agg);
|
||||
let children = build_children(sub_aggs, reader, segment_ordinal, data)?;
|
||||
Ok(AggRefNode {
|
||||
kind: AggKind::Composite,
|
||||
idx_in_req_data: idx,
|
||||
children,
|
||||
})
|
||||
}
|
||||
|
||||
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
|
||||
/// include/exclude parameters.
|
||||
fn build_allowed_term_ids_for_str(
|
||||
|
||||
@@ -40,7 +40,6 @@ use super::metric::{
|
||||
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
|
||||
TopHitsAggregationReq,
|
||||
};
|
||||
use crate::aggregation::bucket::CompositeAggregation;
|
||||
|
||||
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
|
||||
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
|
||||
@@ -135,9 +134,6 @@ pub enum AggregationVariants {
|
||||
/// Filter documents into a single bucket.
|
||||
#[serde(rename = "filter")]
|
||||
Filter(FilterAggregation),
|
||||
/// Put data into multi level paginated buckets.
|
||||
#[serde(rename = "composite")]
|
||||
Composite(CompositeAggregation),
|
||||
|
||||
// Metric aggregation types
|
||||
/// Computes the average of the extracted values.
|
||||
@@ -184,11 +180,6 @@ impl AggregationVariants {
|
||||
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
|
||||
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
|
||||
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
|
||||
AggregationVariants::Composite(composite) => composite
|
||||
.sources
|
||||
.iter()
|
||||
.map(|source_map| source_map.field())
|
||||
.collect(),
|
||||
AggregationVariants::Average(avg) => vec![avg.field_name()],
|
||||
AggregationVariants::Count(count) => vec![count.field_name()],
|
||||
AggregationVariants::Max(max) => vec![max.field_name()],
|
||||
@@ -223,12 +214,6 @@ impl AggregationVariants {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> {
|
||||
match &self {
|
||||
AggregationVariants::Composite(composite) => Some(composite),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
|
||||
match &self {
|
||||
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),
|
||||
|
||||
@@ -13,8 +13,6 @@ use super::metric::{
|
||||
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
|
||||
};
|
||||
use super::{AggregationError, Key};
|
||||
use crate::aggregation::bucket::AfterKey;
|
||||
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
@@ -160,16 +158,6 @@ pub enum BucketResult {
|
||||
},
|
||||
/// This is the filter result - a single bucket with sub-aggregations
|
||||
Filter(FilterBucketResult),
|
||||
/// This is the composite aggregation result
|
||||
Composite {
|
||||
/// The buckets
|
||||
///
|
||||
/// See [`CompositeAggregation`](super::bucket::CompositeAggregation)
|
||||
buckets: Vec<CompositeBucketEntry>,
|
||||
/// The key to start after when paginating
|
||||
#[serde(skip_serializing_if = "FxHashMap::is_empty")]
|
||||
after_key: FxHashMap<String, AfterKey>,
|
||||
},
|
||||
}
|
||||
|
||||
impl BucketResult {
|
||||
@@ -191,9 +179,6 @@ impl BucketResult {
|
||||
// Only count sub-aggregation buckets
|
||||
filter_result.sub_aggregations.get_bucket_count()
|
||||
}
|
||||
BucketResult::Composite { buckets, .. } => {
|
||||
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -352,130 +337,3 @@ pub struct FilterBucketResult {
|
||||
#[serde(flatten)]
|
||||
pub sub_aggregations: AggregationResults,
|
||||
}
|
||||
|
||||
/// The JSON mappable key to identify a composite bucket.
|
||||
///
|
||||
/// This is similar to `Key`, but composite keys can also be boolean and null.
|
||||
///
|
||||
/// Note the type information loss compared to `CompositeIntermediateKey`.
|
||||
/// Pagination is performed using `AfterKey`, which encodes type information.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CompositeKey {
|
||||
/// Boolean key
|
||||
Bool(bool),
|
||||
/// String key
|
||||
Str(String),
|
||||
/// `i64` key
|
||||
I64(i64),
|
||||
/// `u64` key
|
||||
U64(u64),
|
||||
/// `f64` key
|
||||
F64(f64),
|
||||
/// Null key
|
||||
Null,
|
||||
}
|
||||
impl Eq for CompositeKey {}
|
||||
impl std::hash::Hash for CompositeKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
Self::Bool(val) => val.hash(state),
|
||||
Self::Str(text) => text.hash(state),
|
||||
Self::F64(val) => val.to_bits().hash(state),
|
||||
Self::U64(val) => val.hash(state),
|
||||
Self::I64(val) => val.hash(state),
|
||||
Self::Null => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl PartialEq for CompositeKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Bool(l), Self::Bool(r)) => l == r,
|
||||
(Self::Str(l), Self::Str(r)) => l == r,
|
||||
(Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(),
|
||||
(Self::I64(l), Self::I64(r)) => l == r,
|
||||
(Self::U64(l), Self::U64(r)) => l == r,
|
||||
(Self::Null, Self::Null) => true,
|
||||
(
|
||||
Self::Bool(_)
|
||||
| Self::Str(_)
|
||||
| Self::F64(_)
|
||||
| Self::I64(_)
|
||||
| Self::U64(_)
|
||||
| Self::Null,
|
||||
_,
|
||||
) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<CompositeIntermediateKey> for CompositeKey {
|
||||
fn from(value: CompositeIntermediateKey) -> Self {
|
||||
match value {
|
||||
CompositeIntermediateKey::Str(s) => Self::Str(s),
|
||||
CompositeIntermediateKey::IpAddr(s) => {
|
||||
// Prefer to use the IPv4 representation if possible
|
||||
if let Some(ip) = s.to_ipv4_mapped() {
|
||||
Self::Str(ip.to_string())
|
||||
} else {
|
||||
Self::Str(s.to_string())
|
||||
}
|
||||
}
|
||||
CompositeIntermediateKey::F64(f) => Self::F64(f),
|
||||
CompositeIntermediateKey::Bool(f) => Self::Bool(f),
|
||||
CompositeIntermediateKey::U64(f) => Self::U64(f),
|
||||
CompositeIntermediateKey::I64(f) => Self::I64(f),
|
||||
CompositeIntermediateKey::DateTime(f) => Self::I64(f / 1_000_000), // Convert ns to ms
|
||||
CompositeIntermediateKey::Null => Self::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This is the default entry for a bucket, which contains a composite key, count, and optionally
|
||||
/// sub-aggregations.
|
||||
/// ...
|
||||
/// "my_composite": {
|
||||
/// "buckets": [
|
||||
/// {
|
||||
/// "key": {
|
||||
/// "date": 1494201600000,
|
||||
/// "product": "rocky"
|
||||
/// },
|
||||
/// "doc_count": 5
|
||||
/// },
|
||||
/// {
|
||||
/// "key": {
|
||||
/// "date": 1494201600000,
|
||||
/// "product": "balboa"
|
||||
/// },
|
||||
/// "doc_count": 2
|
||||
/// },
|
||||
/// {
|
||||
/// "key": {
|
||||
/// "date": 1494201700000,
|
||||
/// "product": "john"
|
||||
/// },
|
||||
/// "doc_count": 3
|
||||
/// }
|
||||
/// ]
|
||||
/// }
|
||||
/// ...
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CompositeBucketEntry {
|
||||
/// The identifier of the bucket.
|
||||
pub key: FxHashMap<String, CompositeKey>,
|
||||
/// Number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
#[serde(flatten)]
|
||||
/// Sub-aggregations in this bucket.
|
||||
pub sub_aggregation: AggregationResults,
|
||||
}
|
||||
|
||||
impl CompositeBucketEntry {
|
||||
pub(crate) fn get_bucket_count(&self) -> u64 {
|
||||
1 + self.sub_aggregation.get_bucket_count()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,441 +2,15 @@ use serde_json::Value;
|
||||
|
||||
use crate::aggregation::agg_req::{Aggregation, Aggregations};
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||
use crate::aggregation::collector::AggregationCollector;
|
||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||
use crate::aggregation::DistributedAggregationCollector;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, FAST};
|
||||
use crate::{Index, IndexWriter, Term};
|
||||
|
||||
// The following tests ensure that each bucket aggregation type correctly functions as a
|
||||
// sub-aggregation of another bucket aggregation in two scenarios:
|
||||
// 1) The parent has more buckets than the child sub-aggregation
|
||||
// 2) The child sub-aggregation has more buckets than the parent
|
||||
//
|
||||
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
|
||||
|
||||
#[test]
|
||||
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 4 buckets
|
||||
// Child: terms on text -> 2 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
// Exact expected structure and counts
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{
|
||||
"key": "*-3",
|
||||
"doc_count": 1,
|
||||
"to": 3.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "3-7",
|
||||
"doc_count": 3,
|
||||
"from": 3.0,
|
||||
"to": 7.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 2, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "7-20",
|
||||
"doc_count": 3,
|
||||
"from": 7.0,
|
||||
"to": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 3, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "20-*",
|
||||
"doc_count": 2,
|
||||
"from": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: histogram on score with large interval -> 1 bucket
|
||||
// Child: terms on text -> 2 buckets (cool/nohit)
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_hist": {
|
||||
"histogram": {"field": "score", "interval": 100.0},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_hist"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": 0.0,
|
||||
"doc_count": 9,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 7, "key": "cool"},
|
||||
{"doc_count": 2, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 5 buckets
|
||||
// Child: coarse range with 3 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0}
|
||||
]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: range with 4 buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several ranges
|
||||
// Child: histogram with large interval (single bucket per parent)
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text -> 2 buckets
|
||||
// Child: histogram with small interval -> multiple buckets including empties
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 4},
|
||||
{"key": 10.0, "doc_count": 2},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 1},
|
||||
{"key": 10.0, "doc_count": 0},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several buckets
|
||||
// Child: date_histogram with 30d -> single bucket per parent
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
|
||||
// Verify each parent bucket has exactly one child date bucket with matching doc_count
|
||||
for bucket in buckets {
|
||||
let parent_count = bucket["doc_count"].as_u64().unwrap();
|
||||
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(child_buckets.len(), 1);
|
||||
assert_eq!(child_buckets[0]["doc_count"], parent_count);
|
||||
}
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: date_histogram with 1d -> multiple buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
|
||||
|
||||
// cool bucket
|
||||
assert_eq!(buckets[0]["key"], "cool");
|
||||
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(cool_buckets.len(), 3);
|
||||
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
|
||||
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
|
||||
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
|
||||
|
||||
// nohit bucket
|
||||
assert_eq!(buckets[1]["key"], "nohit");
|
||||
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(nohit_buckets.len(), 2);
|
||||
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
|
||||
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_avg_req(field_name: &str) -> Aggregation {
|
||||
serde_json::from_value(json!({
|
||||
"avg": {
|
||||
@@ -451,10 +25,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
|
||||
// Note: The flushng part of these tests are outdated, since the buffering change after converting
|
||||
// the collection into one collector per request instead of per bucket.
|
||||
//
|
||||
// However they are useful as they test a complex aggregation requests.
|
||||
fn test_aggregation_flushing(
|
||||
merge_segments: bool,
|
||||
use_distributed_collector: bool,
|
||||
@@ -467,9 +37,8 @@ fn test_aggregation_flushing(
|
||||
|
||||
let reader = index.reader()?;
|
||||
|
||||
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
|
||||
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
|
||||
// block.
|
||||
assert_eq!(DOC_BLOCK_SIZE, 64);
|
||||
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
|
||||
//
|
||||
// Build a request so that on the first level we have one full cache, which is then flushed.
|
||||
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
|
||||
|
||||
@@ -1,515 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::column_values::{CompactHit, CompactSpaceU64Accessor};
|
||||
use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit};
|
||||
|
||||
use crate::aggregation::accessor_helpers::{get_all_ff_readers, get_numeric_or_date_column_types};
|
||||
use crate::aggregation::bucket::composite::numeric_types::num_proj;
|
||||
use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber;
|
||||
use crate::aggregation::bucket::composite::ToTypePaginationOrder;
|
||||
use crate::aggregation::bucket::{
|
||||
parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource,
|
||||
MissingOrder, Order,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
|
||||
use crate::{SegmentReader, TantivyError};
|
||||
|
||||
/// Contains all information required by the SegmentCompositeCollector to perform the
|
||||
/// composite aggregation on a segment.
|
||||
pub struct CompositeAggReqData {
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The normalized term aggregation request.
|
||||
pub req: CompositeAggregation,
|
||||
/// Accessors for each source, each source can have multiple accessors (columns).
|
||||
pub composite_accessors: Vec<CompositeSourceAccessors>,
|
||||
}
|
||||
|
||||
impl CompositeAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
+ self.composite_accessors.len() * std::mem::size_of::<CompositeSourceAccessors>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Accessors for a single column in a composite source.
|
||||
pub struct CompositeAccessor {
|
||||
/// The fast field column
|
||||
pub column: Column<u64>,
|
||||
/// The column type
|
||||
pub column_type: ColumnType,
|
||||
/// Term dictionary if the column type is Str
|
||||
///
|
||||
/// Only used by term sources
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// Parsed date interval for date histogram sources
|
||||
pub date_histogram_interval: PrecomputedDateInterval,
|
||||
}
|
||||
|
||||
/// Accessors to all the columns that belong to the field of a composite source.
|
||||
pub struct CompositeSourceAccessors {
|
||||
/// The accessors for this source
|
||||
pub accessors: Vec<CompositeAccessor>,
|
||||
/// The key after which to start collecting results. Applies to the first
|
||||
/// column of the source.
|
||||
pub after_key: PrecomputedAfterKey,
|
||||
|
||||
/// The column index the after_key applies to. The after_key only applies to
|
||||
/// one column. Columns before should be skipped. Columns after should be
|
||||
/// kept without comparison to the after_key.
|
||||
pub after_key_accessor_idx: usize,
|
||||
|
||||
/// Whether to skip missing values because of the after_key. Skipping only
|
||||
/// applies if the value for previous columns were exactly equal to the
|
||||
/// corresponding after keys (is_on_after_key).
|
||||
pub skip_missing: bool,
|
||||
|
||||
/// The after key was set to null to indicate that the last collected key
|
||||
/// was a missing value.
|
||||
pub is_after_key_explicit_missing: bool,
|
||||
}
|
||||
|
||||
impl CompositeSourceAccessors {
|
||||
/// Creates a new set of accessors for the composite source.
|
||||
///
|
||||
/// Precomputes some values to make collection faster.
|
||||
pub fn build_for_source(
|
||||
reader: &SegmentReader,
|
||||
source: &CompositeAggregationSource,
|
||||
// First option is None when no after key was set in the query, the
|
||||
// second option is None when the after key was set but its value for
|
||||
// this source was set to `null`
|
||||
source_after_key_opt: Option<&CompositeIntermediateKey>,
|
||||
) -> crate::Result<Self> {
|
||||
let is_after_key_explicit_missing = source_after_key_opt
|
||||
.map(|after_key| matches!(after_key, CompositeIntermediateKey::Null))
|
||||
.unwrap_or(false);
|
||||
let mut skip_missing = false;
|
||||
if let Some(CompositeIntermediateKey::Null) = source_after_key_opt {
|
||||
if !source.missing_bucket() {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"the 'after' key for a source cannot be null when 'missing_bucket' is false"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
} else if source_after_key_opt.is_some() {
|
||||
// if missing buckets come first and we have a non null after key, we skip missing
|
||||
if MissingOrder::First == source.missing_order() {
|
||||
skip_missing = true;
|
||||
}
|
||||
if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() {
|
||||
skip_missing = true;
|
||||
}
|
||||
};
|
||||
|
||||
match source {
|
||||
CompositeAggregationSource::Terms(source) => {
|
||||
let allowed_column_types = [
|
||||
ColumnType::I64,
|
||||
ColumnType::U64,
|
||||
ColumnType::F64,
|
||||
ColumnType::Str,
|
||||
ColumnType::DateTime,
|
||||
ColumnType::Bool,
|
||||
ColumnType::IpAddr,
|
||||
// ColumnType::Bytes Unsupported
|
||||
];
|
||||
let mut columns_and_types =
|
||||
get_all_ff_readers(reader, &source.field, Some(&allowed_column_types))?;
|
||||
|
||||
// Sort columns by their pagination order and determine which to skip
|
||||
columns_and_types.sort_by_key(|(_, col_type)| col_type.column_pagination_order());
|
||||
if source.order == Order::Desc {
|
||||
columns_and_types.reverse();
|
||||
}
|
||||
let after_key_accessor_idx = find_first_column_to_collect(
|
||||
&columns_and_types,
|
||||
source_after_key_opt,
|
||||
source.missing_order,
|
||||
source.order,
|
||||
)?;
|
||||
|
||||
let source_collectors: Vec<CompositeAccessor> = columns_and_types
|
||||
.into_iter()
|
||||
.map(|(column, column_type)| {
|
||||
Ok(CompositeAccessor {
|
||||
column,
|
||||
column_type,
|
||||
str_dict_column: reader.fast_fields().str(&source.field)?,
|
||||
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
|
||||
})
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
let after_key = if let Some(first_col) =
|
||||
source_collectors.get(after_key_accessor_idx)
|
||||
{
|
||||
match source_after_key_opt {
|
||||
Some(after_key) => PrecomputedAfterKey::precompute(
|
||||
&first_col,
|
||||
after_key,
|
||||
&source.field,
|
||||
source.missing_order,
|
||||
source.order,
|
||||
)?,
|
||||
None => {
|
||||
precompute_missing_after_key(false, source.missing_order, source.order)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if no columns, we don't care about the after_key
|
||||
PrecomputedAfterKey::Next(0)
|
||||
};
|
||||
|
||||
Ok(CompositeSourceAccessors {
|
||||
accessors: source_collectors,
|
||||
is_after_key_explicit_missing,
|
||||
skip_missing,
|
||||
after_key,
|
||||
after_key_accessor_idx,
|
||||
})
|
||||
}
|
||||
CompositeAggregationSource::Histogram(source) => {
|
||||
let column_and_types: Vec<(Column, ColumnType)> = get_all_ff_readers(
|
||||
reader,
|
||||
&source.field,
|
||||
Some(get_numeric_or_date_column_types()),
|
||||
)?;
|
||||
let source_collectors: Vec<CompositeAccessor> = column_and_types
|
||||
.into_iter()
|
||||
.map(|(column, column_type)| {
|
||||
Ok(CompositeAccessor {
|
||||
column,
|
||||
column_type,
|
||||
str_dict_column: None,
|
||||
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
|
||||
})
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
let after_key = match source_after_key_opt {
|
||||
Some(CompositeIntermediateKey::F64(key)) => {
|
||||
let normalized_key = *key / source.interval;
|
||||
num_proj::f64_to_i64(normalized_key).into()
|
||||
}
|
||||
Some(CompositeIntermediateKey::Null) => {
|
||||
precompute_missing_after_key(true, source.missing_order, source.order)
|
||||
}
|
||||
None => precompute_missing_after_key(true, source.missing_order, source.order),
|
||||
_ => {
|
||||
return Err(crate::TantivyError::InvalidArgument(
|
||||
"After key type invalid for interval composite source".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(CompositeSourceAccessors {
|
||||
accessors: source_collectors,
|
||||
is_after_key_explicit_missing,
|
||||
skip_missing,
|
||||
after_key,
|
||||
after_key_accessor_idx: 0,
|
||||
})
|
||||
}
|
||||
CompositeAggregationSource::DateHistogram(source) => {
|
||||
let column_and_types =
|
||||
get_all_ff_readers(reader, &source.field, Some(&[ColumnType::DateTime]))?;
|
||||
let date_histogram_interval =
|
||||
PrecomputedDateInterval::from_date_histogram_source_intervals(
|
||||
&source.fixed_interval,
|
||||
source.calendar_interval,
|
||||
)?;
|
||||
let source_collectors: Vec<CompositeAccessor> = column_and_types
|
||||
.into_iter()
|
||||
.map(|(column, column_type)| {
|
||||
Ok(CompositeAccessor {
|
||||
column,
|
||||
column_type,
|
||||
str_dict_column: None,
|
||||
date_histogram_interval,
|
||||
})
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
let after_key = match source_after_key_opt {
|
||||
Some(CompositeIntermediateKey::DateTime(key)) => {
|
||||
PrecomputedAfterKey::Exact(key.to_u64())
|
||||
}
|
||||
Some(CompositeIntermediateKey::Null) => {
|
||||
precompute_missing_after_key(true, source.missing_order, source.order)
|
||||
}
|
||||
None => precompute_missing_after_key(true, source.missing_order, source.order),
|
||||
_ => {
|
||||
return Err(crate::TantivyError::InvalidArgument(
|
||||
"After key type invalid for interval composite source".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
Ok(CompositeSourceAccessors {
|
||||
accessors: source_collectors,
|
||||
is_after_key_explicit_missing,
|
||||
skip_missing,
|
||||
after_key,
|
||||
after_key_accessor_idx: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the index of the first column we should start collecting from to
|
||||
/// resume the pagination from the after_key.
|
||||
fn find_first_column_to_collect<T>(
|
||||
sorted_columns: &[(T, ColumnType)],
|
||||
after_key_opt: Option<&CompositeIntermediateKey>,
|
||||
missing_order: MissingOrder,
|
||||
order: Order,
|
||||
) -> crate::Result<usize> {
|
||||
let after_key = match after_key_opt {
|
||||
None => return Ok(0), // No pagination, start from beginning
|
||||
Some(key) => key,
|
||||
};
|
||||
// Handle null after_key (we were on a missing value last time)
|
||||
if matches!(after_key, CompositeIntermediateKey::Null) {
|
||||
return match (missing_order, order) {
|
||||
// Missing values come first, so all columns remain
|
||||
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => Ok(0),
|
||||
// Missing values come last, so all columns are done
|
||||
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => {
|
||||
Ok(sorted_columns.len())
|
||||
}
|
||||
};
|
||||
}
|
||||
// Find the first column whose type order matches or follows the after_key's
|
||||
// type in the pagination sequence
|
||||
let after_key_column_order = after_key.column_pagination_order();
|
||||
for (idx, (_, col_type)) in sorted_columns.iter().enumerate() {
|
||||
let col_order = col_type.column_pagination_order();
|
||||
let is_first_to_collect = match order {
|
||||
Order::Asc => col_order >= after_key_column_order,
|
||||
Order::Desc => col_order <= after_key_column_order,
|
||||
};
|
||||
if is_first_to_collect {
|
||||
return Ok(idx);
|
||||
}
|
||||
}
|
||||
// All columns are before the after_key, nothing left to collect
|
||||
Ok(sorted_columns.len())
|
||||
}
|
||||
|
||||
fn precompute_missing_after_key(
|
||||
is_after_key_explicit_missing: bool,
|
||||
missing_order: MissingOrder,
|
||||
order: Order,
|
||||
) -> PrecomputedAfterKey {
|
||||
let after_last = PrecomputedAfterKey::AfterLast;
|
||||
let before_first = PrecomputedAfterKey::Next(0);
|
||||
match (is_after_key_explicit_missing, missing_order, order) {
|
||||
(true, MissingOrder::First, Order::Asc) => before_first,
|
||||
(true, MissingOrder::First, Order::Desc) => after_last,
|
||||
(true, MissingOrder::Last, Order::Asc) => after_last,
|
||||
(true, MissingOrder::Last, Order::Desc) => before_first,
|
||||
(true, MissingOrder::Default, Order::Asc) => before_first,
|
||||
(true, MissingOrder::Default, Order::Desc) => after_last,
|
||||
(false, _, Order::Asc) => before_first,
|
||||
(false, _, Order::Desc) => after_last,
|
||||
}
|
||||
}
|
||||
|
||||
/// A parsed representation of the date interval for date histogram sources
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum PrecomputedDateInterval {
|
||||
/// This is not a date histogram source
|
||||
NotApplicable,
|
||||
/// Source was configured with a fixed interval
|
||||
FixedNanoseconds(i64),
|
||||
/// Source was configured with a calendar interval
|
||||
Calendar(CalendarInterval),
|
||||
}
|
||||
|
||||
impl PrecomputedDateInterval {
|
||||
/// Validates the date histogram source interval fields and parses a date interval from them.
|
||||
pub fn from_date_histogram_source_intervals(
|
||||
fixed_interval: &Option<String>,
|
||||
calendar_interval: Option<CalendarInterval>,
|
||||
) -> crate::Result<Self> {
|
||||
match (fixed_interval, calendar_interval) {
|
||||
(Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument(
|
||||
"date histogram source must one and only one of fixed_interval or \
|
||||
calendar_interval set"
|
||||
.to_string(),
|
||||
)),
|
||||
(Some(fixed_interval), None) => {
|
||||
let fixed_interval_ms = parse_into_milliseconds(&fixed_interval)?;
|
||||
Ok(PrecomputedDateInterval::FixedNanoseconds(
|
||||
fixed_interval_ms * 1_000_000,
|
||||
))
|
||||
}
|
||||
(None, Some(calendar_interval)) => {
|
||||
Ok(PrecomputedDateInterval::Calendar(calendar_interval))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The after key projected to the u64 column space
|
||||
///
|
||||
/// Some column types (term, IP) might not have an exact representation of the
|
||||
/// specified after key
|
||||
#[derive(Debug)]
|
||||
pub enum PrecomputedAfterKey {
|
||||
/// The after key could be exactly represented in the column space.
|
||||
Exact(u64),
|
||||
/// The after key could not be exactly represented exactly represented, so
|
||||
/// this is the next closest one.
|
||||
Next(u64),
|
||||
/// The after key could not be represented in the column space, it is
|
||||
/// greater than all value
|
||||
AfterLast,
|
||||
}
|
||||
|
||||
impl From<TermOrdHit> for PrecomputedAfterKey {
|
||||
fn from(hit: TermOrdHit) -> Self {
|
||||
match hit {
|
||||
TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord),
|
||||
// TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is
|
||||
TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CompactHit> for PrecomputedAfterKey {
|
||||
fn from(hit: CompactHit) -> Self {
|
||||
match hit {
|
||||
CompactHit::Exact(ord) => PrecomputedAfterKey::Exact(ord as u64),
|
||||
CompactHit::Next(ord) => PrecomputedAfterKey::Next(ord as u64),
|
||||
CompactHit::AfterLast => PrecomputedAfterKey::AfterLast,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MonotonicallyMappableToU64> From<ProjectedNumber<T>> for PrecomputedAfterKey {
|
||||
fn from(num: ProjectedNumber<T>) -> Self {
|
||||
match num {
|
||||
ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()),
|
||||
ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()),
|
||||
ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// /!\ These operators only makes sense if both values are in the same column space
|
||||
impl PrecomputedAfterKey {
|
||||
pub fn equals(&self, column_value: u64) -> bool {
|
||||
match self {
|
||||
PrecomputedAfterKey::Exact(v) => *v == column_value,
|
||||
PrecomputedAfterKey::Next(_) => false,
|
||||
PrecomputedAfterKey::AfterLast => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gt(&self, column_value: u64) -> bool {
|
||||
match self {
|
||||
PrecomputedAfterKey::Exact(v) => *v > column_value,
|
||||
PrecomputedAfterKey::Next(v) => *v > column_value,
|
||||
PrecomputedAfterKey::AfterLast => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lt(&self, column_value: u64) -> bool {
|
||||
match self {
|
||||
PrecomputedAfterKey::Exact(v) => *v < column_value,
|
||||
// a value equal to the next is greater than the after key
|
||||
PrecomputedAfterKey::Next(v) => *v <= column_value,
|
||||
PrecomputedAfterKey::AfterLast => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn precompute_ip_addr(column: &Column<u64>, key: &Ipv6Addr) -> crate::Result<Self> {
|
||||
let compact_space_accessor = column
|
||||
.values
|
||||
.clone()
|
||||
.downcast_arc::<CompactSpaceU64Accessor>()
|
||||
.map_err(|_| {
|
||||
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
|
||||
"type mismatch: could not downcast to CompactSpaceU64Accessor".to_string(),
|
||||
))
|
||||
})?;
|
||||
let ip_u128 = key.to_bits();
|
||||
let ip_next_compact = compact_space_accessor.u128_to_next_compact(ip_u128);
|
||||
Ok(ip_next_compact.into())
|
||||
}
|
||||
|
||||
fn precompute_term_ord(
|
||||
str_dict_column: &Option<StrColumn>,
|
||||
key: &str,
|
||||
field: &str,
|
||||
) -> crate::Result<Self> {
|
||||
let dict = str_dict_column
|
||||
.as_ref()
|
||||
.expect("dictionary missing for str accessor")
|
||||
.dictionary();
|
||||
let next_ord = dict.term_ord_or_next(key).map_err(|_| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"failed to lookup after_key '{}' for field '{}'",
|
||||
key, field
|
||||
))
|
||||
})?;
|
||||
Ok(next_ord.into())
|
||||
}
|
||||
|
||||
/// Projects the after key into the column space of the given accessor.
|
||||
///
|
||||
/// The computed after key will not take care of skipping entire columns
|
||||
/// when the after key type is ordered after the accessor's type, that
|
||||
/// should be performed earlier.
|
||||
pub fn precompute(
|
||||
composite_accessor: &CompositeAccessor,
|
||||
source_after_key: &CompositeIntermediateKey,
|
||||
field: &str,
|
||||
missing_order: MissingOrder,
|
||||
order: Order,
|
||||
) -> crate::Result<Self> {
|
||||
use CompositeIntermediateKey as CIKey;
|
||||
let precomputed_key = match (composite_accessor.column_type, source_after_key) {
|
||||
(ColumnType::Bytes, _) => panic!("unsupported"),
|
||||
// null after key
|
||||
(_, CIKey::Null) => precompute_missing_after_key(false, missing_order, order),
|
||||
// numerical
|
||||
(ColumnType::I64, CIKey::I64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
|
||||
(ColumnType::I64, CIKey::U64(k)) => num_proj::u64_to_i64(*k).into(),
|
||||
(ColumnType::I64, CIKey::F64(k)) => num_proj::f64_to_i64(*k).into(),
|
||||
(ColumnType::U64, CIKey::I64(k)) => num_proj::i64_to_u64(*k).into(),
|
||||
(ColumnType::U64, CIKey::U64(k)) => PrecomputedAfterKey::Exact(*k),
|
||||
(ColumnType::U64, CIKey::F64(k)) => num_proj::f64_to_u64(*k).into(),
|
||||
(ColumnType::F64, CIKey::I64(k)) => num_proj::i64_to_f64(*k).into(),
|
||||
(ColumnType::F64, CIKey::U64(k)) => num_proj::u64_to_f64(*k).into(),
|
||||
(ColumnType::F64, CIKey::F64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
|
||||
// boolean
|
||||
(ColumnType::Bool, CIKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()),
|
||||
// string
|
||||
(ColumnType::Str, CIKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord(
|
||||
&composite_accessor.str_dict_column,
|
||||
key,
|
||||
field,
|
||||
)?,
|
||||
// date time
|
||||
(ColumnType::DateTime, CIKey::DateTime(key)) => {
|
||||
PrecomputedAfterKey::Exact(key.to_u64())
|
||||
}
|
||||
// ip address
|
||||
(ColumnType::IpAddr, CIKey::IpAddr(key)) => {
|
||||
PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key)?
|
||||
}
|
||||
// assume the column's type is ordered after the after_key's type
|
||||
_ => PrecomputedAfterKey::keep_all(order),
|
||||
};
|
||||
Ok(precomputed_key)
|
||||
}
|
||||
|
||||
fn keep_all(order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => PrecomputedAfterKey::Next(0),
|
||||
Order::Desc => PrecomputedAfterKey::Next(u64::MAX),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
use time::convert::{Day, Nanosecond};
|
||||
use time::{Time, UtcDateTime};
|
||||
|
||||
const NS_IN_DAY: i64 = Nanosecond::per_t::<i128>(Day) as i64;
|
||||
|
||||
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
|
||||
/// year (January 1st at midnight UTC).
|
||||
pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result<i64> {
|
||||
year_bucket_using_time_crate(timestamp_ns).map_err(|e| {
|
||||
crate::TantivyError::InvalidArgument(format!(
|
||||
"Failed to compute year bucket for timestamp {}: {}",
|
||||
timestamp_ns,
|
||||
e.to_string()
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
|
||||
/// month (1st at midnight UTC).
|
||||
pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result<i64> {
|
||||
month_bucket_using_time_crate(timestamp_ns).map_err(|e| {
|
||||
crate::TantivyError::InvalidArgument(format!(
|
||||
"Failed to compute month bucket for timestamp {}: {}",
|
||||
timestamp_ns,
|
||||
e.to_string()
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
|
||||
/// week (Monday at midnight UTC).
|
||||
pub(super) fn week_bucket(timestamp_ns: i64) -> i64 {
|
||||
// 1970-01-01 was a Thursday (weekday = 4)
|
||||
let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY);
|
||||
// Find the weekday: 0=Monday, ..., 6=Sunday
|
||||
let weekday = (days_since_epoch + 3).rem_euclid(7);
|
||||
let monday_days_since_epoch = days_since_epoch - weekday;
|
||||
monday_days_since_epoch * NS_IN_DAY
|
||||
}
|
||||
|
||||
fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
|
||||
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
|
||||
.replace_ordinal(1)?
|
||||
.replace_time(Time::MIDNIGHT)
|
||||
.unix_timestamp_nanos();
|
||||
Ok(timestamp_ns as i64)
|
||||
}
|
||||
|
||||
fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
|
||||
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
|
||||
.replace_day(1)?
|
||||
.replace_time(Time::MIDNIGHT)
|
||||
.unix_timestamp_nanos();
|
||||
Ok(timestamp_ns as i64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::i64;
|
||||
|
||||
use time::format_description::well_known::Iso8601;
|
||||
use time::UtcDateTime;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn ts_ns(iso: &str) -> i64 {
|
||||
UtcDateTime::parse(iso, &Iso8601::DEFAULT)
|
||||
.unwrap()
|
||||
.unix_timestamp_nanos() as i64
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_year_bucket() {
|
||||
let ts = ts_ns("1970-01-01T00:00:00Z");
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-06-01T10:00:01.010Z");
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("2010-12-31T23:59:59.999999999Z");
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("2010-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1972-06-01T00:10:00Z");
|
||||
let res = try_year_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("1972-01-01T00:00:00Z"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_month_bucket() {
|
||||
let ts = ts_ns("1970-01-15T00:00:00Z");
|
||||
let res = try_month_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-02-01T00:00:00Z");
|
||||
let res = try_month_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("1970-02-01T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("2000-01-31T23:59:59.999999999Z");
|
||||
let res = try_month_bucket(ts).unwrap();
|
||||
assert_eq!(res, ts_ns("2000-01-01T00:00:00Z"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_week_bucket() {
|
||||
let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("2025-10-13T00:00:00Z"));
|
||||
|
||||
let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday
|
||||
let res = week_bucket(ts);
|
||||
assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative
|
||||
}
|
||||
}
|
||||
@@ -1,595 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{
|
||||
Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
|
||||
NumericalValue, StrColumn,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::composite::accessors::{
|
||||
CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval,
|
||||
};
|
||||
use crate::aggregation::bucket::composite::calendar_interval;
|
||||
use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE};
|
||||
use crate::aggregation::bucket::{
|
||||
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
|
||||
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CompositeBucketCollector {
|
||||
count: u32,
|
||||
}
|
||||
|
||||
impl CompositeBucketCollector {
|
||||
fn new() -> Self {
|
||||
CompositeBucketCollector { count: 0 }
|
||||
}
|
||||
#[inline]
|
||||
fn collect(&mut self) {
|
||||
self.count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// The value is represented as a tuple of:
|
||||
/// - the column index or missing value sentinel
|
||||
/// - if the value is present, store the accessor index + 1
|
||||
/// - if the value is missing, store 0 (for missing first) or u8::MAX (for missing last)
|
||||
/// - the fast field value u64 representation
|
||||
/// - 0 if the field is missing
|
||||
/// - regular u64 repr if the ordering is ascending
|
||||
/// - bitwise NOT of the u64 repr if the ordering is descending
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
|
||||
struct InternalValueRepr(u8, u64);
|
||||
|
||||
impl InternalValueRepr {
|
||||
#[inline]
|
||||
fn new_term(raw: u64, accessor_idx: u8, order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => InternalValueRepr(accessor_idx + 1, raw),
|
||||
Order::Desc => InternalValueRepr(accessor_idx + 1, !raw),
|
||||
}
|
||||
}
|
||||
/// For histogram, the source column does not matter
|
||||
#[inline]
|
||||
fn new_histogram(raw: u64, order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => InternalValueRepr(1, raw),
|
||||
Order::Desc => InternalValueRepr(1, !raw),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn new_missing(order: Order, missing_order: MissingOrder) -> Self {
|
||||
let column_idx = match (missing_order, order) {
|
||||
(MissingOrder::First, _) => 0,
|
||||
(MissingOrder::Last, _) => u8::MAX,
|
||||
(MissingOrder::Default, Order::Asc) => 0,
|
||||
(MissingOrder::Default, Order::Desc) => u8::MAX,
|
||||
};
|
||||
InternalValueRepr(column_idx, 0)
|
||||
}
|
||||
#[inline]
|
||||
fn decode(self, order: Order) -> Option<(u8, u64)> {
|
||||
if self.0 == u8::MAX || self.0 == 0 {
|
||||
return None;
|
||||
}
|
||||
match order {
|
||||
Order::Asc => Some((self.0 - 1, self.1)),
|
||||
Order::Desc => Some((self.0 - 1, !self.1)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and
|
||||
/// does a conversion to the correct datatype.
|
||||
#[derive(Debug)]
|
||||
pub struct SegmentCompositeCollector {
|
||||
buckets: DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCompositeCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
_parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data
|
||||
.get_composite_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
|
||||
let buckets = self.into_intermediate_bucket_result(agg_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
_parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
|
||||
|
||||
for doc in docs {
|
||||
let mut sub_level_values = SmallVec::new();
|
||||
recursive_key_visitor(
|
||||
*doc,
|
||||
agg_data,
|
||||
&composite_agg_data,
|
||||
0,
|
||||
&mut sub_level_values,
|
||||
&mut self.buckets,
|
||||
true,
|
||||
)?;
|
||||
}
|
||||
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data.context.limits.add_memory_consumed(mem_delta)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
_max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCompositeCollector {
|
||||
fn get_memory_consumption(&self) -> u64 {
|
||||
self.buckets.memory_consumption()
|
||||
}
|
||||
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
validate_req(req_data, node.idx_in_req_data)?;
|
||||
|
||||
if !node.children.is_empty() {
|
||||
let _sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
}
|
||||
|
||||
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
|
||||
Ok(SegmentCompositeCollector {
|
||||
buckets: DynArrayHeapMap::try_new(composite_req_data.req.sources.len())?,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn into_intermediate_bucket_result(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateCompositeBucketResult> {
|
||||
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
|
||||
Default::default();
|
||||
dict.reserve(self.buckets.size());
|
||||
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
|
||||
let buckets = std::mem::replace(
|
||||
&mut self.buckets,
|
||||
DynArrayHeapMap::try_new(composite_data.req.sources.len())
|
||||
.expect("already validated source count"),
|
||||
);
|
||||
for (key_internal_repr, agg) in buckets.into_iter() {
|
||||
let key = resolve_key(&key_internal_repr, composite_data)?;
|
||||
let sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
|
||||
dict.insert(
|
||||
key,
|
||||
IntermediateCompositeBucketEntry {
|
||||
doc_count: agg.count,
|
||||
sub_aggregation: sub_aggregation_res,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(IntermediateCompositeBucketResult {
|
||||
entries: dict,
|
||||
target_size: composite_data.req.size,
|
||||
orders: composite_data
|
||||
.req
|
||||
.sources
|
||||
.iter()
|
||||
.map(|source| match source {
|
||||
CompositeAggregationSource::Terms(t) => (t.order, t.missing_order),
|
||||
CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order),
|
||||
CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
|
||||
let composite_data = req_data.get_composite_req_data(accessor_idx);
|
||||
let req = &composite_data.req;
|
||||
if req.sources.is_empty() {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"composite aggregation must have at least one source".to_string(),
|
||||
));
|
||||
}
|
||||
if req.size == 0 {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"composite aggregation 'size' must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| {
|
||||
item.accessors
|
||||
.iter()
|
||||
.map(|a| a.column_type)
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
for column_types in column_types_for_sources {
|
||||
if column_types.len() > MAX_DYN_ARRAY_SIZE {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources",
|
||||
)));
|
||||
}
|
||||
if column_types.contains(&ColumnType::Bytes) {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"composite aggregation does not support 'bytes' field type".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_bucket_with_limit(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
composite_agg_data: &CompositeAggReqData,
|
||||
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
|
||||
key: &[InternalValueRepr],
|
||||
) -> crate::Result<()> {
|
||||
if (buckets.size() as u32) < composite_agg_data.req.size {
|
||||
buckets
|
||||
.get_or_insert_with(key, CompositeBucketCollector::new)
|
||||
.collect();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(entry) = buckets.get_mut(key) {
|
||||
entry.collect();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(highest_key) = buckets.peek_highest() {
|
||||
if key < highest_key {
|
||||
buckets.evict_highest();
|
||||
buckets
|
||||
.get_or_insert_with(key, CompositeBucketCollector::new)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = agg_data;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Converts the composite key from its internal column space representation
|
||||
/// (segment specific) into its intermediate form.
|
||||
fn resolve_key(
|
||||
internal_key: &[InternalValueRepr],
|
||||
agg_data: &CompositeAggReqData,
|
||||
) -> crate::Result<Vec<CompositeIntermediateKey>> {
|
||||
internal_key
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, val)| {
|
||||
resolve_internal_value_repr(
|
||||
*val,
|
||||
&agg_data.req.sources[idx],
|
||||
&agg_data.composite_accessors[idx].accessors,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn resolve_internal_value_repr(
|
||||
internal_value_repr: InternalValueRepr,
|
||||
source: &CompositeAggregationSource,
|
||||
composite_accessors: &[CompositeAccessor],
|
||||
) -> crate::Result<CompositeIntermediateKey> {
|
||||
let decoded_value_opt = match source {
|
||||
CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order),
|
||||
CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order),
|
||||
CompositeAggregationSource::DateHistogram(source) => {
|
||||
internal_value_repr.decode(source.order)
|
||||
}
|
||||
};
|
||||
let Some((decoded_accessor_idx, val)) = decoded_value_opt else {
|
||||
return Ok(CompositeIntermediateKey::Null);
|
||||
};
|
||||
let key = match source {
|
||||
CompositeAggregationSource::Terms(_) => {
|
||||
let CompositeAccessor {
|
||||
column_type,
|
||||
str_dict_column,
|
||||
column,
|
||||
..
|
||||
} = &composite_accessors[decoded_accessor_idx as usize];
|
||||
resolve_term(val, column_type, str_dict_column, column)?
|
||||
}
|
||||
CompositeAggregationSource::Histogram(source) => {
|
||||
CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval)
|
||||
}
|
||||
CompositeAggregationSource::DateHistogram(_) => {
|
||||
CompositeIntermediateKey::DateTime(i64::from_u64(val))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn resolve_term(
|
||||
val: u64,
|
||||
column_type: &ColumnType,
|
||||
str_dict_column: &Option<StrColumn>,
|
||||
column: &Column,
|
||||
) -> crate::Result<CompositeIntermediateKey> {
|
||||
let key = if *column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let term_dict = str_dict_column
|
||||
.as_ref()
|
||||
.map(|el| el.dictionary())
|
||||
.unwrap_or_else(|| &fallback_dict);
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
term_dict.ord_to_term(val, &mut buffer)?;
|
||||
CompositeIntermediateKey::Str(
|
||||
String::from_utf8(buffer.to_vec()).expect("could not convert to String"),
|
||||
)
|
||||
} else if *column_type == ColumnType::DateTime {
|
||||
let val = i64::from_u64(val);
|
||||
CompositeIntermediateKey::DateTime(val)
|
||||
} else if *column_type == ColumnType::Bool {
|
||||
let val = bool::from_u64(val);
|
||||
CompositeIntermediateKey::Bool(val)
|
||||
} else if *column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = column
|
||||
.values
|
||||
.clone()
|
||||
.downcast_arc::<CompactSpaceU64Accessor>()
|
||||
.map_err(|_| {
|
||||
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
|
||||
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
|
||||
))
|
||||
})?;
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
let val = Ipv6Addr::from_u128(val);
|
||||
CompositeIntermediateKey::IpAddr(val)
|
||||
} else {
|
||||
if *column_type == ColumnType::U64 {
|
||||
CompositeIntermediateKey::U64(val)
|
||||
} else if *column_type == ColumnType::I64 {
|
||||
CompositeIntermediateKey::I64(i64::from_u64(val))
|
||||
} else {
|
||||
let val = f64::from_u64(val);
|
||||
let val: NumericalValue = val.into();
|
||||
|
||||
match val.normalize() {
|
||||
NumericalValue::U64(val) => CompositeIntermediateKey::U64(val),
|
||||
NumericalValue::I64(val) => CompositeIntermediateKey::I64(val),
|
||||
NumericalValue::F64(val) => CompositeIntermediateKey::F64(val),
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
/// Depth-first walk of the accessors to build the composite key combinations
|
||||
/// and update the buckets.
|
||||
fn recursive_key_visitor(
|
||||
doc_id: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
composite_agg_data: &CompositeAggReqData,
|
||||
source_idx_for_recursion: usize,
|
||||
sub_level_values: &mut SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
|
||||
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
|
||||
is_on_after_key: bool,
|
||||
) -> crate::Result<()> {
|
||||
if source_idx_for_recursion == composite_agg_data.req.sources.len() {
|
||||
if !is_on_after_key {
|
||||
collect_bucket_with_limit(
|
||||
agg_data,
|
||||
composite_agg_data,
|
||||
buckets,
|
||||
sub_level_values,
|
||||
)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let current_level_accessors = &composite_agg_data.composite_accessors[source_idx_for_recursion];
|
||||
let current_level_source = &composite_agg_data.req.sources[source_idx_for_recursion];
|
||||
let mut missing = true;
|
||||
for (accessor_idx, accessor) in current_level_accessors.accessors.iter().enumerate() {
|
||||
let values = accessor.column.values_for_doc(doc_id);
|
||||
for value in values {
|
||||
missing = false;
|
||||
match current_level_source {
|
||||
CompositeAggregationSource::Terms(_) => {
|
||||
let preceeds_after_key_type =
|
||||
accessor_idx < current_level_accessors.after_key_accessor_idx;
|
||||
if is_on_after_key && preceeds_after_key_type {
|
||||
break;
|
||||
}
|
||||
let matches_after_key_type =
|
||||
accessor_idx == current_level_accessors.after_key_accessor_idx;
|
||||
|
||||
if matches_after_key_type && is_on_after_key {
|
||||
let should_skip = match current_level_source.order() {
|
||||
Order::Asc => current_level_accessors.after_key.gt(value),
|
||||
Order::Desc => current_level_accessors.after_key.lt(value),
|
||||
};
|
||||
if should_skip {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
sub_level_values.push(InternalValueRepr::new_term(
|
||||
value,
|
||||
accessor_idx as u8,
|
||||
current_level_source.order(),
|
||||
));
|
||||
let still_on_after_key =
|
||||
matches_after_key_type && current_level_accessors.after_key.equals(value);
|
||||
recursive_key_visitor(
|
||||
doc_id,
|
||||
agg_data,
|
||||
composite_agg_data,
|
||||
source_idx_for_recursion + 1,
|
||||
sub_level_values,
|
||||
buckets,
|
||||
is_on_after_key && still_on_after_key,
|
||||
)?;
|
||||
sub_level_values.pop();
|
||||
}
|
||||
CompositeAggregationSource::Histogram(source) => {
|
||||
let float_value = match accessor.column_type {
|
||||
ColumnType::U64 => value as f64,
|
||||
ColumnType::I64 => i64::from_u64(value) as f64,
|
||||
ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000.,
|
||||
ColumnType::F64 => f64::from_u64(value),
|
||||
_ => {
|
||||
panic!(
|
||||
"unexpected type {:?}. This should not happen",
|
||||
accessor.column_type
|
||||
)
|
||||
}
|
||||
};
|
||||
let bucket_index = (float_value / source.interval).floor() as i64;
|
||||
let bucket_value = i64::to_u64(bucket_index);
|
||||
if is_on_after_key {
|
||||
let should_skip = match current_level_source.order() {
|
||||
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
|
||||
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
|
||||
};
|
||||
if should_skip {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
sub_level_values.push(InternalValueRepr::new_histogram(
|
||||
bucket_value,
|
||||
current_level_source.order(),
|
||||
));
|
||||
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
|
||||
recursive_key_visitor(
|
||||
doc_id,
|
||||
agg_data,
|
||||
composite_agg_data,
|
||||
source_idx_for_recursion + 1,
|
||||
sub_level_values,
|
||||
buckets,
|
||||
is_on_after_key && still_on_after_key,
|
||||
)?;
|
||||
sub_level_values.pop();
|
||||
}
|
||||
CompositeAggregationSource::DateHistogram(_) => {
|
||||
let value_ns = match accessor.column_type {
|
||||
ColumnType::DateTime => i64::from_u64(value),
|
||||
_ => {
|
||||
panic!(
|
||||
"unexpected type {:?}. This should not happen",
|
||||
accessor.column_type
|
||||
)
|
||||
}
|
||||
};
|
||||
let bucket_index = match accessor.date_histogram_interval {
|
||||
PrecomputedDateInterval::FixedNanoseconds(fixed_interval_ns) => {
|
||||
(value_ns / fixed_interval_ns) * fixed_interval_ns
|
||||
}
|
||||
PrecomputedDateInterval::Calendar(CalendarInterval::Year) => {
|
||||
calendar_interval::try_year_bucket(value_ns)?
|
||||
}
|
||||
PrecomputedDateInterval::Calendar(CalendarInterval::Month) => {
|
||||
calendar_interval::try_month_bucket(value_ns)?
|
||||
}
|
||||
PrecomputedDateInterval::Calendar(CalendarInterval::Week) => {
|
||||
calendar_interval::week_bucket(value_ns)
|
||||
}
|
||||
PrecomputedDateInterval::NotApplicable => {
|
||||
panic!("interval not precomputed for date histogram source")
|
||||
}
|
||||
};
|
||||
let bucket_value = i64::to_u64(bucket_index);
|
||||
if is_on_after_key {
|
||||
let should_skip = match current_level_source.order() {
|
||||
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
|
||||
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
|
||||
};
|
||||
if should_skip {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
sub_level_values.push(InternalValueRepr::new_histogram(
|
||||
bucket_value,
|
||||
current_level_source.order(),
|
||||
));
|
||||
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
|
||||
recursive_key_visitor(
|
||||
doc_id,
|
||||
agg_data,
|
||||
composite_agg_data,
|
||||
source_idx_for_recursion + 1,
|
||||
sub_level_values,
|
||||
buckets,
|
||||
is_on_after_key && still_on_after_key,
|
||||
)?;
|
||||
sub_level_values.pop();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
if missing && current_level_source.missing_bucket() {
|
||||
if is_on_after_key && current_level_accessors.skip_missing {
|
||||
return Ok(());
|
||||
}
|
||||
sub_level_values.push(InternalValueRepr::new_missing(
|
||||
current_level_source.order(),
|
||||
current_level_source.missing_order(),
|
||||
));
|
||||
recursive_key_visitor(
|
||||
doc_id,
|
||||
agg_data,
|
||||
composite_agg_data,
|
||||
source_idx_for_recursion + 1,
|
||||
sub_level_values,
|
||||
buckets,
|
||||
is_on_after_key && current_level_accessors.is_after_key_explicit_missing,
|
||||
)?;
|
||||
sub_level_values.pop();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::Hash;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Map backed by a hash map for fast access and a binary heap to track the
|
||||
/// highest key. The key is an array of fixed size S.
|
||||
#[derive(Clone, Debug)]
|
||||
struct ArrayHeapMap<K: Ord, V, const S: usize> {
|
||||
pub(crate) buckets: FxHashMap<[K; S], V>,
|
||||
pub(crate) heap: BinaryHeap<[K; S]>,
|
||||
}
|
||||
|
||||
impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> {
|
||||
fn default() -> Self {
|
||||
ArrayHeapMap {
|
||||
buckets: FxHashMap::default(),
|
||||
heap: BinaryHeap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> {
|
||||
/// Panics if the length of `key` is not S.
|
||||
fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
|
||||
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
|
||||
self.buckets.entry(key_array.clone()).or_insert_with(|| {
|
||||
self.heap.push(key_array.clone());
|
||||
f()
|
||||
})
|
||||
}
|
||||
|
||||
/// Panics if the length of `key` is not S.
|
||||
fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
|
||||
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
|
||||
self.buckets.get_mut(key_array)
|
||||
}
|
||||
|
||||
fn peek_highest(&self) -> Option<&[K]> {
|
||||
self.heap.peek().map(|k_array| k_array.as_slice())
|
||||
}
|
||||
|
||||
fn evict_highest(&mut self) {
|
||||
if let Some(highest) = self.heap.pop() {
|
||||
self.buckets.remove(&highest);
|
||||
}
|
||||
}
|
||||
|
||||
fn memory_consumption(&self) -> u64 {
|
||||
let key_size = std::mem::size_of::<[K; S]>();
|
||||
let map_size = (key_size + std::mem::size_of::<V>()) * self.buckets.capacity();
|
||||
let heap_size = key_size * self.heap.capacity();
|
||||
(map_size + heap_size) as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> {
|
||||
fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> {
|
||||
Box::new(
|
||||
self.buckets
|
||||
.into_iter()
|
||||
.map(|(k, v)| (SmallVec::from_slice(&k), v)),
|
||||
)
|
||||
}
|
||||
|
||||
fn values_mut<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a mut V> + 'a> {
|
||||
Box::new(self.buckets.values_mut())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16;
|
||||
const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1;
|
||||
|
||||
/// A map optimized for memory footprint, fast access and efficient eviction of
|
||||
/// the highest key.
|
||||
///
|
||||
/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given
|
||||
/// instance the key size is fixed. This allows to avoid heap allocations for the
|
||||
/// keys.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>);
|
||||
|
||||
/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size.
|
||||
#[derive(Clone, Debug)]
|
||||
enum DynArrayHeapMapInner<K: Ord, V> {
|
||||
Dim1(ArrayHeapMap<K, V, 1>),
|
||||
Dim2(ArrayHeapMap<K, V, 2>),
|
||||
Dim3(ArrayHeapMap<K, V, 3>),
|
||||
Dim4(ArrayHeapMap<K, V, 4>),
|
||||
Dim5(ArrayHeapMap<K, V, 5>),
|
||||
Dim6(ArrayHeapMap<K, V, 6>),
|
||||
Dim7(ArrayHeapMap<K, V, 7>),
|
||||
Dim8(ArrayHeapMap<K, V, 8>),
|
||||
Dim9(ArrayHeapMap<K, V, 9>),
|
||||
Dim10(ArrayHeapMap<K, V, 10>),
|
||||
Dim11(ArrayHeapMap<K, V, 11>),
|
||||
Dim12(ArrayHeapMap<K, V, 12>),
|
||||
Dim13(ArrayHeapMap<K, V, 13>),
|
||||
Dim14(ArrayHeapMap<K, V, 14>),
|
||||
Dim15(ArrayHeapMap<K, V, 15>),
|
||||
Dim16(ArrayHeapMap<K, V, 16>),
|
||||
}
|
||||
|
||||
impl<K: Ord, V> DynArrayHeapMap<K, V> {
|
||||
/// Creates a new heap map with dynamic array keys of size `key_dimension`.
|
||||
pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> {
|
||||
let inner = match key_dimension {
|
||||
0 => {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"DynArrayHeapMap dimension must be at least 1".to_string(),
|
||||
))
|
||||
}
|
||||
1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()),
|
||||
2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()),
|
||||
3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()),
|
||||
4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()),
|
||||
5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()),
|
||||
6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()),
|
||||
7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()),
|
||||
8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()),
|
||||
9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()),
|
||||
10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()),
|
||||
11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()),
|
||||
12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()),
|
||||
13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()),
|
||||
14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()),
|
||||
15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()),
|
||||
16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()),
|
||||
MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \
|
||||
{key_dimension}",
|
||||
)))
|
||||
}
|
||||
};
|
||||
Ok(DynArrayHeapMap(inner))
|
||||
}
|
||||
|
||||
/// Number of elements in the map. This is not the dimension of the keys.
|
||||
pub(super) fn size(&self) -> usize {
|
||||
match &self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.buckets.len(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.buckets.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> {
|
||||
/// Get a mutable reference to the value corresponding to `key` or inserts a new
|
||||
/// value created by calling `f`.
|
||||
///
|
||||
/// Panics if the length of `key` does not match the key dimension of the map.
|
||||
pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
|
||||
match &mut self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the value corresponding to `key`.
|
||||
///
|
||||
/// Panics if the length of `key` does not match the key dimension of the map.
|
||||
pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
|
||||
match &mut self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.get_mut(key),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.get_mut(key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the highest key in the map.
|
||||
pub(super) fn peek_highest(&self) -> Option<&[K]> {
|
||||
match &self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.peek_highest(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.peek_highest(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the entry with the highest key from the map.
|
||||
pub(super) fn evict_highest(&mut self) {
|
||||
match &mut self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.evict_highest(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.evict_highest(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn memory_consumption(&self) -> u64 {
|
||||
match &self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> {
|
||||
/// Turns this map into an iterator over key-value pairs.
|
||||
pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> {
|
||||
match self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.into_iter(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.into_iter(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over mutable references to the values in the map.
|
||||
pub(super) fn values_mut(&mut self) -> impl Iterator<Item = &mut V> {
|
||||
match &mut self.0 {
|
||||
DynArrayHeapMapInner::Dim1(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim2(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim3(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim4(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim5(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim6(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim7(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim8(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim9(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim10(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim11(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim12(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim13(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim14(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim15(map) => map.values_mut(),
|
||||
DynArrayHeapMapInner::Dim16(map) => map.values_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dyn_array_heap_map() {
|
||||
let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap();
|
||||
// insert
|
||||
let key1 = [1u32, 2u32];
|
||||
let key2 = [2u32, 1u32];
|
||||
map.get_or_insert_with(&key1, || "a");
|
||||
map.get_or_insert_with(&key2, || "b");
|
||||
assert_eq!(map.size(), 2);
|
||||
|
||||
// evict highest
|
||||
assert_eq!(map.peek_highest(), Some(&key2[..]));
|
||||
map.evict_highest();
|
||||
assert_eq!(map.size(), 1);
|
||||
assert_eq!(map.peek_highest(), Some(&key1[..]));
|
||||
|
||||
// mutable iterator
|
||||
{
|
||||
let mut mut_iter = map.values_mut();
|
||||
let v = mut_iter.next().unwrap();
|
||||
assert_eq!(*v, "a");
|
||||
*v = "c";
|
||||
assert_eq!(mut_iter.next(), None);
|
||||
}
|
||||
|
||||
// into_iter
|
||||
let mut iter = map.into_iter();
|
||||
let (k, v) = iter.next().unwrap();
|
||||
assert_eq!(k.as_slice(), &key1);
|
||||
assert_eq!(v, "c");
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,460 +0,0 @@
|
||||
/// This modules helps comparing numerical values of different types (i64, u64
|
||||
/// and f64).
|
||||
pub(super) mod num_cmp {
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use crate::TantivyError;
|
||||
|
||||
pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result<Ordering> {
|
||||
if right_f.is_nan() {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"NaN comparison is not supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be
|
||||
// exactly represented as f64)
|
||||
if right_f < i64::MIN as f64 {
|
||||
return Ok(Ordering::Greater);
|
||||
}
|
||||
// If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot
|
||||
// be exactly represented as f64)
|
||||
if right_f >= i64::MAX as f64 {
|
||||
return Ok(Ordering::Less);
|
||||
}
|
||||
|
||||
// Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is
|
||||
// well-defined (truncation toward 0)
|
||||
let right_as_i = right_f as i64;
|
||||
|
||||
let result = match left_i.cmp(&right_as_i) {
|
||||
Ordering::Less => Ordering::Less,
|
||||
Ordering::Greater => Ordering::Greater,
|
||||
Ordering::Equal => {
|
||||
// they have the same integer part, compare the fraction
|
||||
let rem = right_f - (right_as_i as f64);
|
||||
if rem == 0.0 {
|
||||
Ordering::Equal
|
||||
} else if right_f > 0.0 {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Greater
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result<Ordering> {
|
||||
if right_f.is_nan() {
|
||||
return Err(TantivyError::InvalidArgument(
|
||||
"NaN comparison is not supported".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Negative floats are always less than any u64 >= 0
|
||||
if right_f < 0.0 {
|
||||
return Ok(Ordering::Greater);
|
||||
}
|
||||
|
||||
// If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly)
|
||||
let max_as_f = u64::MAX as f64;
|
||||
if right_f > max_as_f {
|
||||
return Ok(Ordering::Less);
|
||||
}
|
||||
|
||||
// Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined
|
||||
// (truncation toward 0)
|
||||
let right_as_u = right_f as u64;
|
||||
|
||||
let result = match left_u.cmp(&right_as_u) {
|
||||
Ordering::Less => Ordering::Less,
|
||||
Ordering::Greater => Ordering::Greater,
|
||||
Ordering::Equal => {
|
||||
// they have the same integer part, compare the fraction
|
||||
let rem = right_f - (right_as_u as f64);
|
||||
if rem == 0.0 {
|
||||
Ordering::Equal
|
||||
} else {
|
||||
Ordering::Less
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering {
|
||||
if left_i < 0 {
|
||||
Ordering::Less
|
||||
} else {
|
||||
let left_as_u = left_i as u64;
|
||||
left_as_u.cmp(&right_u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This modules helps projecting numerical values to other numerical types.
|
||||
/// When the target value space cannot exactly represent the source value, the
|
||||
/// next representable value is returned (or AfterLast if the source value is
|
||||
/// larger than the largest representable value).
|
||||
///
|
||||
/// All functions in this module assume that f64 values are not NaN.
|
||||
pub(super) mod num_proj {
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ProjectedNumber<T> {
|
||||
Exact(T),
|
||||
Next(T),
|
||||
AfterLast,
|
||||
}
|
||||
|
||||
pub fn i64_to_u64(value: i64) -> ProjectedNumber<u64> {
|
||||
if value < 0 {
|
||||
ProjectedNumber::Next(0)
|
||||
} else {
|
||||
ProjectedNumber::Exact(value as u64)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn u64_to_i64(value: u64) -> ProjectedNumber<i64> {
|
||||
if value > i64::MAX as u64 {
|
||||
ProjectedNumber::AfterLast
|
||||
} else {
|
||||
ProjectedNumber::Exact(value as i64)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn f64_to_u64(value: f64) -> ProjectedNumber<u64> {
|
||||
if value < 0.0 {
|
||||
ProjectedNumber::Next(0)
|
||||
} else if value > u64::MAX as f64 {
|
||||
ProjectedNumber::AfterLast
|
||||
} else if value.fract() == 0.0 {
|
||||
ProjectedNumber::Exact(value as u64)
|
||||
} else {
|
||||
// casting f64 to u64 truncates toward zero
|
||||
ProjectedNumber::Next(value as u64 + 1)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn f64_to_i64(value: f64) -> ProjectedNumber<i64> {
|
||||
if value < (i64::MIN as f64) {
|
||||
return ProjectedNumber::Next(i64::MIN);
|
||||
} else if value >= (i64::MAX as f64) {
|
||||
return ProjectedNumber::AfterLast;
|
||||
} else if value.fract() == 0.0 {
|
||||
ProjectedNumber::Exact(value as i64)
|
||||
} else if value > 0.0 {
|
||||
// casting f64 to i64 truncates toward zero
|
||||
ProjectedNumber::Next(value as i64 + 1)
|
||||
} else {
|
||||
ProjectedNumber::Next(value as i64)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn i64_to_f64(value: i64) -> ProjectedNumber<f64> {
|
||||
let value_f = value as f64;
|
||||
let k_roundtrip = value_f as i64;
|
||||
if k_roundtrip == value {
|
||||
// between -2^53 and 2^53 all i64 are exactly represented as f64
|
||||
ProjectedNumber::Exact(value_f)
|
||||
} else {
|
||||
// for very large/small i64 values, it is approximated to the closest f64
|
||||
if k_roundtrip > value {
|
||||
ProjectedNumber::Next(value_f)
|
||||
} else {
|
||||
ProjectedNumber::Next(value_f.next_up())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn u64_to_f64(value: u64) -> ProjectedNumber<f64> {
|
||||
let value_f = value as f64;
|
||||
let k_roundtrip = value_f as u64;
|
||||
if k_roundtrip == value {
|
||||
// between 0 and 2^53 all u64 are exactly represented as f64
|
||||
ProjectedNumber::Exact(value_f)
|
||||
} else if k_roundtrip > value {
|
||||
ProjectedNumber::Next(value_f)
|
||||
} else {
|
||||
ProjectedNumber::Next(value_f.next_up())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod num_cmp_tests {
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use super::num_cmp::*;
|
||||
|
||||
#[test]
|
||||
fn test_cmp_u64_f64() {
|
||||
// Basic comparisons
|
||||
assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal);
|
||||
assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal);
|
||||
assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less);
|
||||
|
||||
// Negative float values should always be less than any u64
|
||||
assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater);
|
||||
|
||||
// Tests with extreme values
|
||||
assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less);
|
||||
|
||||
// Precision edge cases: large u64 that loses precision when converted to f64
|
||||
// => 2^54, exactly represented as f64
|
||||
let large_f64 = 18_014_398_509_481_984.0;
|
||||
let large_u64 = 18_014_398_509_481_984;
|
||||
// prove that large_u64 is exactly represented as f64
|
||||
assert_eq!(large_u64 as f64, large_f64);
|
||||
assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal);
|
||||
// => (2^54 + 1) cannot be exactly represented in f64
|
||||
let large_u64_plus_1 = 18_014_398_509_481_985;
|
||||
// prove that it is represented as f64 by large_f64
|
||||
assert_eq!(large_u64_plus_1 as f64, large_f64);
|
||||
assert_eq!(
|
||||
cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(),
|
||||
Ordering::Greater
|
||||
);
|
||||
// => (2^54 - 1) cannot be exactly represented in f64
|
||||
let large_u64_minus_1 = 18_014_398_509_481_983;
|
||||
// prove that it is also represented as f64 by large_f64
|
||||
assert_eq!(large_u64_minus_1 as f64, large_f64);
|
||||
assert_eq!(
|
||||
cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(),
|
||||
Ordering::Less
|
||||
);
|
||||
|
||||
// NaN comparison results in an error
|
||||
assert!(cmp_u64_f64(0, f64::NAN).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cmp_i64_f64() {
|
||||
// Basic comparisons
|
||||
assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal);
|
||||
assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal);
|
||||
assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater);
|
||||
assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal);
|
||||
|
||||
// Tests with extreme values
|
||||
assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less);
|
||||
assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater);
|
||||
|
||||
// Precision edge cases: large i64 that loses precision when converted to f64
|
||||
// => 2^54, exactly represented as f64
|
||||
let large_f64 = 18_014_398_509_481_984.0;
|
||||
let large_i64 = 18_014_398_509_481_984;
|
||||
// prove that large_i64 is exactly represented as f64
|
||||
assert_eq!(large_i64 as f64, large_f64);
|
||||
assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal);
|
||||
// => (1_i64 << 54) + 1 cannot be exactly represented in f64
|
||||
let large_i64_plus_1 = 18_014_398_509_481_985;
|
||||
// prove that it is represented as f64 by large_f64
|
||||
assert_eq!(large_i64_plus_1 as f64, large_f64);
|
||||
assert_eq!(
|
||||
cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(),
|
||||
Ordering::Greater
|
||||
);
|
||||
// => (1_i64 << 54) - 1 cannot be exactly represented in f64
|
||||
let large_i64_minus_1 = 18_014_398_509_481_983;
|
||||
// prove that it is also represented as f64 by large_f64
|
||||
assert_eq!(large_i64_minus_1 as f64, large_f64);
|
||||
assert_eq!(
|
||||
cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(),
|
||||
Ordering::Less
|
||||
);
|
||||
|
||||
// Same precision edge case but with negative values
|
||||
// => -2^54, exactly represented as f64
|
||||
let large_neg_f64 = -18_014_398_509_481_984.0;
|
||||
let large_neg_i64 = -18_014_398_509_481_984;
|
||||
// prove that large_neg_i64 is exactly represented as f64
|
||||
assert_eq!(large_neg_i64 as f64, large_neg_f64);
|
||||
assert_eq!(
|
||||
cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(),
|
||||
Ordering::Equal
|
||||
);
|
||||
// => (-2^54 + 1) cannot be exactly represented in f64
|
||||
let large_neg_i64_plus_1 = -18_014_398_509_481_985;
|
||||
// prove that it is represented as f64 by large_neg_f64
|
||||
assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64);
|
||||
assert_eq!(
|
||||
cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(),
|
||||
Ordering::Less
|
||||
);
|
||||
// => (-2^54 - 1) cannot be exactly represented in f64
|
||||
let large_neg_i64_minus_1 = -18_014_398_509_481_983;
|
||||
// prove that it is also represented as f64 by large_neg_f64
|
||||
assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64);
|
||||
assert_eq!(
|
||||
cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(),
|
||||
Ordering::Greater
|
||||
);
|
||||
|
||||
// NaN comparison results in an error
|
||||
assert!(cmp_i64_f64(0, f64::NAN).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cmp_i64_u64() {
|
||||
// Test with negative i64 values (should always be less than any u64)
|
||||
assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less);
|
||||
assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less);
|
||||
assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less);
|
||||
|
||||
// Test with positive i64 values
|
||||
assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal);
|
||||
assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater);
|
||||
assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal);
|
||||
assert_eq!(cmp_i64_u64(0, 1), Ordering::Less);
|
||||
assert_eq!(cmp_i64_u64(5, 10), Ordering::Less);
|
||||
assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater);
|
||||
|
||||
// Test with values near i64::MAX and u64 conversion
|
||||
assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal);
|
||||
assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less);
|
||||
assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod num_proj_tests {
|
||||
use super::num_proj::{self, ProjectedNumber};
|
||||
|
||||
#[test]
|
||||
fn test_i64_to_u64() {
|
||||
assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0));
|
||||
assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0));
|
||||
assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0));
|
||||
assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42));
|
||||
assert_eq!(
|
||||
num_proj::i64_to_u64(i64::MAX),
|
||||
ProjectedNumber::Exact(i64::MAX as u64)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_to_i64() {
|
||||
assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0));
|
||||
assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42));
|
||||
assert_eq!(
|
||||
num_proj::u64_to_i64(i64::MAX as u64),
|
||||
ProjectedNumber::Exact(i64::MAX)
|
||||
);
|
||||
assert_eq!(
|
||||
num_proj::u64_to_i64((i64::MAX as u64) + 1),
|
||||
ProjectedNumber::AfterLast
|
||||
);
|
||||
assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_f64_to_u64() {
|
||||
assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0));
|
||||
assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0));
|
||||
assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast);
|
||||
assert_eq!(
|
||||
num_proj::f64_to_u64(f64::INFINITY),
|
||||
ProjectedNumber::AfterLast
|
||||
);
|
||||
assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0));
|
||||
assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42));
|
||||
assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1));
|
||||
assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_f64_to_i64() {
|
||||
assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN));
|
||||
assert_eq!(
|
||||
num_proj::f64_to_i64(f64::NEG_INFINITY),
|
||||
ProjectedNumber::Next(i64::MIN)
|
||||
);
|
||||
assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast);
|
||||
assert_eq!(
|
||||
num_proj::f64_to_i64(f64::INFINITY),
|
||||
ProjectedNumber::AfterLast
|
||||
);
|
||||
assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0));
|
||||
assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42));
|
||||
assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42));
|
||||
assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1));
|
||||
assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43));
|
||||
assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0));
|
||||
assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_i64_to_f64() {
|
||||
assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0));
|
||||
assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0));
|
||||
assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0));
|
||||
|
||||
let max_exact = 9_007_199_254_740_992; // 2^53
|
||||
assert_eq!(
|
||||
num_proj::i64_to_f64(max_exact),
|
||||
ProjectedNumber::Exact(max_exact as f64)
|
||||
);
|
||||
|
||||
// Test values that cannot be exactly represented as f64 (integers above 2^53)
|
||||
let large_i64 = 9_007_199_254_740_993; // 2^53 + 1
|
||||
let closest_f64 = 9_007_199_254_740_992.0;
|
||||
assert_eq!(large_i64 as f64, closest_f64);
|
||||
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) {
|
||||
// Verify that the returned float is different from the direct cast
|
||||
assert!(val > closest_f64);
|
||||
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
|
||||
} else {
|
||||
panic!("Expected ProjectedNumber::Next for large_i64");
|
||||
}
|
||||
|
||||
// Test with very large negative value
|
||||
let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1)
|
||||
let closest_neg_f64 = -9_007_199_254_740_992.0;
|
||||
assert_eq!(large_neg_i64 as f64, closest_neg_f64);
|
||||
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) {
|
||||
// Verify that the returned float is the closest representable f64
|
||||
assert_eq!(val, closest_neg_f64);
|
||||
} else {
|
||||
panic!("Expected ProjectedNumber::Next for large_neg_i64");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_to_f64() {
|
||||
assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0));
|
||||
assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0));
|
||||
|
||||
// Test the largest u64 value that can be exactly represented as f64 (2^53)
|
||||
let max_exact = 9_007_199_254_740_992; // 2^53
|
||||
assert_eq!(
|
||||
num_proj::u64_to_f64(max_exact),
|
||||
ProjectedNumber::Exact(max_exact as f64)
|
||||
);
|
||||
|
||||
// Test values that cannot be exactly represented as f64 (integers above 2^53)
|
||||
let large_u64 = 9_007_199_254_740_993; // 2^53 + 1
|
||||
let closest_f64 = 9_007_199_254_740_992.0;
|
||||
assert_eq!(large_u64 as f64, closest_f64);
|
||||
if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) {
|
||||
// Verify that the returned float is different from the direct cast
|
||||
assert!(val > closest_f64);
|
||||
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
|
||||
} else {
|
||||
panic!("Expected ProjectedNumber::Next for large_u64");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,14 +6,10 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::cached_sub_aggs::{
|
||||
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
|
||||
use crate::docset::DocSet;
|
||||
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
|
||||
use crate::schema::Schema;
|
||||
@@ -408,18 +404,15 @@ pub struct FilterAggReqData {
|
||||
pub evaluator: DocumentQueryEvaluator,
|
||||
/// Reusable buffer for matching documents to minimize allocations during collection
|
||||
pub matching_docs_buffer: Vec<DocId>,
|
||||
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl FilterAggReqData {
|
||||
pub(crate) fn get_memory_consumption(&self) -> usize {
|
||||
// Estimate: name + segment reader reference + bitset + buffer capacity
|
||||
self.name.len()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
+ std::mem::size_of::<bool>()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -496,24 +489,17 @@ impl Debug for DocumentQueryEvaluator {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Copy)]
|
||||
struct DocCount {
|
||||
doc_count: u64,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// Segment collector for filter aggregation
|
||||
pub struct SegmentFilterCollector<C: SubAggCache> {
|
||||
/// Document counts per parent bucket
|
||||
parent_buckets: Vec<DocCount>,
|
||||
pub struct SegmentFilterCollector {
|
||||
/// Document count in this bucket
|
||||
doc_count: u64,
|
||||
/// Sub-aggregation collectors
|
||||
sub_aggregations: Option<CachedSubAggs<C>>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// Accessor index for this filter aggregation (to access FilterAggReqData)
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> SegmentFilterCollector<C> {
|
||||
impl SegmentFilterCollector {
|
||||
/// Create a new filter segment collector following the new agg_data pattern
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
@@ -525,75 +511,47 @@ impl<C: SubAggCache> SegmentFilterCollector<C> {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
|
||||
|
||||
Ok(SegmentFilterCollector {
|
||||
parent_buckets: Vec::new(),
|
||||
doc_count: 0,
|
||||
sub_aggregations: sub_agg_collector,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_filter_collector(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
|
||||
.as_ref()
|
||||
.expect("filter_req_data slot is empty")
|
||||
.is_top_level;
|
||||
|
||||
if is_top_level {
|
||||
Ok(Box::new(
|
||||
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
|
||||
))
|
||||
} else {
|
||||
Ok(Box::new(
|
||||
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
|
||||
impl Debug for SegmentFilterCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentFilterCollector")
|
||||
.field("buckets", &self.parent_buckets)
|
||||
.field("doc_count", &self.doc_count)
|
||||
.field("has_sub_aggs", &self.sub_aggregations.is_some())
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
|
||||
impl CollectorClone for SegmentFilterCollector {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
// For now, panic - this needs proper implementation with weight recreation
|
||||
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let mut sub_results = IntermediateAggregationResults::default();
|
||||
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
|
||||
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_results,
|
||||
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
|
||||
// exist, so that sub-aggregations can still produce results (e.g., zero doc
|
||||
// count)
|
||||
bucket_opt
|
||||
.map(|bucket| bucket.bucket_id)
|
||||
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
|
||||
)?;
|
||||
if let Some(sub_aggs) = self.sub_aggregations {
|
||||
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
|
||||
}
|
||||
|
||||
// Create the filter bucket result
|
||||
let filter_bucket_result = IntermediateBucketResult::Filter {
|
||||
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregations: sub_results,
|
||||
};
|
||||
|
||||
@@ -612,17 +570,32 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(
|
||||
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
// Access the evaluator from FilterAggReqData
|
||||
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
|
||||
|
||||
// O(1) BitSet lookup to check if document matches filter
|
||||
if req_data.evaluator.matches_document(doc) {
|
||||
self.doc_count += 1;
|
||||
|
||||
// If we have sub-aggregations, collect on them for this filtered document
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
docs: &[DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if docs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
|
||||
// Take the request data to avoid borrow checker issues with sub-aggregations
|
||||
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
|
||||
|
||||
@@ -631,24 +604,18 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
|
||||
req.evaluator
|
||||
.filter_batch(docs, &mut req.matching_docs_buffer);
|
||||
|
||||
bucket.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
self.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
|
||||
// Batch process sub-aggregations if we have matches
|
||||
if !req.matching_docs_buffer.is_empty() {
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
for &doc_id in &req.matching_docs_buffer {
|
||||
sub_aggs.push(bucket.bucket_id, doc_id);
|
||||
}
|
||||
// Use collect_block for better sub-aggregation performance
|
||||
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Put the request data back
|
||||
agg_data.put_back_filter_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.check_flush_local(agg_data)?;
|
||||
}
|
||||
// put back bucket
|
||||
self.parent_buckets[parent_bucket_id as usize] = bucket;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -659,21 +626,6 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.parent_buckets.push(DocCount {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result for filter aggregation
|
||||
@@ -1567,9 +1519,9 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let agg = json!({
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result<i64, AggregationError>
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
|
||||
fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
|
||||
let split_boundary = input
|
||||
.as_bytes()
|
||||
.iter()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy_bitpacker::minmax;
|
||||
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::BucketEntry;
|
||||
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateHistogramBucketEntry,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -26,8 +26,13 @@ pub struct HistogramAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The field type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
|
||||
/// Will be filled during initialization of the collector.
|
||||
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The histogram aggregation request.
|
||||
pub req: HistogramAggregation,
|
||||
/// True if this is a date_histogram aggregation.
|
||||
@@ -252,24 +257,18 @@ impl HistogramBounds {
|
||||
pub(crate) struct SegmentHistogramBucketEntry {
|
||||
pub key: f64,
|
||||
pub doc_count: u64,
|
||||
pub bucket_id: BucketId,
|
||||
}
|
||||
|
||||
impl SegmentHistogramBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
|
||||
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateHistogramBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = sub_aggregation {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_aggregation_res,
|
||||
self.bucket_id,
|
||||
)?;
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
|
||||
}
|
||||
Ok(IntermediateHistogramBucketEntry {
|
||||
key: self.key,
|
||||
@@ -279,38 +278,27 @@ impl SegmentHistogramBucketEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct HistogramBuckets {
|
||||
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentHistogramCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
/// One Histogram bucket per parent bucket id.
|
||||
parent_buckets: Vec<HistogramBuckets>,
|
||||
sub_agg: Option<HighCardCachedSubAggs>,
|
||||
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
|
||||
accessor_idx: usize,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
// TODO: avoid prepare_max_bucket here and handle empty buckets.
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
|
||||
let bucket = self.into_intermediate_bucket_result(agg_data)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
@@ -319,40 +307,44 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
|
||||
|
||||
let bounds = req.bounds;
|
||||
let interval = req.req.interval;
|
||||
let offset = req.offset;
|
||||
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in agg_data
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let val = f64_from_fastfield_u64(val, req.field_type);
|
||||
let val = f64_from_fastfield_u64(val, &req.field_type);
|
||||
let bucket_pos = get_bucket_pos(val);
|
||||
if bounds.contains(val) {
|
||||
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
|
||||
SegmentHistogramBucketEntry {
|
||||
key,
|
||||
doc_count: 0,
|
||||
bucket_id: self.bucket_id_provider.next_bucket_id(),
|
||||
}
|
||||
SegmentHistogramBucketEntry { key, doc_count: 0 }
|
||||
});
|
||||
bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
|
||||
self.sub_aggregations
|
||||
.entry(bucket_pos)
|
||||
.or_insert_with(|| sub_aggregation_blueprint.clone())
|
||||
.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,30 +358,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
for sub_aggregation in self.sub_aggregations.values_mut() {
|
||||
sub_aggregation.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
self.parent_buckets.push(HistogramBuckets {
|
||||
buckets: FxHashMap::default(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -397,19 +373,22 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
impl SegmentHistogramCollector {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
|
||||
self_mem + buckets_mem
|
||||
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
|
||||
let buckets_mem = self.buckets.memory_consumption();
|
||||
self_mem + sub_aggs_mem + buckets_mem
|
||||
}
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
fn add_intermediate_bucket_result(
|
||||
&mut self,
|
||||
pub fn into_intermediate_bucket_result(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
histogram: HistogramBuckets,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut buckets = Vec::with_capacity(histogram.buckets.len());
|
||||
let mut buckets = Vec::with_capacity(self.buckets.len());
|
||||
|
||||
for bucket in histogram.buckets.into_values() {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
|
||||
for (bucket_pos, bucket) in self.buckets {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(
|
||||
self.sub_aggregations.get(&bucket_pos).cloned(),
|
||||
agg_data,
|
||||
);
|
||||
|
||||
buckets.push(bucket_res?);
|
||||
}
|
||||
@@ -429,7 +408,7 @@ impl SegmentHistogramCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
let blueprint = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
@@ -444,13 +423,13 @@ impl SegmentHistogramCollector {
|
||||
max: f64::MAX,
|
||||
});
|
||||
req_data.offset = req_data.req.offset.unwrap_or(0.0);
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
|
||||
req_data.sub_aggregation_blueprint = blueprint;
|
||||
|
||||
Ok(Self {
|
||||
parent_buckets: Default::default(),
|
||||
sub_agg,
|
||||
buckets: Default::default(),
|
||||
sub_aggregations: Default::default(),
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
//! - [Range](RangeAggregation)
|
||||
//! - [Terms](TermsAggregation)
|
||||
|
||||
mod composite;
|
||||
mod filter;
|
||||
mod histogram;
|
||||
mod range;
|
||||
@@ -32,7 +31,6 @@ mod term_missing_agg;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
pub use composite::*;
|
||||
pub use filter::*;
|
||||
pub use histogram::*;
|
||||
pub use range::*;
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::AggregationLimitsGuard;
|
||||
use crate::aggregation::cached_sub_aggs::{
|
||||
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -27,12 +23,12 @@ pub struct RangeAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The range aggregation request.
|
||||
pub req: RangeAggregation,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// Whether this is a top-level aggregation.
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl RangeAggReqData {
|
||||
@@ -155,47 +151,19 @@ pub(crate) struct SegmentRangeAndBucketEntry {
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
pub struct SegmentRangeCollector<C: SubAggCache> {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentRangeCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
/// One for each ParentBucketId
|
||||
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
|
||||
buckets: Vec<SegmentRangeAndBucketEntry>,
|
||||
column_type: ColumnType,
|
||||
pub(crate) accessor_idx: usize,
|
||||
sub_agg: Option<CachedSubAggs<C>>,
|
||||
/// Here things get a bit weird. We need to assign unique bucket ids across all
|
||||
/// parent buckets. So we keep track of the next available bucket id here.
|
||||
/// This allows a kind of flattening of the bucket ids across all parent buckets.
|
||||
/// E.g. in nested aggregations:
|
||||
/// Term Agg -> Range aggregation -> Stats aggregation
|
||||
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
|
||||
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
|
||||
/// - INFO: 0,1,2,3
|
||||
/// - ERROR: 4,5,6,7
|
||||
/// - WARN: 8,9,10,11
|
||||
///
|
||||
/// This allows the Stats aggregation to have unique bucket ids to refer to.
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
limits: AggregationLimitsGuard,
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentRangeCollector")
|
||||
.field("parent_buckets_len", &self.parent_buckets.len())
|
||||
.field("column_type", &self.column_type)
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.field("has_sub_agg", &self.sub_agg.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SegmentRangeBucketEntry {
|
||||
pub key: Key,
|
||||
pub doc_count: u64,
|
||||
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
pub bucket_id: BucketId,
|
||||
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
|
||||
@@ -216,50 +184,48 @@ impl Debug for SegmentRangeBucketEntry {
|
||||
impl SegmentRangeBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateRangeBucketEntry> {
|
||||
let sub_aggregation = IntermediateAggregationResults::default();
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = self.sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
|
||||
Ok(IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation_res: sub_aggregation,
|
||||
sub_aggregation: sub_aggregation_res,
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
|
||||
impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let field_type = self.column_type;
|
||||
let name = agg_data
|
||||
.get_range_req_data(self.accessor_idx)
|
||||
.name
|
||||
.to_string();
|
||||
|
||||
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
|
||||
.buckets
|
||||
.into_iter()
|
||||
.map(|range_bucket| {
|
||||
let bucket_id = range_bucket.bucket.bucket_id;
|
||||
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut agg.sub_aggregation_res,
|
||||
bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
|
||||
.map(move |range_bucket| {
|
||||
Ok((
|
||||
range_to_string(&range_bucket.range, &field_type)?,
|
||||
range_bucket
|
||||
.bucket
|
||||
.into_intermediate_bucket_entry(agg_data)?,
|
||||
))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
@@ -276,114 +242,73 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
// Take request data to avoid borrow conflicts during sub-aggregation
|
||||
let mut req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
|
||||
|
||||
for (doc, val) in agg_data
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let bucket_pos = get_bucket_pos(val, buckets);
|
||||
let bucket = &mut buckets[bucket_pos];
|
||||
let bucket_pos = self.get_bucket_pos(val);
|
||||
let bucket = &mut self.buckets[bucket_pos];
|
||||
bucket.bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket.bucket_id, doc);
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
agg_data.put_back_range_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
for bucket in self.buckets.iter_mut() {
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let new_buckets = self.create_new_buckets(agg_data)?;
|
||||
self.parent_buckets.push(new_buckets);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
pub(crate) fn build_segment_range_collector(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
|
||||
let field_type = req_data.field_type;
|
||||
|
||||
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
|
||||
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
|
||||
// we can are still in low cardinality territory.
|
||||
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
|
||||
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_low_card {
|
||||
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
|
||||
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
} else {
|
||||
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> SegmentRangeCollector<C> {
|
||||
pub(crate) fn create_new_buckets(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
|
||||
let field_type = self.column_type;
|
||||
let req_data = agg_data.get_range_req_data(self.accessor_idx);
|
||||
impl SegmentRangeCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let (field_type, ranges) = {
|
||||
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
|
||||
(req_view.field_type, req_view.req.ranges.clone())
|
||||
};
|
||||
|
||||
// The range input on the request is f64.
|
||||
// We need to convert to u64 ranges, because we read the values as u64.
|
||||
// The mapping from the conversion is monotonic so ordering is preserved.
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
|
||||
let sub_agg_prototype = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(req_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
@@ -392,20 +317,20 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
// let sub_aggregation = sub_agg_prototype.clone();
|
||||
let sub_aggregation = sub_agg_prototype.clone();
|
||||
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
sub_aggregation,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
@@ -414,19 +339,26 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
self.limits.add_memory_consumed(
|
||||
req_data.context.limits.add_memory_consumed(
|
||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||
)?;
|
||||
Ok(buckets)
|
||||
|
||||
Ok(SegmentRangeCollector {
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_bucket_pos(&self, val: u64) -> usize {
|
||||
let pos = self
|
||||
.buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(self.buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
|
||||
let pos = buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
|
||||
/// Converts the user provided f64 range value to fast field value space.
|
||||
@@ -524,7 +456,7 @@ pub(crate) fn range_to_string(
|
||||
let val = i64::from_u64(val);
|
||||
format_date(val)
|
||||
} else {
|
||||
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
|
||||
Ok(f64_from_fastfield_u64(val, field_type).to_string())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -554,7 +486,7 @@ mod tests {
|
||||
pub fn get_collector_from_ranges(
|
||||
ranges: Vec<RangeAggregationRange>,
|
||||
field_type: ColumnType,
|
||||
) -> SegmentRangeCollector<HighCardSubAggCache> {
|
||||
) -> SegmentRangeCollector {
|
||||
let req = RangeAggregation {
|
||||
field: "dummy".to_string(),
|
||||
ranges,
|
||||
@@ -574,33 +506,30 @@ mod tests {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation: None,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
bucket_id: 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
SegmentRangeCollector {
|
||||
parent_buckets: vec![buckets],
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx: 0,
|
||||
sub_agg: None,
|
||||
bucket_id_provider: Default::default(),
|
||||
limits: AggregationLimitsGuard::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -847,7 +776,7 @@ mod tests {
|
||||
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -870,7 +799,7 @@ mod tests {
|
||||
];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -885,7 +814,7 @@ mod tests {
|
||||
let buckets = vec![(-10f64..-1f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
|
||||
}
|
||||
@@ -894,7 +823,7 @@ mod tests {
|
||||
let buckets = vec![(0f64..10f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
|
||||
}
|
||||
@@ -903,7 +832,7 @@ mod tests {
|
||||
fn range_binary_search_test_u64() {
|
||||
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9), 0);
|
||||
@@ -949,7 +878,7 @@ mod tests {
|
||||
let ranges = vec![(10.0..100.0).into()];
|
||||
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9f64.to_u64()), 0);
|
||||
@@ -961,3 +890,63 @@ mod tests {
|
||||
// the max value
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
|
||||
|
||||
const TOTAL_DOCS: u64 = 1_000_000u64;
|
||||
const NUM_DOCS: u64 = 50_000u64;
|
||||
|
||||
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
|
||||
let bucket_size = num_docs / num_buckets;
|
||||
let mut buckets: Vec<RangeAggregationRange> = vec![];
|
||||
for i in 0..num_buckets {
|
||||
let bucket_start = (i * bucket_size) as f64;
|
||||
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
|
||||
}
|
||||
|
||||
get_collector_from_ranges(buckets, ColumnType::U64)
|
||||
}
|
||||
|
||||
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let all_docs = (0..total_docs - 1).collect_vec();
|
||||
let mut vals = all_docs
|
||||
.as_slice()
|
||||
.choose_multiple(&mut rng, num_docs_returned as usize)
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
vals.sort();
|
||||
vals
|
||||
}
|
||||
|
||||
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
|
||||
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
|
||||
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
|
||||
b.iter(|| {
|
||||
let mut bucket_pos = 0;
|
||||
for val in &vals {
|
||||
bucket_pos = collector.get_bucket_pos(*val);
|
||||
}
|
||||
bucket_pos
|
||||
})
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_100_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 100)
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_10_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 10)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,13 +5,11 @@ use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::term_agg::TermsAggregation;
|
||||
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
|
||||
/// Special aggregation to handle missing values for term aggregations.
|
||||
/// This missing aggregation will check multiple columns for existence.
|
||||
@@ -37,55 +35,41 @@ impl MissingTermAggReqData {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct MissingCount {
|
||||
missing_count: u32,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct TermMissingAgg {
|
||||
missing_count: u32,
|
||||
accessor_idx: usize,
|
||||
sub_agg: Option<HighCardCachedSubAggs>,
|
||||
/// Idx = parent bucket id, Value = missing count for that bucket
|
||||
missing_count_per_bucket: Vec<MissingCount>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
}
|
||||
impl TermMissingAgg {
|
||||
pub(crate) fn new(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let sub_agg = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
let bucket_id_provider = BucketIdProvider::default();
|
||||
|
||||
Ok(Self {
|
||||
accessor_idx,
|
||||
sub_agg,
|
||||
missing_count_per_bucket: Vec::new(),
|
||||
bucket_id_provider,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let term_agg = &req_data.req;
|
||||
let missing = term_agg
|
||||
@@ -96,16 +80,13 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
|
||||
Default::default();
|
||||
|
||||
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let mut missing_entry = IntermediateTermBucketEntry {
|
||||
doc_count: missing_count.missing_count,
|
||||
doc_count: self.missing_count,
|
||||
sub_aggregation: Default::default(),
|
||||
};
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
if let Some(sub_agg) = self.sub_agg {
|
||||
let mut res = IntermediateAggregationResults::default();
|
||||
sub_agg
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
|
||||
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
|
||||
missing_entry.sub_aggregation = res;
|
||||
}
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
@@ -128,52 +109,30 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
self.missing_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
|
||||
for doc in docs {
|
||||
let doc = *doc;
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
bucket.missing_count += 1;
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.missing_count_per_bucket.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.missing_count_per_bucket.push(MissingCount {
|
||||
missing_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
self.collect(*doc, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
87
src/aggregation/buf_collector.rs
Normal file
87
src/aggregation/buf_collector.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
|
||||
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BufAggregationCollector {
|
||||
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
|
||||
staged_docs: DocBlock,
|
||||
num_staged_docs: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufAggregationCollector {
|
||||
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
num_staged_docs: 0,
|
||||
staged_docs: [0; DOC_BLOCK_SIZE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.staged_docs[self.num_staged_docs] = doc;
|
||||
self.num_staged_docs += 1;
|
||||
if self.num_staged_docs == self.staged_docs.len() {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
|
||||
self.collector.flush(agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::DocId;
|
||||
|
||||
/// A cache for sub-aggregations, storing doc ids per bucket id.
|
||||
/// Depending on the cardinality of the parent aggregation, we use different
|
||||
/// storage strategies.
|
||||
///
|
||||
/// ## Low Cardinality
|
||||
/// Cardinality here refers to the number of unique flattened buckets that can be created
|
||||
/// by the parent aggregation.
|
||||
/// Flattened buckets are the result of combining all buckets per collector
|
||||
/// into a single list of buckets, where each bucket is identified by its BucketId.
|
||||
///
|
||||
/// ## Usage
|
||||
/// Since this is caching for sub-aggregations, it is only used by bucket
|
||||
/// aggregations.
|
||||
///
|
||||
/// TODO: consider using a more advanced data structure for high cardinality
|
||||
/// aggregations.
|
||||
/// What this datastructure does in general is to group docs by bucket id.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CachedSubAggs<C: SubAggCache> {
|
||||
cache: C,
|
||||
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
|
||||
num_docs: usize,
|
||||
}
|
||||
|
||||
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
|
||||
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
|
||||
|
||||
const FLUSH_THRESHOLD: usize = 2048;
|
||||
|
||||
/// A trait for caching sub-aggregation doc ids per bucket id.
|
||||
/// Different implementations can be used depending on the cardinality
|
||||
/// of the parent aggregation.
|
||||
pub trait SubAggCache: Debug {
|
||||
fn new() -> Self;
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()>;
|
||||
}
|
||||
|
||||
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
|
||||
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
cache: Backend::new(),
|
||||
sub_agg_collector: sub_agg,
|
||||
num_docs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
|
||||
&mut self.sub_agg_collector
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
self.cache.push(bucket_id, doc_id);
|
||||
self.num_docs += 1;
|
||||
}
|
||||
|
||||
/// Check if we need to flush based on the number of documents cached.
|
||||
/// If so, flushes the cache to the provided aggregation collector.
|
||||
pub fn check_flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.num_docs >= FLUSH_THRESHOLD {
|
||||
self.cache
|
||||
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
|
||||
self.num_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this _does_ flush the sub aggregations.
|
||||
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if self.num_docs != 0 {
|
||||
self.cache
|
||||
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
|
||||
self.num_docs = 0;
|
||||
}
|
||||
self.sub_agg_collector.flush(agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of partitions for high cardinality sub-aggregation cache.
|
||||
const NUM_PARTITIONS: usize = 16;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct HighCardSubAggCache {
|
||||
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
|
||||
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
|
||||
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
|
||||
///
|
||||
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
|
||||
/// buckets, and there may be nothing to group.
|
||||
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
|
||||
}
|
||||
|
||||
impl HighCardSubAggCache {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
for partition in self.partitions.iter_mut() {
|
||||
partition.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct PartitionEntry {
|
||||
bucket_ids: Vec<BucketId>,
|
||||
docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl PartitionEntry {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
self.bucket_ids.clear();
|
||||
self.docs.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAggCache for HighCardSubAggCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
let idx = bucket_id % NUM_PARTITIONS as u32;
|
||||
let slot = &mut self.partitions[idx as usize];
|
||||
slot.bucket_ids.push(bucket_id);
|
||||
slot.docs.push(doc_id);
|
||||
}
|
||||
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
_force: bool,
|
||||
) -> crate::Result<()> {
|
||||
let mut max_bucket = 0u32;
|
||||
for partition in self.partitions.iter() {
|
||||
if let Some(&local_max) = partition.bucket_ids.iter().max() {
|
||||
max_bucket = max_bucket.max(local_max);
|
||||
}
|
||||
}
|
||||
|
||||
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
|
||||
for slot in self.partitions.iter() {
|
||||
if !slot.bucket_ids.is_empty() {
|
||||
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
|
||||
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LowCardSubAggCache {
|
||||
/// Cache doc ids per bucket for sub-aggregations.
|
||||
///
|
||||
/// The outer Vec is indexed by BucketId.
|
||||
per_bucket_docs: Vec<Vec<DocId>>,
|
||||
}
|
||||
|
||||
impl LowCardSubAggCache {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
for v in &mut self.per_bucket_docs {
|
||||
v.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAggCache for LowCardSubAggCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
per_bucket_docs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
let idx = bucket_id as usize;
|
||||
if self.per_bucket_docs.len() <= idx {
|
||||
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
|
||||
}
|
||||
self.per_bucket_docs[idx].push(doc_id);
|
||||
}
|
||||
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()> {
|
||||
// Pre-aggregated: call collect per bucket.
|
||||
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
|
||||
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
// The threshold above which we flush buckets individually.
|
||||
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
|
||||
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
|
||||
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
|
||||
const _: () = {
|
||||
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
|
||||
// Note: There may be other flexible values, for other aggregations, but we can use the
|
||||
// const value here as a upper bound. (better than nothing)
|
||||
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
|
||||
assert!(
|
||||
bucket_treshold_limit > 0,
|
||||
"Bucket threshold must be greater than 0"
|
||||
);
|
||||
};
|
||||
if force {
|
||||
bucket_treshold = 0;
|
||||
}
|
||||
for (bucket_id, docs) in self
|
||||
.per_bucket_docs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, docs)| docs.len() > bucket_treshold)
|
||||
{
|
||||
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
|
||||
}
|
||||
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::agg_req::Aggregations;
|
||||
use super::agg_result::AggregationResults;
|
||||
use super::cached_sub_aggs::LowCardCachedSubAggs;
|
||||
use super::buf_collector::BufAggregationCollector;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use super::AggContextParams;
|
||||
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
|
||||
use crate::aggregation::agg_data::{
|
||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||
};
|
||||
@@ -136,7 +136,7 @@ fn merge_fruits(
|
||||
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
|
||||
pub struct AggregationSegmentCollector {
|
||||
aggs_with_accessor: AggregationsSegmentCtx,
|
||||
agg_collector: LowCardCachedSubAggs,
|
||||
agg_collector: BufAggregationCollector,
|
||||
error: Option<TantivyError>,
|
||||
}
|
||||
|
||||
@@ -151,11 +151,8 @@ impl AggregationSegmentCollector {
|
||||
) -> crate::Result<Self> {
|
||||
let mut agg_data =
|
||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||
let mut result =
|
||||
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
result
|
||||
.get_sub_agg_collector()
|
||||
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
|
||||
let result =
|
||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
|
||||
Ok(AggregationSegmentCollector {
|
||||
aggs_with_accessor: agg_data,
|
||||
@@ -173,31 +170,26 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
self.agg_collector.push(0, doc);
|
||||
match self
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.check_flush_local(&mut self.aggs_with_accessor)
|
||||
.collect(doc, &mut self.aggs_with_accessor)
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
self.error = Some(err);
|
||||
}
|
||||
}
|
||||
|
||||
/// The query pushes the documents to the collector via this method.
|
||||
///
|
||||
/// Only valid for Collectors that ignore docs
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
match self.agg_collector.get_sub_agg_collector().collect(
|
||||
0,
|
||||
docs,
|
||||
&mut self.aggs_with_accessor,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.collect_block(docs, &mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,13 +200,10 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
|
||||
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
self.agg_collector
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
0,
|
||||
)?;
|
||||
Box::new(self.agg_collector).add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
)?;
|
||||
|
||||
Ok(sub_aggregation_res)
|
||||
}
|
||||
|
||||
@@ -25,12 +25,9 @@ use super::metric::{
|
||||
use super::segment_agg_result::AggregationLimitsGuard;
|
||||
use super::{format_date, AggregationError, Key, SerializedKey};
|
||||
use crate::aggregation::agg_result::{
|
||||
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
|
||||
};
|
||||
use crate::aggregation::bucket::{
|
||||
composite_intermediate_key_ordering, CompositeAggregation, MissingOrder,
|
||||
TermsAggregationInternal,
|
||||
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
|
||||
};
|
||||
use crate::aggregation::bucket::TermsAggregationInternal;
|
||||
use crate::aggregation::metric::CardinalityCollector;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -93,19 +90,6 @@ impl From<IntermediateKey> for Key {
|
||||
|
||||
impl Eq for IntermediateKey {}
|
||||
|
||||
impl std::fmt::Display for IntermediateKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
IntermediateKey::Str(val) => f.write_str(val),
|
||||
IntermediateKey::F64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::U64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::I64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::Bool(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::IpAddr(val) => f.write_str(&val.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for IntermediateKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
@@ -121,21 +105,6 @@ impl std::hash::Hash for IntermediateKey {
|
||||
}
|
||||
|
||||
impl IntermediateAggregationResults {
|
||||
/// Returns a reference to the intermediate aggregation result for the given key.
|
||||
pub fn get(&self, key: &str) -> Option<&IntermediateAggregationResult> {
|
||||
self.aggs_res.get(key)
|
||||
}
|
||||
|
||||
/// Removes and returns the intermediate aggregation result for the given key.
|
||||
pub fn remove(&mut self, key: &str) -> Option<IntermediateAggregationResult> {
|
||||
self.aggs_res.remove(key)
|
||||
}
|
||||
|
||||
/// Returns an iterator over the keys in the intermediate aggregation results.
|
||||
pub fn keys(&self) -> impl Iterator<Item = &String> {
|
||||
self.aggs_res.keys()
|
||||
}
|
||||
|
||||
/// Add a result
|
||||
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) -> crate::Result<()> {
|
||||
let entry = self.aggs_res.entry(key);
|
||||
@@ -249,11 +218,6 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
|
||||
is_date_agg: true,
|
||||
})
|
||||
}
|
||||
Composite(_) => {
|
||||
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite {
|
||||
buckets: Default::default(),
|
||||
})
|
||||
}
|
||||
Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average(
|
||||
IntermediateAverage::default(),
|
||||
)),
|
||||
@@ -481,11 +445,6 @@ pub enum IntermediateBucketResult {
|
||||
/// Sub-aggregation results
|
||||
sub_aggregations: IntermediateAggregationResults,
|
||||
},
|
||||
/// Composite aggregation
|
||||
Composite {
|
||||
/// The composite buckets
|
||||
buckets: IntermediateCompositeBucketResult,
|
||||
},
|
||||
}
|
||||
|
||||
impl IntermediateBucketResult {
|
||||
@@ -581,13 +540,6 @@ impl IntermediateBucketResult {
|
||||
sub_aggregations: final_sub_aggregations,
|
||||
}))
|
||||
}
|
||||
IntermediateBucketResult::Composite { buckets } => buckets.into_final_result(
|
||||
req.agg
|
||||
.as_composite()
|
||||
.expect("unexpected aggregation, expected composite aggregation"),
|
||||
req.sub_aggregation(),
|
||||
limits,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -654,16 +606,6 @@ impl IntermediateBucketResult {
|
||||
*doc_count_left += doc_count_right;
|
||||
sub_aggs_left.merge_fruits(sub_aggs_right)?;
|
||||
}
|
||||
(
|
||||
IntermediateBucketResult::Composite {
|
||||
buckets: buckets_left,
|
||||
},
|
||||
IntermediateBucketResult::Composite {
|
||||
buckets: buckets_right,
|
||||
},
|
||||
) => {
|
||||
buckets_left.merge_fruits(buckets_right)?;
|
||||
}
|
||||
(IntermediateBucketResult::Range(_), _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
@@ -676,9 +618,6 @@ impl IntermediateBucketResult {
|
||||
(IntermediateBucketResult::Filter { .. }, _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
(IntermediateBucketResult::Composite { .. }, _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -700,21 +639,6 @@ pub struct IntermediateTermBucketResult {
|
||||
}
|
||||
|
||||
impl IntermediateTermBucketResult {
|
||||
/// Returns a reference to the map of bucket entries keyed by [`IntermediateKey`].
|
||||
pub fn entries(&self) -> &FxHashMap<IntermediateKey, IntermediateTermBucketEntry> {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
/// Returns the count of documents not included in the returned buckets.
|
||||
pub fn sum_other_doc_count(&self) -> u64 {
|
||||
self.sum_other_doc_count
|
||||
}
|
||||
|
||||
/// Returns the upper bound of the error on document counts in the returned buckets.
|
||||
pub fn doc_count_error_upper_bound(&self) -> u64 {
|
||||
self.doc_count_error_upper_bound
|
||||
}
|
||||
|
||||
pub(crate) fn into_final_result(
|
||||
self,
|
||||
req: &TermsAggregation,
|
||||
@@ -868,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
|
||||
/// The number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
/// The sub_aggregation in this bucket.
|
||||
pub sub_aggregation_res: IntermediateAggregationResults,
|
||||
pub sub_aggregation: IntermediateAggregationResults,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`.
|
||||
@@ -887,7 +811,7 @@ impl IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: self
|
||||
.sub_aggregation_res
|
||||
.sub_aggregation
|
||||
.into_final_result_internal(req, limits)?,
|
||||
to: self.to,
|
||||
from: self.from,
|
||||
@@ -896,7 +820,7 @@ impl IntermediateRangeBucketEntry {
|
||||
};
|
||||
|
||||
// If we have a date type on the histogram buckets, we add the `key_as_string` field as
|
||||
// rfc3339
|
||||
// rfc339
|
||||
if column_type == Some(ColumnType::DateTime) {
|
||||
if let Some(val) = range_bucket_entry.to {
|
||||
let key_as_string = format_date(val as i64)?;
|
||||
@@ -922,212 +846,6 @@ pub struct IntermediateTermBucketEntry {
|
||||
pub sub_aggregation: IntermediateAggregationResults,
|
||||
}
|
||||
|
||||
/// Entry for the composite bucket.
|
||||
pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry;
|
||||
|
||||
/// The fully typed key for composite aggregation
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum CompositeIntermediateKey {
|
||||
/// Bool key
|
||||
Bool(bool),
|
||||
/// String key
|
||||
Str(String),
|
||||
/// Float key
|
||||
F64(f64),
|
||||
/// Signed integer key
|
||||
I64(i64),
|
||||
/// Unsigned integer key
|
||||
U64(u64),
|
||||
/// DateTime key, nanoseconds since epoch
|
||||
DateTime(i64),
|
||||
/// IP Address key
|
||||
IpAddr(Ipv6Addr),
|
||||
/// Missing value key
|
||||
Null,
|
||||
}
|
||||
|
||||
impl Eq for CompositeIntermediateKey {}
|
||||
|
||||
impl std::hash::Hash for CompositeIntermediateKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
CompositeIntermediateKey::Bool(val) => val.hash(state),
|
||||
CompositeIntermediateKey::Str(text) => text.hash(state),
|
||||
CompositeIntermediateKey::F64(val) => val.to_bits().hash(state),
|
||||
CompositeIntermediateKey::U64(val) => val.hash(state),
|
||||
CompositeIntermediateKey::I64(val) => val.hash(state),
|
||||
CompositeIntermediateKey::DateTime(val) => val.hash(state),
|
||||
CompositeIntermediateKey::IpAddr(val) => val.hash(state),
|
||||
CompositeIntermediateKey::Null => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Composite aggregation page.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateCompositeBucketResult {
|
||||
#[serde(
|
||||
serialize_with = "serialize_composite_entries",
|
||||
deserialize_with = "deserialize_composite_entries"
|
||||
)]
|
||||
pub(crate) entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
|
||||
pub(crate) target_size: u32,
|
||||
pub(crate) orders: Vec<(Order, MissingOrder)>,
|
||||
}
|
||||
|
||||
fn serialize_composite_entries<S>(
|
||||
entries: &FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeSeq;
|
||||
let mut seq = serializer.serialize_seq(Some(entries.len()))?;
|
||||
for (k, v) in entries {
|
||||
seq.serialize_element(&(k, v))?;
|
||||
}
|
||||
seq.end()
|
||||
}
|
||||
|
||||
fn deserialize_composite_entries<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let vec: Vec<(Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry)> =
|
||||
serde::Deserialize::deserialize(deserializer)?;
|
||||
Ok(vec.into_iter().collect())
|
||||
}
|
||||
|
||||
impl IntermediateCompositeBucketResult {
|
||||
pub(crate) fn into_final_result(
|
||||
self,
|
||||
req: &CompositeAggregation,
|
||||
sub_aggregation_req: &Aggregations,
|
||||
limits: &mut AggregationLimitsGuard,
|
||||
) -> crate::Result<BucketResult> {
|
||||
let trimmed_entry_vec =
|
||||
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
|
||||
let after_key = if trimmed_entry_vec.len() == req.size as usize {
|
||||
trimmed_entry_vec
|
||||
.last()
|
||||
.map(|bucket| {
|
||||
let (intermediate_key, _entry) = bucket;
|
||||
intermediate_key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, intermediate_key)| {
|
||||
let source = &req.sources[idx];
|
||||
(source.name().to_string(), intermediate_key.clone().into())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap()
|
||||
} else {
|
||||
FxHashMap::default()
|
||||
};
|
||||
|
||||
let buckets = trimmed_entry_vec
|
||||
.into_iter()
|
||||
.map(|(intermediate_key, entry)| {
|
||||
let key = intermediate_key
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, intermediate_key)| {
|
||||
let source = &req.sources[idx];
|
||||
(source.name().to_string(), intermediate_key.into())
|
||||
})
|
||||
.collect();
|
||||
Ok(CompositeBucketEntry {
|
||||
key,
|
||||
doc_count: entry.doc_count as u64,
|
||||
sub_aggregation: entry
|
||||
.sub_aggregation
|
||||
.into_final_result_internal(sub_aggregation_req, limits)?,
|
||||
})
|
||||
})
|
||||
.collect::<crate::Result<Vec<_>>>()?;
|
||||
|
||||
Ok(BucketResult::Composite { after_key, buckets })
|
||||
}
|
||||
|
||||
fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> {
|
||||
merge_maps(&mut self.entries, other.entries)?;
|
||||
if self.entries.len() as u32 > 2 * self.target_size {
|
||||
// 2x factor used to avoid trimming too often (expensive operation)
|
||||
// an optimal threshold could probably be figured out
|
||||
self.trim()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Trim the composite buckets to the target size, according to the ordering.
|
||||
///
|
||||
/// Returns an error if the ordering comparison fails.
|
||||
pub(crate) fn trim(&mut self) -> crate::Result<()> {
|
||||
if self.entries.len() as u32 <= self.target_size {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let sorted_entries = trim_composite_buckets(
|
||||
std::mem::take(&mut self.entries),
|
||||
&self.orders,
|
||||
self.target_size,
|
||||
)?;
|
||||
|
||||
self.entries = sorted_entries.into_iter().collect();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn trim_composite_buckets(
|
||||
entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
|
||||
orders: &[(Order, MissingOrder)],
|
||||
target_size: u32,
|
||||
) -> crate::Result<
|
||||
Vec<(
|
||||
Vec<CompositeIntermediateKey>,
|
||||
IntermediateCompositeBucketEntry,
|
||||
)>,
|
||||
> {
|
||||
let mut entries: Vec<_> = entries.into_iter().collect();
|
||||
let mut sort_error: Option<TantivyError> = None;
|
||||
entries.sort_by(|(left_key, _), (right_key, _)| {
|
||||
// Only attempt sorting if we haven't encountered an error yet
|
||||
if sort_error.is_some() {
|
||||
return Ordering::Equal; // Return a default, we'll handle the error after sorting
|
||||
}
|
||||
|
||||
for i in 0..orders.len() {
|
||||
match composite_intermediate_key_ordering(
|
||||
&left_key[i],
|
||||
&right_key[i],
|
||||
orders[i].0,
|
||||
orders[i].1,
|
||||
) {
|
||||
Ok(ordering) if ordering != Ordering::Equal => return ordering,
|
||||
Ok(_) => continue, // Equal, try next key
|
||||
Err(err) => {
|
||||
sort_error = Some(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Equal
|
||||
});
|
||||
|
||||
// If we encountered an error during sorting, return it now
|
||||
if let Some(err) = sort_error {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
entries.truncate(target_size as usize);
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
impl MergeFruits for IntermediateTermBucketEntry {
|
||||
fn merge_fruits(&mut self, other: IntermediateTermBucketEntry) -> crate::Result<()> {
|
||||
self.doc_count += other.doc_count;
|
||||
@@ -1139,8 +857,7 @@ impl MergeFruits for IntermediateTermBucketEntry {
|
||||
impl MergeFruits for IntermediateRangeBucketEntry {
|
||||
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
|
||||
self.doc_count += other.doc_count;
|
||||
self.sub_aggregation_res
|
||||
.merge_fruits(other.sub_aggregation_res)?;
|
||||
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1170,7 +887,7 @@ mod tests {
|
||||
IntermediateRangeBucketEntry {
|
||||
key: IntermediateKey::Str(key.to_string()),
|
||||
doc_count: *doc_count,
|
||||
sub_aggregation_res: Default::default(),
|
||||
sub_aggregation: Default::default(),
|
||||
from: None,
|
||||
to: None,
|
||||
},
|
||||
@@ -1203,7 +920,7 @@ mod tests {
|
||||
doc_count: *doc_count,
|
||||
from: None,
|
||||
to: None,
|
||||
sub_aggregation_res: get_sub_test_tree(&[(
|
||||
sub_aggregation: get_sub_test_tree(&[(
|
||||
sub_aggregation_key.to_string(),
|
||||
*sub_aggregation_count,
|
||||
)]),
|
||||
|
||||
@@ -52,15 +52,11 @@ pub struct IntermediateAverage {
|
||||
|
||||
impl IntermediateAverage {
|
||||
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying [`IntermediateStats`].
|
||||
pub fn stats(&self) -> &IntermediateStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::hash::Hash;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use datasketches::hll::{HllSketch, HllType, HllUnion};
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
@@ -15,17 +16,29 @@ use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Log2 of the number of registers for the HLL sketch.
|
||||
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
|
||||
const LG_K: u8 = 11;
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct BuildSaltedHasher {
|
||||
salt: u8,
|
||||
}
|
||||
|
||||
impl BuildHasher for BuildSaltedHasher {
|
||||
type Hasher = DefaultHasher;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write_u8(self.salt);
|
||||
|
||||
hasher
|
||||
}
|
||||
}
|
||||
|
||||
/// # Cardinality
|
||||
///
|
||||
/// The cardinality aggregation allows for computing an estimate
|
||||
/// of the number of different values in a data set based on the
|
||||
/// Apache DataSketches HyperLogLog algorithm. This is particularly useful for
|
||||
/// understanding the uniqueness of values in a large dataset where counting
|
||||
/// each unique value individually would be computationally expensive.
|
||||
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
|
||||
/// uniqueness of values in a large dataset where counting each unique value
|
||||
/// individually would be computationally expensive.
|
||||
///
|
||||
/// For example, you might use a cardinality aggregation to estimate the number
|
||||
/// of unique visitors to a website by aggregating on a field that contains
|
||||
@@ -93,6 +106,8 @@ pub struct CardinalityAggReqData {
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_value_for_accessor: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The aggregation request.
|
||||
@@ -120,34 +135,45 @@ impl CardinalityAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
buckets: Vec<SegmentCardinalityCollectorBucket>,
|
||||
accessor_idx: usize,
|
||||
/// The column accessor to access the fast field values.
|
||||
accessor: Column<u64>,
|
||||
/// The column_type of the field.
|
||||
column_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Default)]
|
||||
pub(crate) struct SegmentCardinalityCollectorBucket {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
impl SegmentCardinalityCollectorBucket {
|
||||
pub fn new(column_type: ColumnType) -> Self {
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: FxHashSet::default(),
|
||||
entries: Default::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut CardinalityAggReqData,
|
||||
) {
|
||||
if let Some(missing) = agg_data.missing_value_for_accessor {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_data.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
req_data: &CardinalityAggReqData,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = req_data
|
||||
@@ -168,10 +194,9 @@ impl SegmentCardinalityCollectorBucket {
|
||||
term_ids.push(term_ord as u32);
|
||||
}
|
||||
}
|
||||
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.insert(term);
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
Ok(())
|
||||
})?;
|
||||
if has_missing {
|
||||
@@ -182,17 +207,17 @@ impl SegmentCardinalityCollectorBucket {
|
||||
);
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
self.cardinality.insert(missing.as_str());
|
||||
self.cardinality.sketch.insert_any(&missing);
|
||||
}
|
||||
Key::F64(val) => {
|
||||
let val = f64_to_u64(*val);
|
||||
self.cardinality.insert(val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
Key::U64(val) => {
|
||||
self.cardinality.insert(*val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
Key::I64(val) => {
|
||||
self.cardinality.insert(*val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -202,49 +227,16 @@ impl SegmentCardinalityCollectorBucket {
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
accessor: Column<u64>,
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
|
||||
column_type,
|
||||
accessor_idx,
|
||||
accessor,
|
||||
missing_value_for_accessor,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_value_for_accessor,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
let name = req_data.name.to_string();
|
||||
// take the bucket in buckets and replace it with a new empty one
|
||||
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -255,20 +247,27 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.fetch_block_with_field(docs, agg_data);
|
||||
let bucket = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
|
||||
self.fetch_block_with_field(docs, req_data);
|
||||
|
||||
let col_block_accessor = &agg_data.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
let col_block_accessor = &req_data.column_block_accessor;
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
bucket.entries.insert(term_ord);
|
||||
self.entries.insert(term_ord);
|
||||
}
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
} else if req_data.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = req_data
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
@@ -283,43 +282,23 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
bucket.cardinality.insert(val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
bucket.cardinality.insert(val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if max_bucket as usize >= self.buckets.len() {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
SegmentCardinalityCollectorBucket::new(self.column_type)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// The cardinality collector used during segment collection and for merging results.
|
||||
/// Uses Apache DataSketches HLL (lg_k=11, Hll4) for compact binary serialization
|
||||
/// and cross-language compatibility (e.g. Java `datasketches` library).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
/// The percentiles collector used during segment collection and for merging results.
|
||||
pub struct CardinalityCollector {
|
||||
sketch: HllSketch,
|
||||
/// Salt derived from `ColumnType`, used to differentiate values of different column types
|
||||
/// that map to the same u64 (e.g. bool `false` = 0 vs i64 `0`).
|
||||
/// Not serialized — only needed during insertion, not after sketch registers are populated.
|
||||
salt: u8,
|
||||
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
|
||||
}
|
||||
|
||||
impl Default for CardinalityCollector {
|
||||
fn default() -> Self {
|
||||
Self::new(0)
|
||||
@@ -332,52 +311,25 @@ impl PartialEq for CardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for CardinalityCollector {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let bytes = self.sketch.serialize();
|
||||
serializer.serialize_bytes(&bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CardinalityCollector {
|
||||
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
|
||||
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
|
||||
Ok(Self { sketch, salt: 0 })
|
||||
}
|
||||
}
|
||||
|
||||
impl CardinalityCollector {
|
||||
/// Compute the final cardinality estimate.
|
||||
pub fn finalize(self) -> Option<f64> {
|
||||
Some(self.sketch.clone().count().trunc())
|
||||
}
|
||||
|
||||
fn new(salt: u8) -> Self {
|
||||
Self {
|
||||
sketch: HllSketch::new(LG_K, HllType::Hll4),
|
||||
salt,
|
||||
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a value into the HLL sketch, salted by the column type.
|
||||
/// The salt ensures that identical u64 values from different column types
|
||||
/// (e.g. bool `false` vs i64 `0`) are counted as distinct.
|
||||
pub(crate) fn insert<T: Hash>(&mut self, value: T) {
|
||||
self.sketch.update((self.salt, value));
|
||||
}
|
||||
|
||||
/// Compute the final cardinality estimate.
|
||||
pub fn finalize(self) -> Option<f64> {
|
||||
Some(self.sketch.estimate().trunc())
|
||||
}
|
||||
|
||||
/// Serialize the HLL sketch to its compact binary representation.
|
||||
/// The format is cross-language compatible with Apache DataSketches (Java, C++, Python).
|
||||
pub fn to_sketch_bytes(&self) -> Vec<u8> {
|
||||
self.sketch.serialize()
|
||||
}
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
|
||||
let mut union = HllUnion::new(LG_K);
|
||||
union.update(&self.sketch);
|
||||
union.update(&right.sketch);
|
||||
self.sketch = union.get_result(HllType::Hll4);
|
||||
self.sketch.merge(&right.sketch).map_err(|err| {
|
||||
TantivyError::AggregationError(AggregationError::InternalError(format!(
|
||||
"Error while merging cardinality {err:?}"
|
||||
)))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -539,75 +491,4 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_serde_roundtrip() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut collector = CardinalityCollector::default();
|
||||
collector.insert("hello");
|
||||
collector.insert("world");
|
||||
collector.insert("hello"); // duplicate
|
||||
|
||||
let serialized = serde_json::to_vec(&collector).unwrap();
|
||||
let deserialized: CardinalityCollector = serde_json::from_slice(&serialized).unwrap();
|
||||
|
||||
let original_estimate = collector.finalize().unwrap();
|
||||
let roundtrip_estimate = deserialized.finalize().unwrap();
|
||||
assert_eq!(original_estimate, roundtrip_estimate);
|
||||
assert_eq!(original_estimate, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_merge() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut left = CardinalityCollector::default();
|
||||
left.insert("a");
|
||||
left.insert("b");
|
||||
|
||||
let mut right = CardinalityCollector::default();
|
||||
right.insert("b");
|
||||
right.insert("c");
|
||||
|
||||
left.merge_fruits(right).unwrap();
|
||||
let estimate = left.finalize().unwrap();
|
||||
assert_eq!(estimate, 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_serialize_deserialize_binary() {
|
||||
use datasketches::hll::HllSketch;
|
||||
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut collector = CardinalityCollector::default();
|
||||
collector.insert("apple");
|
||||
collector.insert("banana");
|
||||
collector.insert("cherry");
|
||||
|
||||
let bytes = collector.to_sketch_bytes();
|
||||
let deserialized = HllSketch::deserialize(&bytes).unwrap();
|
||||
assert!((deserialized.estimate() - 3.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_salt_differentiates_types() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
// Without salt, same u64 value from different column types would collide
|
||||
let mut collector_bool = CardinalityCollector::new(5); // e.g. ColumnType::Bool
|
||||
collector_bool.insert(0u64); // false
|
||||
collector_bool.insert(1u64); // true
|
||||
|
||||
let mut collector_i64 = CardinalityCollector::new(2); // e.g. ColumnType::I64
|
||||
collector_i64.insert(0u64);
|
||||
collector_i64.insert(1u64);
|
||||
|
||||
// Merge them
|
||||
collector_bool.merge_fruits(collector_i64).unwrap();
|
||||
let estimate = collector_bool.finalize().unwrap();
|
||||
// Should be 4 because salt makes (5, 0) != (2, 0) and (5, 1) != (2, 1)
|
||||
assert_eq!(estimate, 4.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateCount {
|
||||
|
||||
impl IntermediateCount {
|
||||
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateCount) {
|
||||
|
||||
@@ -8,9 +8,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of extended statistics
|
||||
/// on numeric values that are extracted
|
||||
@@ -317,28 +318,51 @@ impl IntermediateExtendedStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentExtendedStatsCollector {
|
||||
name: String,
|
||||
missing: Option<u64>,
|
||||
field_type: ColumnType,
|
||||
accessor: columnar::Column<u64>,
|
||||
buckets: Vec<IntermediateExtendedStats>,
|
||||
sigma: Option<f64>,
|
||||
pub(crate) extended_stats: IntermediateExtendedStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
}
|
||||
|
||||
impl SegmentExtendedStatsCollector {
|
||||
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
|
||||
let missing = req
|
||||
.missing
|
||||
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
|
||||
pub fn from_req(
|
||||
field_type: ColumnType,
|
||||
sigma: Option<f64>,
|
||||
accessor_idx: usize,
|
||||
missing: Option<f64>,
|
||||
) -> Self {
|
||||
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
Self {
|
||||
name: req.name.clone(),
|
||||
field_type: req.field_type,
|
||||
accessor: req.accessor.clone(),
|
||||
field_type,
|
||||
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
|
||||
accessor_idx,
|
||||
missing,
|
||||
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
|
||||
sigma,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,18 +370,15 @@ impl SegmentExtendedStatsCollector {
|
||||
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = self.name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
|
||||
extended_stats,
|
||||
self.extended_stats,
|
||||
)),
|
||||
)?;
|
||||
|
||||
@@ -367,36 +388,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block_with_missing(docs, &self.accessor, self.missing);
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
extended_stats.collect(val1);
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = self.missing {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.extended_stats
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
// store back
|
||||
self.buckets[parent_bucket_id as usize] = extended_stats;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
IntermediateExtendedStats::with_sigma(self.sigma)
|
||||
});
|
||||
}
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateMax {
|
||||
|
||||
impl IntermediateMax {
|
||||
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMax) {
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateMin {
|
||||
|
||||
impl IntermediateMin {
|
||||
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMin) {
|
||||
|
||||
@@ -31,7 +31,7 @@ use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
@@ -55,6 +55,8 @@ pub struct MetricAggReqData {
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// Used when converting to intermediate result
|
||||
@@ -107,11 +109,8 @@ pub enum PercentileValues {
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The entry when requesting percentiles with keyed: false
|
||||
pub struct PercentileValuesVecEntry {
|
||||
/// Percentile
|
||||
pub key: f64,
|
||||
|
||||
/// Value at the percentile
|
||||
pub value: f64,
|
||||
key: f64,
|
||||
value: f64,
|
||||
}
|
||||
|
||||
/// Single-metric aggregations use this common result structure.
|
||||
|
||||
@@ -7,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// # Percentiles
|
||||
///
|
||||
@@ -130,16 +131,10 @@ impl PercentilesAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentPercentilesCollector {
|
||||
pub(crate) buckets: Vec<PercentilesCollector>,
|
||||
pub(crate) percentiles: PercentilesCollector,
|
||||
pub(crate) accessor_idx: usize,
|
||||
/// The type of the field.
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -222,12 +217,6 @@ impl PercentilesCollector {
|
||||
self.sketch.add(val);
|
||||
}
|
||||
|
||||
/// Encode the underlying DDSketch to Java-compatible binary format
|
||||
/// for cross-language serialization with Java consumers.
|
||||
pub fn to_sketch_bytes(&self) -> Vec<u8> {
|
||||
self.sketch.to_java_bytes()
|
||||
}
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, right: PercentilesCollector) -> crate::Result<()> {
|
||||
self.sketch.merge(&right.sketch).map_err(|err| {
|
||||
TantivyError::AggregationError(AggregationError::InternalError(format!(
|
||||
@@ -240,18 +229,33 @@ impl PercentilesCollector {
|
||||
}
|
||||
|
||||
impl SegmentPercentilesCollector {
|
||||
pub fn from_req_and_validate(
|
||||
field_type: ColumnType,
|
||||
missing_u64: Option<u64>,
|
||||
accessor: Column<u64>,
|
||||
accessor_idx: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: Vec::with_capacity(64),
|
||||
field_type,
|
||||
missing_u64,
|
||||
accessor,
|
||||
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
percentiles: PercentilesCollector::new(),
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -259,18 +263,12 @@ impl SegmentPercentilesCollector {
|
||||
impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
// Swap collector with an empty one to avoid cloning
|
||||
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_metric_result =
|
||||
IntermediateMetricResult::Percentiles(percentiles_collector);
|
||||
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
|
||||
|
||||
results.push(
|
||||
name,
|
||||
@@ -283,33 +281,40 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let percentiles = &mut self.buckets[parent_bucket_id as usize];
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
percentiles.collect(val1);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.percentiles
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.push(PercentilesCollector::new());
|
||||
}
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -616,11 +621,11 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"],
|
||||
5.002829575110705
|
||||
5.0028295751107414
|
||||
);
|
||||
assert_eq!(
|
||||
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"],
|
||||
10.07469668951133
|
||||
10.07469668951144
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -665,8 +670,8 @@ mod tests {
|
||||
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(res["percentiles"]["values"]["1.0"], 5.002829575110705);
|
||||
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951133);
|
||||
assert_eq!(res["percentiles"]["values"]["1.0"], 5.0028295751107414);
|
||||
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951144);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
@@ -8,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
/// are extracted from the aggregated documents.
|
||||
@@ -83,7 +83,7 @@ impl Stats {
|
||||
|
||||
/// Intermediate result of the stats aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateStats {
|
||||
/// The number of extracted values.
|
||||
pub(crate) count: u64,
|
||||
@@ -110,16 +110,6 @@ impl Default for IntermediateStats {
|
||||
}
|
||||
|
||||
impl IntermediateStats {
|
||||
/// Returns the number of values collected.
|
||||
pub fn count(&self) -> u64 {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Returns the sum of all values collected.
|
||||
pub fn sum(&self) -> f64 {
|
||||
self.sum
|
||||
}
|
||||
|
||||
/// Merges the other stats intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateStats) {
|
||||
self.count += other.count;
|
||||
@@ -197,75 +187,75 @@ pub enum StatsType {
|
||||
Percentiles,
|
||||
}
|
||||
|
||||
fn create_collector<const TYPE_ID: u8>(
|
||||
req: &MetricAggReqData,
|
||||
) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(SegmentStatsCollector::<TYPE_ID> {
|
||||
name: req.name.clone(),
|
||||
collecting_for: req.collecting_for,
|
||||
is_number_or_date_type: req.is_number_or_date_type,
|
||||
missing_u64: req.missing_u64,
|
||||
accessor: req.accessor.clone(),
|
||||
buckets: vec![IntermediateStats::default()],
|
||||
})
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
}
|
||||
|
||||
/// Build a concrete `SegmentStatsCollector` depending on the column type.
|
||||
pub(crate) fn build_segment_stats_collector(
|
||||
req: &MetricAggReqData,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match req.field_type {
|
||||
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
|
||||
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
|
||||
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
|
||||
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
|
||||
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
|
||||
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
|
||||
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
|
||||
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
if req_data.is_number_or_date_type {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in req_data.column_block_accessor.iter_vals() {
|
||||
// we ignore the value and simply record that we got something
|
||||
self.stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
|
||||
pub(crate) missing_u64: Option<u64>,
|
||||
pub(crate) accessor: Column<u64>,
|
||||
pub(crate) is_number_or_date_type: bool,
|
||||
pub(crate) buckets: Vec<IntermediateStats>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) collecting_for: StatsType,
|
||||
}
|
||||
|
||||
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
for SegmentStatsCollector<COLUMN_TYPE_ID>
|
||||
{
|
||||
impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = self.name.clone();
|
||||
let req = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let name = req.name.clone();
|
||||
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let stats = self.buckets[parent_bucket_id as usize];
|
||||
let intermediate_metric_result = match self.collecting_for {
|
||||
let intermediate_metric_result = match req.collecting_for {
|
||||
StatsType::Average => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
|
||||
}
|
||||
StatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
|
||||
}
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
|
||||
_ => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Unsupported stats type for stats aggregation: {:?}",
|
||||
self.collecting_for
|
||||
req.collecting_for
|
||||
)))
|
||||
}
|
||||
};
|
||||
@@ -281,67 +271,41 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.stats
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
// TODO: remove once we fetch all values for all bucket ids in one go
|
||||
if docs.len() == 1 && self.missing_u64.is_none() {
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
self.accessor.values_for_doc(docs[0]),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
agg_data.column_block_accessor.iter_vals(),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let required_buckets = (max_bucket as usize) + 1;
|
||||
if self.buckets.len() < required_buckets {
|
||||
self.buckets
|
||||
.resize_with(required_buckets, IntermediateStats::default);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_stats<const COLUMN_TYPE_ID: u8>(
|
||||
stats: &mut IntermediateStats,
|
||||
vals: impl Iterator<Item = u64>,
|
||||
is_number_or_date_type: bool,
|
||||
) -> crate::Result<()> {
|
||||
if is_number_or_date_type {
|
||||
for val in vals {
|
||||
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
|
||||
stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in vals {
|
||||
// we ignore the value and simply record that we got something
|
||||
stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateSum {
|
||||
|
||||
impl IntermediateSum {
|
||||
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
|
||||
@@ -15,11 +15,12 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::{AggregationError, BucketId};
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
// duplicate import removed; already imported above
|
||||
|
||||
/// Contains all information required by the TopHitsSegmentCollector to perform the
|
||||
/// top_hits aggregation on a segment.
|
||||
@@ -471,10 +472,7 @@ impl TopHitsTopNComputer {
|
||||
/// Create a new TopHitsCollector
|
||||
pub fn new(req: &TopHitsAggregationReq) -> Self {
|
||||
Self {
|
||||
top_n: TopNComputer::new_with_comparator(
|
||||
req.size + req.from.unwrap_or(0),
|
||||
ReverseComparator,
|
||||
),
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
req: req.clone(),
|
||||
}
|
||||
}
|
||||
@@ -520,8 +518,7 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
|
||||
num_hits: usize,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -530,29 +527,19 @@ impl TopHitsSegmentCollector {
|
||||
accessor_idx: usize,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
) -> Self {
|
||||
let num_hits = req.size + req.from.unwrap_or(0);
|
||||
Self {
|
||||
num_hits,
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
segment_ordinal,
|
||||
accessor_idx,
|
||||
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
|
||||
}
|
||||
}
|
||||
fn get_top_hits_computer(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
fn into_top_hits_collector(
|
||||
self,
|
||||
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
if parent_bucket_id as usize >= self.buckets.len() {
|
||||
return TopHitsTopNComputer::new(req);
|
||||
}
|
||||
let top_n = std::mem::replace(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
TopNComputer::new(0),
|
||||
);
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
let top_results = top_n.into_vec();
|
||||
let top_results = self.top_n.into_vec();
|
||||
|
||||
for res in top_results {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
@@ -567,24 +554,54 @@ impl TopHitsSegmentCollector {
|
||||
|
||||
top_hits_computer
|
||||
}
|
||||
|
||||
/// TODO add a specialized variant for a single sort field
|
||||
fn collect_with(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessors: &[(Column<u64>, ColumnType)],
|
||||
) -> crate::Result<()> {
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
|
||||
let value_accessors = &req_data.value_accessors;
|
||||
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
|
||||
parent_bucket_id,
|
||||
value_accessors,
|
||||
&req_data.req,
|
||||
));
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(
|
||||
self.into_top_hits_collector(value_accessors, &req_data.req),
|
||||
);
|
||||
results.push(
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -594,54 +611,24 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
/// TODO: Consider a caching layer to reduce the call overhead
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc_id: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let top_n = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
let req = &req_data.req;
|
||||
let accessors = &req_data.accessors;
|
||||
for &doc_id in docs {
|
||||
// TODO: this is terrible, a new vec is allocated for every doc
|
||||
// We can fetch blocks instead
|
||||
// We don't need to store the order for every value
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
}
|
||||
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.buckets.resize(
|
||||
(max_bucket as usize) + 1,
|
||||
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
|
||||
);
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
// TODO: Consider getting fields with the column block accessor.
|
||||
for doc in docs {
|
||||
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -759,7 +746,7 @@ mod tests {
|
||||
],
|
||||
"from": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
@@ -888,7 +875,7 @@ mod tests {
|
||||
"mixed.*",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
|
||||
let collector = AggregationCollector::from_aggs(d, Default::default());
|
||||
|
||||
@@ -133,7 +133,7 @@ mod agg_limits;
|
||||
pub mod agg_req;
|
||||
pub mod agg_result;
|
||||
pub mod bucket;
|
||||
pub(crate) mod cached_sub_aggs;
|
||||
mod buf_collector;
|
||||
mod collector;
|
||||
mod date;
|
||||
mod error;
|
||||
@@ -162,19 +162,6 @@ use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
|
||||
/// A bucket id is a dense identifier for a bucket within an aggregation.
|
||||
/// It is used to index into a Vec that hold per-bucket data.
|
||||
///
|
||||
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
|
||||
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
|
||||
///
|
||||
/// This allows to have a single AggregationCollector instance per aggregation,
|
||||
/// that can handle multiple buckets efficiently.
|
||||
///
|
||||
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
|
||||
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
|
||||
pub type BucketId = u32;
|
||||
|
||||
/// Context parameters for aggregation execution
|
||||
///
|
||||
/// This struct holds shared resources needed during aggregation execution:
|
||||
@@ -348,37 +335,19 @@ impl Display for Key {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
|
||||
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
|
||||
val as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|
||||
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
|
||||
{
|
||||
i64::from_u64(val) as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
|
||||
f64::from_u64(val)
|
||||
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
|
||||
val as f64
|
||||
} else {
|
||||
panic!(
|
||||
"ColumnType ID {} cannot be converted to f64 metric",
|
||||
COLUMN_TYPE_ID
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
|
||||
///
|
||||
/// # Panics
|
||||
/// Only `u64`, `f64`, `date`, and `i64` are supported.
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
|
||||
match field_type {
|
||||
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
|
||||
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
|
||||
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
|
||||
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
|
||||
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
|
||||
_ => panic!("unexpected type {field_type:?}. This should not happen"),
|
||||
ColumnType::U64 => val as f64,
|
||||
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
|
||||
ColumnType::F64 => f64::from_u64(val),
|
||||
ColumnType::Bool => val as f64,
|
||||
_ => {
|
||||
panic!("unexpected type {field_type:?}. This should not happen")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,67 +8,25 @@ use std::fmt::Debug;
|
||||
pub(crate) use super::agg_limits::AggregationLimitsGuard;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Monotonically increasing provider of BucketIds.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BucketIdProvider(u32);
|
||||
impl BucketIdProvider {
|
||||
/// Get the next BucketId.
|
||||
pub fn next_bucket_id(&mut self) -> BucketId {
|
||||
let bucket_id = self.0;
|
||||
self.0 += 1;
|
||||
bucket_id
|
||||
}
|
||||
}
|
||||
|
||||
/// A SegmentAggregationCollector is used to collect aggregation results.
|
||||
pub trait SegmentAggregationCollector: Debug {
|
||||
pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Collect docs for multiple buckets in one call.
|
||||
/// Minimizes dynamic dispatch overhead when collecting many buckets.
|
||||
///
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect_multiple(
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
bucket_ids: &[BucketId],
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
debug_assert_eq!(bucket_ids.len(), docs.len());
|
||||
let mut start = 0;
|
||||
while start < bucket_ids.len() {
|
||||
let bucket_id = bucket_ids[start];
|
||||
let mut end = start + 1;
|
||||
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
|
||||
end += 1;
|
||||
}
|
||||
self.collect(bucket_id, &docs[start..end], agg_data)?;
|
||||
start = end;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the collector for collecting up to BucketId `max_bucket`.
|
||||
/// This is useful so we can split allocation ahead of time of collecting.
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
|
||||
@@ -78,7 +36,26 @@ pub trait SegmentAggregationCollector: Debug {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
|
||||
pub trait CollectorClone {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
|
||||
}
|
||||
|
||||
impl<T> CollectorClone for T
|
||||
where T: 'static + SegmentAggregationCollector + Clone
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn SegmentAggregationCollector> {
|
||||
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
|
||||
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
|
||||
/// and can provide specialized versions instead, that remove some of its overhead.
|
||||
@@ -96,13 +73,12 @@ impl Debug for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
for agg in &mut self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
|
||||
for agg in self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -110,13 +86,23 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.collect(parent_bucket_id, docs, agg_data)?;
|
||||
collector.collect_block(docs, agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -126,15 +112,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,9 +486,9 @@ mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use columnar::Dictionary;
|
||||
use rand::distr::Uniform;
|
||||
use rand::distributions::Uniform;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::{rng, Rng};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::{FacetCollector, FacetCounts};
|
||||
use crate::collector::facet_collector::compress_mapping;
|
||||
@@ -731,7 +731,7 @@ mod tests {
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
let uniform = Uniform::new_inclusive(1, 100_000).unwrap();
|
||||
let uniform = Uniform::new_inclusive(1, 100_000);
|
||||
let mut docs: Vec<TantivyDocument> =
|
||||
vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)]
|
||||
.into_iter()
|
||||
@@ -741,11 +741,14 @@ mod tests {
|
||||
std::iter::repeat_n(doc, count)
|
||||
})
|
||||
.map(|mut doc| {
|
||||
doc.add_facet(facet_field, &format!("/facet/{}", rng().sample(uniform)));
|
||||
doc.add_facet(
|
||||
facet_field,
|
||||
&format!("/facet/{}", thread_rng().sample(uniform)),
|
||||
);
|
||||
doc
|
||||
})
|
||||
.collect();
|
||||
docs[..].shuffle(&mut rng());
|
||||
docs[..].shuffle(&mut thread_rng());
|
||||
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
for doc in docs {
|
||||
@@ -819,8 +822,8 @@ mod tests {
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use test::Bencher;
|
||||
|
||||
use crate::collector::FacetCollector;
|
||||
@@ -843,7 +846,7 @@ mod bench {
|
||||
}
|
||||
}
|
||||
// 40425 docs
|
||||
docs[..].shuffle(&mut rng());
|
||||
docs[..].shuffle(&mut thread_rng());
|
||||
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
for doc in docs {
|
||||
|
||||
@@ -1,50 +1,25 @@
|
||||
mod order;
|
||||
mod sort_by_bytes;
|
||||
mod sort_by_erased_type;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_bytes::SortByBytes;
|
||||
pub use sort_by_erased_type::SortByErasedType;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
|
||||
// By spec, regardless of whether ascending or descending order was requested, in presence of a
|
||||
// tie, we sort by ascending doc id/doc address.
|
||||
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
|
||||
hits: &mut [ComparableDoc<TSortKey, D>],
|
||||
order: Order,
|
||||
) {
|
||||
if order.is_asc() {
|
||||
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
|
||||
} else {
|
||||
hits.sort_by(|l, r| {
|
||||
l.sort_key
|
||||
.cmp(&r.sort_key)
|
||||
.reverse() // This is descending
|
||||
.then(l.doc.cmp(&r.doc))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, QueryParser};
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
|
||||
|
||||
fn make_index() -> crate::Result<Index> {
|
||||
@@ -319,9 +294,11 @@ pub(crate) mod tests {
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByString::for_field("city"), city_order),
|
||||
));
|
||||
let results: Vec<((Score, Option<String>), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
Ok(searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(f, doc)| (f, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
@@ -346,51 +323,6 @@ pub(crate) mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, OwnedValue);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByErasedType::for_field("city"), city_order),
|
||||
));
|
||||
let results: Vec<((Score, OwnedValue), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
@@ -440,10 +372,15 @@ pub(crate) mod tests {
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
let sorted_docs: Vec<_> = if order.is_desc() {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
sort_hits(&mut comparable_docs, order);
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
} else {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
|
||||
@@ -1,116 +1,36 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::MonotonicallyMappableToU64;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::{OwnedValue, Schema};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
match (lhs, rhs) {
|
||||
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
|
||||
(OwnedValue::Null, _) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Greater
|
||||
}
|
||||
}
|
||||
(_, OwnedValue::Null) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
Ordering::Less
|
||||
}
|
||||
}
|
||||
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
|
||||
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
|
||||
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
|
||||
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
|
||||
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
|
||||
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
|
||||
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
|
||||
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
|
||||
if *b < 0 {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
a.cmp(&(*b as u64))
|
||||
}
|
||||
}
|
||||
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
|
||||
if *a < 0 {
|
||||
Ordering::Less
|
||||
} else {
|
||||
(*a as u64).cmp(b)
|
||||
}
|
||||
}
|
||||
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(a, b) => {
|
||||
let ord = a.discriminant_value().cmp(&b.discriminant_value());
|
||||
// If the discriminant is equal, it's because a new type was added, but hasn't been
|
||||
// included in this `match` statement.
|
||||
assert!(
|
||||
ord != Ordering::Equal,
|
||||
"Unimplemented comparison for type of {a:?}, {b:?}"
|
||||
);
|
||||
ord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparator trait defining the order in which documents should be ordered.
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
}
|
||||
|
||||
/// Compare values naturally (e.g. 1 < 2).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first).
|
||||
///
|
||||
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
|
||||
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
|
||||
/// With the natural comparator, the top k collector will return
|
||||
/// the top documents in decreasing order.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalComparator;
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// A (partial) implementation of comparison for OwnedValue.
|
||||
/// Sorts document in reverse order.
|
||||
///
|
||||
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
|
||||
/// mismatched types. The one exception is Null, for which we do define all comparisons.
|
||||
impl Comparator<OwnedValue> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse (e.g. 2 < 1).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first).
|
||||
///
|
||||
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
|
||||
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
|
||||
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
|
||||
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
|
||||
/// first.
|
||||
///
|
||||
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
|
||||
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
|
||||
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
|
||||
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
@@ -123,15 +43,11 @@ where NaturalComparator: Comparator<T>
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse, but treating `None` as lower than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
|
||||
/// (e.g. `[Some(10), Some(20), None]`).
|
||||
/// Sorts document in reverse order, but considers None as having the lowest value.
|
||||
///
|
||||
/// This is usually what is wanted when sorting by a field in an ascending order.
|
||||
/// For instance, in an e-commerce website, if sorting by price ascending,
|
||||
/// the cheapest items would appear first, and items without a price would appear last.
|
||||
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
|
||||
/// cheapest items first, and the items without a price at last.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct ReverseNoneIsLowerComparator;
|
||||
|
||||
@@ -191,84 +107,6 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values naturally, but treating `None` as higher than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first), but with `None` values appearing first
|
||||
/// (e.g. `[None, Some(20), Some(10)]`).
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalNoneIsHigherComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Greater,
|
||||
(Some(_), None) => Ordering::Less,
|
||||
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
|
||||
pub enum ComparatorEnum {
|
||||
@@ -277,10 +115,8 @@ pub enum ComparatorEnum {
|
||||
Natural,
|
||||
/// Reverse order (See [ReverseComparator])
|
||||
Reverse,
|
||||
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
|
||||
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
|
||||
ReverseNoneLower,
|
||||
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
|
||||
NaturalNoneHigher,
|
||||
}
|
||||
|
||||
impl From<Order> for ComparatorEnum {
|
||||
@@ -297,7 +133,6 @@ where
|
||||
ReverseNoneIsLowerComparator: Comparator<T>,
|
||||
NaturalComparator: Comparator<T>,
|
||||
ReverseComparator: Comparator<T>,
|
||||
NaturalNoneIsHigherComparator: Comparator<T>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
@@ -305,7 +140,6 @@ where
|
||||
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,12 +322,11 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
|
||||
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: Clone + 'static + Sync + Send,
|
||||
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
type SegmentComparator = TComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
@@ -513,55 +346,3 @@ where
|
||||
.convert_segment_sort_key(sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::OwnedValue;
|
||||
|
||||
#[test]
|
||||
fn test_natural_none_is_higher() {
|
||||
let comp = NaturalNoneIsHigherComparator;
|
||||
let null = None;
|
||||
let v1 = Some(1_u64);
|
||||
let v2 = Some(2_u64);
|
||||
|
||||
// NaturalNoneIsGreaterComparator logic:
|
||||
// 1. Delegates to NaturalComparator for non-nulls.
|
||||
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
|
||||
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
|
||||
|
||||
// 2. Treats None (Null) as Greater than any value.
|
||||
// compare(None, Some(2)) should be Greater.
|
||||
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
|
||||
|
||||
// compare(Some(1), None) should be Less.
|
||||
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
|
||||
|
||||
// compare(None, None) should be Equal.
|
||||
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_ownedvalue_compare() {
|
||||
let u = OwnedValue::U64(10);
|
||||
let i = OwnedValue::I64(10);
|
||||
let f = OwnedValue::F64(10.0);
|
||||
|
||||
let nc = NaturalComparator;
|
||||
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
|
||||
|
||||
let u2 = OwnedValue::U64(11);
|
||||
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
|
||||
|
||||
let s = OwnedValue::Str("a".to_string());
|
||||
// Str < U64
|
||||
assert_eq!(nc.compare(&s, &u), Ordering::Less);
|
||||
// Str < I64
|
||||
assert_eq!(nc.compare(&s, &i), Ordering::Less);
|
||||
// Str < F64
|
||||
assert_eq!(nc.compare(&s, &f), Ordering::Less);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
use columnar::BytesColumn;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Sort by the first value of a bytes column.
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByBytes {
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl SortByBytes {
|
||||
/// Creates a new sort by bytes sort key computer.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
SortByBytes {
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByBytes {
|
||||
type SortKey = Option<Vec<u8>>;
|
||||
type Child = ByBytesColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
|
||||
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })
|
||||
}
|
||||
}
|
||||
|
||||
/// Segment-level sort key computer for bytes columns.
|
||||
pub struct ByBytesColumnSegmentSortKeyComputer {
|
||||
bytes_column_opt: Option<BytesColumn>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ByBytesColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<Vec<u8>>;
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
let bytes_column = self.bytes_column_opt.as_ref()?;
|
||||
bytes_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<Vec<u8>> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
let term_ord = term_ord_opt?;
|
||||
let bytes_column = self.bytes_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
bytes_column
|
||||
.dictionary()
|
||||
.ord_to_term(term_ord, &mut bytes)
|
||||
.ok()?;
|
||||
Some(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SortByBytes;
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{BytesOptions, Schema, FAST, INDEXED};
|
||||
use crate::{Index, IndexWriter, Order, TantivyDocument};
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_bytes_asc() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bytes_field = schema_builder
|
||||
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
|
||||
let id_field = schema_builder.add_u64_field("id", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// Insert documents with byte values in non-sorted order
|
||||
let test_data: Vec<(u64, Vec<u8>)> = vec![
|
||||
(1, vec![0x02, 0x00]),
|
||||
(2, vec![0x00, 0x10]),
|
||||
(3, vec![0x01, 0x00]),
|
||||
(4, vec![0x00, 0x20]),
|
||||
];
|
||||
|
||||
for (id, bytes) in &test_data {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_u64(id_field, *id);
|
||||
doc.add_bytes(bytes_field, bytes);
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort ascending by bytes
|
||||
let top_docs =
|
||||
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Asc));
|
||||
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
|
||||
|
||||
// Expected order: [0x00,0x10], [0x00,0x20], [0x01,0x00], [0x02,0x00]
|
||||
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
|
||||
assert_eq!(
|
||||
sorted_bytes,
|
||||
vec![
|
||||
Some(vec![0x00, 0x10]),
|
||||
Some(vec![0x00, 0x20]),
|
||||
Some(vec![0x01, 0x00]),
|
||||
Some(vec![0x02, 0x00]),
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_bytes_desc() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bytes_field = schema_builder
|
||||
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
let test_data: Vec<Vec<u8>> = vec![vec![0x00, 0x10], vec![0x02, 0x00], vec![0x01, 0x00]];
|
||||
|
||||
for bytes in &test_data {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_bytes(bytes_field, bytes);
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort descending by bytes
|
||||
let top_docs =
|
||||
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Desc));
|
||||
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
|
||||
|
||||
// Expected order (descending): [0x02,0x00], [0x01,0x00], [0x00,0x10]
|
||||
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
|
||||
assert_eq!(
|
||||
sorted_bytes,
|
||||
vec![
|
||||
Some(vec![0x02, 0x00]),
|
||||
Some(vec![0x01, 0x00]),
|
||||
Some(vec![0x00, 0x10]),
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,430 +0,0 @@
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64};
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
NaturalComparator, SortByBytes, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastFieldNotAvailableError;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DateTime, DocId, Score};
|
||||
|
||||
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
|
||||
///
|
||||
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
|
||||
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
|
||||
/// use a SortKeyComputer implementation with a known-type at compile time.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SortByErasedType {
|
||||
/// Sort by a fast field
|
||||
Field(String),
|
||||
/// Sort by score
|
||||
Score,
|
||||
}
|
||||
|
||||
impl SortByErasedType {
|
||||
/// Creates a new sort key computer which will sort by the given fast field column, with type
|
||||
/// erasure.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
Self::Field(column_name.to_string())
|
||||
}
|
||||
|
||||
/// Creates a new sort key computer which will sort by score, with type erasure.
|
||||
pub fn for_score() -> Self {
|
||||
Self::Score
|
||||
}
|
||||
}
|
||||
|
||||
trait ErasedSegmentSortKeyComputer: Send + Sync {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
|
||||
}
|
||||
|
||||
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
|
||||
inner: C,
|
||||
converter: F,
|
||||
}
|
||||
|
||||
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
where
|
||||
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
|
||||
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
|
||||
{
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let val = self.inner.convert_segment_sort_key(sort_key);
|
||||
(self.converter)(val)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}
|
||||
|
||||
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
|
||||
Some(score_value.to_u64())
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
|
||||
OwnedValue::F64(f64::from_u64(score_value))
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByErasedType {
|
||||
type SortKey = OwnedValue;
|
||||
type Child = ErasedColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
matches!(self, Self::Score)
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
|
||||
Self::Field(column_name) => {
|
||||
let fast_fields = segment_reader.fast_fields();
|
||||
// TODO: We currently double-open the column to avoid relying on the implementation
|
||||
// details of `SortByString` or `SortByStaticFastValue`. Once
|
||||
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
|
||||
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
|
||||
// the column that we open here.
|
||||
let (_column, column_type) =
|
||||
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
|
||||
FastFieldNotAvailableError {
|
||||
field_name: column_name.to_owned(),
|
||||
}
|
||||
})?;
|
||||
|
||||
match column_type {
|
||||
ColumnType::Str => {
|
||||
let computer = SortByString::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<String>| {
|
||||
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::Bytes => {
|
||||
let computer = SortByBytes::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<Vec<u8>>| {
|
||||
val.map(OwnedValue::Bytes).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<u64>| {
|
||||
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<i64>| {
|
||||
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<f64>| {
|
||||
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<bool>| {
|
||||
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<DateTime>| {
|
||||
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
column_type => {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {column_type:?}, which is not supported for \
|
||||
sorting by owned value yet.",
|
||||
column_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}),
|
||||
};
|
||||
Ok(ErasedColumnSegmentSortKeyComputer { inner })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErasedColumnSegmentSortKeyComputer {
|
||||
inner: Box<dyn ErasedSegmentSortKeyComputer>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
|
||||
type SortKey = OwnedValue;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
|
||||
self.inner.convert_segment_sort_key(segment_sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
|
||||
);
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("id"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_string() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(city_field => "tokyo")).unwrap();
|
||||
writer.add_document(doc!(city_field => "austin")).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("city"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Str("austin".to_string()),
|
||||
OwnedValue::Str("tokyo".to_string()),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_bytes() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let data_field = schema_builder.add_bytes_field("data", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x03u8, 0x00]))
|
||||
.unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x01u8, 0x00]))
|
||||
.unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x02u8, 0x00]))
|
||||
.unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort descending (Natural - highest first)
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("data"), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Bytes(vec![0x03, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x02, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x01, 0x00]),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
|
||||
// Sort ascending (ReverseNoneLower - lowest first, nulls last)
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("data"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Bytes(vec![0x01, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x02, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x03, 0x00]),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_reverse() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_score() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let body_field = schema_builder.add_text_field("body", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(body_field => "a a")).unwrap();
|
||||
writer.add_document(doc!(body_field => "a")).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
|
||||
let query = query_parser.parse_query("a").unwrap();
|
||||
|
||||
// Sort by score descending (Natural)
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] > values[1]);
|
||||
|
||||
// Sort by score ascending (ReverseNoneLower)
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_score(),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] < values[1]);
|
||||
}
|
||||
}
|
||||
@@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore {
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type SegmentSortKey = Score;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
|
||||
@@ -34,7 +34,9 @@ impl<T: FastValue> SortByStaticFastValue<T> {
|
||||
|
||||
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
|
||||
type Child = SortByFastValueSegmentSortKeyComputer<T>;
|
||||
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
|
||||
@@ -82,8 +84,8 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
|
||||
|
||||
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
|
||||
@@ -30,7 +30,9 @@ impl SortByString {
|
||||
|
||||
impl SortKeyComputer for SortByString {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type Child = ByStringColumnSegmentSortKeyComputer;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
@@ -48,8 +50,8 @@ pub struct ByStringColumnSegmentSortKeyComputer {
|
||||
|
||||
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
@@ -58,8 +60,6 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
let term_ord = term_ord_opt?;
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
@@ -12,21 +12,13 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
/// It is the segment local version of the [`SortKeyComputer`].
|
||||
pub trait SegmentSortKeyComputer: 'static {
|
||||
/// The final score being emitted.
|
||||
type SortKey: 'static + Send + Sync + Clone;
|
||||
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
|
||||
|
||||
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
|
||||
///
|
||||
/// It is typically small like a `u64`, and is meant to be converted
|
||||
/// to the final score at the end of the collection of the segment.
|
||||
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Comparator type.
|
||||
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
|
||||
|
||||
/// Returns the segment sort key comparator.
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
Self::SegmentComparator::default()
|
||||
}
|
||||
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Computes the sort key for the given document and score.
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
|
||||
@@ -55,7 +47,7 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.segment_comparator().compare(left, right)
|
||||
NaturalComparator.compare(left, right)
|
||||
}
|
||||
|
||||
/// Implementing this method makes it possible to avoid computing
|
||||
@@ -89,7 +81,7 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
/// the sort key at a segment scale.
|
||||
pub trait SortKeyComputer: Sync {
|
||||
/// The sort key type.
|
||||
type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
|
||||
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
|
||||
/// Type of the associated [`SegmentSortKeyComputer`].
|
||||
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
|
||||
/// Comparator type.
|
||||
@@ -144,7 +136,10 @@ where
|
||||
HeadSortKeyComputer: SortKeyComputer,
|
||||
TailSortKeyComputer: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
|
||||
type SortKey = (
|
||||
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
);
|
||||
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
|
||||
|
||||
type Comparator = (
|
||||
@@ -193,11 +188,6 @@ where
|
||||
TailSegmentSortKeyComputer::SegmentSortKey,
|
||||
);
|
||||
|
||||
type SegmentComparator = (
|
||||
HeadSegmentSortKeyComputer::SegmentComparator,
|
||||
TailSegmentSortKeyComputer::SegmentComparator,
|
||||
);
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
@@ -279,12 +269,11 @@ impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
|
||||
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
|
||||
where
|
||||
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
|
||||
PreviousScore: 'static + Clone + Send + Sync,
|
||||
NewScore: 'static + Clone + Send + Sync,
|
||||
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
NewScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
{
|
||||
type SortKey = NewScore;
|
||||
type SegmentSortKey = T::SegmentSortKey;
|
||||
type SegmentComparator = T::SegmentComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_key_computer.segment_sort_key(doc, score)
|
||||
@@ -474,7 +463,6 @@ where
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
|
||||
(self)(doc)
|
||||
|
||||
@@ -160,7 +160,7 @@ mod tests {
|
||||
expected: &[(crate::Score, usize)],
|
||||
) {
|
||||
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
|
||||
vals.shuffle(&mut rand::rng());
|
||||
vals.shuffle(&mut rand::thread_rng());
|
||||
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
|
||||
assert_eq!(&vals_merged, expected);
|
||||
}
|
||||
|
||||
@@ -1,22 +1,64 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
|
||||
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D> {
|
||||
/// It guarantees stable sorting: in case of a tie on the feature, the document
|
||||
/// address is used.
|
||||
///
|
||||
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
|
||||
/// should be reversed, which is useful for achieving for example largest-first
|
||||
/// semantics without having to wrap the feature in a `Reverse`.
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is a type which can be compared with a `Comparator<T>`.
|
||||
/// is any type that implements `PartialOrd`.
|
||||
pub sort_key: T,
|
||||
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
|
||||
/// The document address. In practice, this is any
|
||||
/// type that implements `PartialOrd`, and is guaranteed
|
||||
/// to be unique for each document.
|
||||
pub doc: D,
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
|
||||
for ComparableDoc<T, D, R>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("ComparableDoc")
|
||||
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
|
||||
.field("feature", &self.sort_key)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
|
||||
#[inline]
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let by_feature = self
|
||||
.sort_key
|
||||
.partial_cmp(&other.sort_key)
|
||||
.map(|ord| if R { ord.reverse() } else { ord })
|
||||
.unwrap_or(Ordering::Equal);
|
||||
|
||||
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
|
||||
|
||||
// In case of a tie on the feature, we sort by ascending
|
||||
// `DocAddress` in order to ensure a stable sorting of the
|
||||
// documents.
|
||||
by_feature.then_with(lazy_by_doc_address)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other) == Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user