Compare commits

..

1 Commits

Author SHA1 Message Date
Pascal Seitz
b4139bc353 chore: Release 2025-08-20 18:21:53 +08:00
216 changed files with 5049 additions and 16263 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -39,11 +39,11 @@ jobs:
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
- name: Check Stable Compilation
run: cargo build --all-features
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
@@ -59,10 +59,10 @@ jobs:
strategy:
matrix:
features:
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
- { label: "none", flags: "" }
features: [
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
]
name: test-${{ matrix.features.label}}
@@ -80,21 +80,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
cargo +stable nextest run --lib --no-default-features --verbose --workspace
else
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
fi
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
- name: Run doctests
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
echo "no doctest for no feature flag"
else
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
fi
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace

View File

@@ -14,18 +14,6 @@ Tantivy 0.25
- Support mixed field types in query parser [#2676](https://github.com/quickwit-oss/tantivy/pull/2676)(@trinity-1686a)
- Add per-field size details [#2679](https://github.com/quickwit-oss/tantivy/pull/2679)(@fulmicoton)
Tantivy 0.24.2
================================
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
for `Order::Asc`
Tantivy 0.24.1
================================
- Fix: bump required rust version to 1.81
Tantivy 0.24
================================
Tantivy 0.24 will be backwards compatible with indices created with v0.22 and v0.21. The new minimum rust version will be 1.75. Tantivy 0.23 will be skipped.
@@ -78,7 +66,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
- **Performance/Memory**
- **Performace/Memory**
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)
@@ -108,14 +96,6 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- Fix trait bound of StoreReader::iter [#2360](https://github.com/quickwit-oss/tantivy/pull/2360)(@adamreichold)
- remove read_postings_no_deletes [#2526](https://github.com/quickwit-oss/tantivy/pull/2526)(@PSeitz)
Tantivy 0.22.1
================================
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
for `Order::Asc`
Tantivy 0.22
================================

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.26.0"
version = "0.25.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]
@@ -37,9 +37,9 @@ fs4 = { version = "0.13.1", optional = true }
levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = { version = "1.2.0", optional = true }
rust-stemmers = "1.2.0"
downcast-rs = "2.0.1"
bitpacking = { version = "0.9.3", default-features = false, features = [
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
] }
census = "0.4.2"
@@ -69,18 +69,17 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.2"
binggan = "0.14.0"
rand = "0.8.5"
maplit = "1.0.2"
matches = "0.1.9"
pretty_assertions = "1.2.1"
proptest = "1.7.0"
proptest = "1.0.0"
test-log = "0.2.10"
futures = "0.3.21"
paste = "1.0.11"
@@ -88,7 +87,7 @@ more-asserts = "0.3.1"
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
@@ -113,8 +112,7 @@ debug-assertions = true
overflow-checks = true
[features]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
stemmer = ["rust-stemmers"]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
mmap = ["fs4", "tempfile", "memmap2"]
stopwords = []
@@ -169,23 +167,3 @@ harness = false
[[bench]]
name = "agg_bench"
harness = false
[[bench]]
name = "exists_json"
harness = false
[[bench]]
name = "range_query"
harness = false
[[bench]]
name = "and_or_queries"
harness = false
[[bench]]
name = "range_queries"
harness = false
[[bench]]
name = "bool_queries_with_range"
harness = false

View File

@@ -23,6 +23,8 @@ performance for different types of queries/collections.
Your mileage WILL vary depending on the nature of queries and their load.
<img src="doc/assets/images/searchbenchmark.png">
Details about the benchmark can be found at this [repository](https://github.com/quickwit-oss/search-benchmark-game).
## Features
@@ -123,7 +125,6 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -1,4 +1,4 @@
# Releasing a new Tantivy Version
# Release a new Tantivy Version
## Steps
@@ -10,29 +10,12 @@
6. Set git tag with new version
[`cargo-release`](https://github.com/crate-ci/cargo-release) will help us with steps 1-5:
In conjucation with `cargo-release` Steps 1-4 (I'm not sure if the change detection works):
Set new packages to version 0.0.0
Replace prev-tag-name
```bash
cargo release --workspace --no-publish -v --prev-tag-name 0.24 --push-remote origin minor --no-tag
cargo release --workspace --no-publish -v --prev-tag-name 0.19 --push-remote origin minor --no-tag --execute
```
`no-tag` or it will create tags for all the subpackages
cargo release will _not_ ignore unchanged packages, but it will print warnings for them.
e.g. "warning: updating ownedbytes to 0.10.0 despite no changes made since tag 0.24"
We need to manually ignore these unchanged packages
```bash
cargo release --workspace --no-publish -v --prev-tag-name 0.24 --push-remote origin minor --no-tag --exclude tokenizer-api
```
Add `--execute` to actually publish the packages, otherwise it will only print the commands that would be run.
### Tag Version
```bash
git tag 0.25.0
git push upstream tag 0.25.0
```
no-tag or it will create tags for all the subpackages

View File

@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
remove fast field reader
find a way to unify the two DateTime.
re-add type check in the filter wrapper
readd type check in the filter wrapper
add unit test on columnar list columns.

View File

@@ -1,6 +1,5 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -54,41 +53,26 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
register!(group, filter_agg_all_query_count_agg);
register!(group, filter_agg_term_query_count_agg);
register!(group, filter_agg_all_query_with_sub_aggs);
register!(group, filter_agg_term_query_with_sub_aggs);
group.run();
}
@@ -139,12 +123,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
});
execute_agg(index, agg_req);
}
@@ -159,10 +143,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
fn terms_few_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"terms": { "field": "text_few_terms" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -175,20 +159,13 @@ fn terms_status_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
fn terms_few(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
"my_texts": { "terms": { "field": "text_few_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
fn terms_many(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -236,72 +213,6 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -357,7 +268,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_status(index: &Index) {
fn range_agg_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -372,7 +283,7 @@ fn range_agg_with_term_agg_status(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } },
"my_texts": { "terms": { "field": "text_few_terms" } },
}
},
});
@@ -428,17 +339,6 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn avg_and_range_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"rangef64": {
@@ -478,13 +378,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -493,47 +386,20 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -543,25 +409,15 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
text_field_few_terms => "cool",
text_field_few_terms => "cool",
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -586,10 +442,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -606,61 +460,3 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
Ok(index)
}
// Filter aggregation benchmarks
fn filter_agg_all_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
}
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
}
}
}
});
execute_agg(index, agg_req);
}

View File

@@ -1,218 +0,0 @@
// Benchmarks boolean conjunction queries using binggan.
//
// Whats measured:
// - Or and And queries with varying selectivity (only `Term` queries for now on leafs)
// - Nested AND/OR combinations (on multiple fields)
// - No-scoring path using the Count collector (focus on iterator/skip performance)
// - Top-K retrieval (k=10) using the TopDocs collector
//
// Corpus model:
// - Synthetic docs; each token a/b/c is independently included per doc
// - If none of a/b/c are included, emit a neutral filler token to keep doc length similar
//
// Notes:
// - After optimization, when scoring is disabled Tantivy reads doc-only postings
// (IndexRecordOption::Basic), avoiding frequency decoding overhead.
// - This bench isolates boolean iteration speed and intersection/union cost.
// - Use `cargo bench --bench boolean_conjunction` to run.
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
/// Build a single index containing both fields (title, body) and
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
/// - multi_field: QueryParser defaults to ["title", "body"]
fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (BenchIndex, BenchIndex) {
// Unified schema (two text fields)
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_body = schema_builder.add_text_field("body", TEXT);
let f_score = schema_builder.add_u64_field("score", FAST);
let f_score2 = schema_builder.add_u64_field("score2", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
// Populate: spread each present token 90/10 to body/title
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
if rng.gen_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.gen_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.gen_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");
}
}
if title_tokens.is_empty() && body_tokens.is_empty() {
body_tokens.push("z");
}
writer
.add_document(doc!(
f_title=>title_tokens.join(" "),
f_body=>body_tokens.join(" "),
f_score=>score,
f_score2=>score2,
))
.unwrap();
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build two query parsers with different default fields.
let qp_single = QueryParser::for_index(&index, vec![f_body]);
let qp_multi = QueryParser::for_index(&index, vec![f_title, f_body]);
let single_view = BenchIndex {
index: index.clone(),
searcher: searcher.clone(),
query_parser: qp_single,
};
let multi_view = BenchIndex {
index,
searcher,
query_parser: qp_multi,
};
(single_view, multi_view)
}
fn main() {
// Prepare corpora with varying selectivity. Build one index per corpus
// and derive two views (single-field vs multi-field) from it.
let scenarios = vec![
(
"N=1M, p(a)=5%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.05,
0.01,
0.15,
),
(
"N=1M, p(a)=1%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.01,
0.01,
0.15,
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
"top10_by_ff",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by((
SortByStaticFastValue::<u64>::for_field("score"),
SortByStaticFastValue::<u64>::for_field("score2"),
)),
"top10_by_2ff",
);
}
group.run();
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
}

View File

@@ -1,288 +0,0 @@
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
// Unified schema
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.gen_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.gen_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.gen_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.gen_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build query parser for title field
let qp_title = QueryParser::for_index(&index, vec![f_title]);
BenchIndex {
index,
searcher,
query_parser: qp_title,
}
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
0,
9,
),
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
990,
999,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
0,
9,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
9_999_990,
9_999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define all four field types
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
// Define the three terms we want to test with
let terms = ["a", "b", "z"];
// Generate all combinations of terms and field names
let mut queries = Vec::new();
for &term in &terms {
for &field_name in &field_names {
let query_str = format!(
"{} AND {}:[{} TO {}]",
term, field_name, range_low, range_high
);
queries.push((query_str, field_name.to_string()));
}
}
let query_str = format!(
"{}:[{} TO {}] AND {}:[{} TO {}]",
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
);
queries.push((query_str, "num_asc_fast".to_string()));
// Run all benchmark tasks for each query and its corresponding field name
for (query_str, field_name) in queries {
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
}
group.run();
}
}
/// Run all benchmark tasks for a given query string and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
field_name: &str,
) {
// Test count
add_bench_task(bench_group, bench_index, query_str, Count, "count");
// Test all results
add_bench_task(
bench_group,
bench_index,
query_str,
DocSetCollector,
"all results",
);
// Test top 100 by the field (if it's a FAST field)
if field_name.ends_with("_fast") {
// Ascending order
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
&collector_name,
);
}
// Descending order
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
&collector_name,
);
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &self.collector).unwrap();
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
*count
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(top_docs) =
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
{
doc_set.len()
} else {
eprintln!(
"Unknown collector result type: {:?}",
std::any::type_name::<C::Fruit>()
);
0
}
}
}

View File

@@ -1,69 +0,0 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use serde_json::json;
use tantivy::collector::Count;
use tantivy::query::ExistsQuery;
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
let doc_count: usize = 500_000;
let subfield_counts: &[usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 16, 256, 4096, 65536, 262144];
let indices: Vec<(String, Index)> = subfield_counts
.iter()
.map(|&sub_fields| {
(
format!("subfields={sub_fields}"),
build_index_with_json_subfields(doc_count, sub_fields),
)
})
.collect();
let mut group = InputGroup::new_with_inputs(indices);
group.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
group.config().num_iter_group = Some(1);
group.config().num_iter_bench = Some(1);
group.register("exists_json", exists_json_union);
group.run();
}
fn exists_json_union(index: &Index) {
let reader = index.reader().expect("reader");
let searcher = reader.searcher();
let query = ExistsQuery::new("json".to_string(), true);
let count = searcher.search(&query, &Count).expect("exists search");
// Prevents optimizer from eliding the search
black_box(count);
}
fn build_index_with_json_subfields(num_docs: usize, num_subfields: usize) -> Index {
// Schema: single JSON field stored as FAST to support ExistsQuery.
let mut schema_builder = Schema::builder();
let json_field = schema_builder.add_json_field("json", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_from_tempdir(schema).expect("create index");
{
let mut index_writer = index
.writer_with_num_threads(1, 200_000_000)
.expect("writer");
for i in 0..num_docs {
let sub = i % num_subfields;
// Only one subpath set per document; rotate subpaths so that
// no single subpath is full, but the union covers all docs.
let v = json!({ format!("field_{sub}"): i as u64 });
index_writer
.add_document(doc!(json_field => v))
.expect("add_document");
}
index_writer.commit().expect("commit");
}
index
}

View File

@@ -1,365 +0,0 @@
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector, TopDocs};
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, FAST, INDEXED};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with fast fields only
let mut schema_builder = Schema::builder();
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
// Dense distribution - random values in small range (0-999)
(
"dense_values_search_low_value_range".to_string(),
10_000_000,
"dense",
0,
9,
),
(
"dense_values_search_high_value_range".to_string(),
10_000_000,
"dense",
990,
999,
),
(
"dense_values_search_out_of_range".to_string(),
10_000_000,
"dense",
1000,
1002,
),
(
"sparse_values_search_low_value_range".to_string(),
10_000_000,
"sparse",
0,
9,
),
(
"sparse_values_search_high_value_range".to_string(),
10_000_000,
"sparse",
9_999_990,
9_999_999,
),
(
"sparse_values_search_out_of_range".to_string(),
10_000_000,
"sparse",
10_000_000,
10_000_002,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define fast field types
let field_names = ["num_rand_fast", "num_asc_fast"];
// Generate range queries for fast fields
for &field_name in &field_names {
// Create the range query
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
let lower_term = Term::from_field_u64(field, range_low);
let upper_term = Term::from_field_u64(field, range_high);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(
&mut group,
&bench_index,
query,
field_name,
range_low,
range_high,
);
}
group.run();
}
}
/// Run all benchmark tasks for a given range query and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
field_name: &str,
range_low: u64,
range_high: u64,
) {
// Test count
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
"count",
field_name,
range_low,
range_high,
);
// Test top 100 by the field (ascending order)
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_asc(
bench_group,
bench_index,
query.clone(),
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
// Test top 100 by the field (descending order)
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_desc(
bench_group,
bench_index,
query,
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_asc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100AscSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_desc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100DescSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct Top100AscSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100AscSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}
struct Top100DescSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100DescSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}

View File

@@ -1,260 +0,0 @@
use std::fmt::Display;
use std::net::Ipv6Addr;
use std::ops::RangeInclusive;
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use columnar::MonotonicallyMappableToU128;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
bench_range_query();
}
fn bench_range_query() {
let index = get_index_0_to_100();
let mut runner = BenchRunner::new();
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
runner.set_name("range_query on u64");
let field_name_and_descr: Vec<_> = vec![
("id", "Single Valued Range Field"),
("ids", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent()),
("10_percent", get_10_percent()),
("1_percent", get_1_percent()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
runner.set_name("range_query on ip");
let field_name_and_descr: Vec<_> = vec![
("ip", "Single Valued Range Field"),
("ips", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent_ip()),
("10_percent", get_10_percent_ip()),
("1_percent", get_1_percent_ip()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
}
fn test_range<T: Display>(
runner: &mut BenchRunner,
index: &Index,
field_name_and_descr: &[(&str, &str)],
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
) {
for (field, suffix) in field_name_and_descr {
let term_num_hits = vec![
("", ""),
("1_percent", "veryfew"),
("10_percent", "few"),
("90_percent", "most"),
];
let mut group = runner.new_group();
group.set_name(suffix);
// all intersect combinations
for (range_name, range) in &range_num_hits {
for (term_name, term) in &term_num_hits {
let index = &index;
let test_name = if term_name.is_empty() {
format!("id_range_hit_{}", range_name)
} else {
format!(
"id_range_hit_{}_intersect_with_term_{}",
range_name, term_name
)
};
group.register(test_name, move |_| {
let query = if term_name.is_empty() {
"".to_string()
} else {
format!("AND id_name:{}", term)
};
black_box(execute_query(field, range, &query, index));
});
}
}
group.run();
}
}
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"most".to_string() // 90%
};
Doc {
id_name,
id: rng.gen_range(0..100),
// Multiply by 1000, so that we create most buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
#[derive(Clone, Debug)]
pub struct Doc {
pub id_name: String,
pub id: u64,
pub ip: Ipv6Addr,
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
let ids_u64_field =
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
let ids_f64_field = schema_builder.add_f64_field(
"ids_f64",
NumericOptions::default().set_fast().set_indexed(),
);
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
let ids_i64_field = schema_builder.add_i64_field(
"ids_i64",
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ids_i64_field => doc.id as i64,
ids_i64_field => doc.id as i64,
ids_f64_field => doc.id as f64,
ids_f64_field => doc.id as f64,
ids_u64_field => doc.id,
ids_u64_field => doc.id,
id_u64_field => doc.id,
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
struct NumHits {
count: usize,
}
impl OutputValue for NumHits {
fn column_title() -> &'static str {
"NumHits"
}
fn format(&self) -> Option<String> {
Some(self.count.to_string())
}
}
fn execute_query<T: Display>(
field: &str,
id_range: &RangeInclusive<T>,
suffix: &str,
index: &Index,
) -> NumHits {
let gen_query_inclusive = |from: &T, to: &T| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
execute_query_(&query, index)
}
fn execute_query_(query: &str, index: &Index) -> NumHits {
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let num_hits = searcher
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap()
.1;
NumHits { count: num_hits }
}

View File

@@ -48,7 +48,7 @@ impl BitPacker {
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
if self.mini_buffer_written > 0 {
let num_bytes = self.mini_buffer_written.div_ceil(8);
let num_bytes = (self.mini_buffer_written + 7) / 8;
let bytes = self.mini_buffer.to_le_bytes();
output.write_all(&bytes[..num_bytes])?;
self.mini_buffer_written = 0;
@@ -138,7 +138,7 @@ impl BitUnpacker {
// We use `usize` here to avoid overflow issues.
let end_bit_read = (end_idx as usize) * self.num_bits;
let end_byte_read = end_bit_read.div_ceil(8);
let end_byte_read = (end_bit_read + 7) / 8;
assert!(
end_byte_read <= data.len(),
"Requested index is out of bounds."
@@ -258,7 +258,7 @@ mod test {
bitpacker.write(val, num_bits, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
let bitunpacker = BitUnpacker::new(num_bits);
(bitunpacker, vals, data)
}
@@ -304,7 +304,7 @@ mod test {
bitpacker.write(val, num_bits, &mut buffer).unwrap();
}
bitpacker.flush(&mut buffer).unwrap();
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
let bitunpacker = BitUnpacker::new(num_bits);
let max_val = if num_bits == 64 {
u64::MAX

View File

@@ -140,10 +140,10 @@ impl BlockedBitpacker {
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
// todo performance: we could decompress a whole block and cache it instead
let bitpacked_elems = self.offset_and_bits.len() * BLOCK_SIZE;
(0..bitpacked_elems)
let iter = (0..bitpacked_elems)
.map(move |idx| self.get(idx))
.chain(self.buffer.iter().cloned())
.chain(self.buffer.iter().cloned());
iter
}
}

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,19 +66,17 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
unsafe { output_tail.offset_from(output) as usize }
output_tail.offset_from(output) as usize
}
#[inline]
@@ -94,7 +92,8 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
}
union U8x32 {

View File

@@ -73,7 +73,7 @@ The crate introduces the following concepts.
`Columnar` is an equivalent of a dataframe.
It maps `column_key` to `Column`.
A `Column<T>` associates a `RowId` (u32) to any
A `Column<T>` asssociates a `RowId` (u32) to any
number of values.
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.

View File

@@ -89,6 +89,13 @@ fn main() {
black_box(sum);
});
group.register("first_block_fetch", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();
column.first_vals(&fetch_docids, &mut block);
black_box(block[0]);
});
group.register("first_block_single_calls", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();

View File

@@ -29,20 +29,12 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -85,8 +85,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
#[inline]
pub fn first(&self, doc_id: DocId) -> Option<T> {
self.values_for_doc(doc_id).next()
pub fn first(&self, row_id: RowId) -> Option<T> {
self.values_for_doc(row_id).next()
}
/// Load the first value for each docid in the provided slice.
@@ -131,8 +131,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
}
/// Get an iterator over the values for the provided docid.
#[inline]
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
self.index
.value_row_ids(doc_id)
@@ -160,6 +158,15 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
.select_batch_in_place(selected_docid_range.start, doc_ids);
}
/// Fills the output vector with the (possibly multiple values that are associated_with
/// `row_id`.
///
/// This method clears the `output` vector.
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
output.clear();
output.extend(self.values_for_doc(row_id));
}
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
Arc::new(FirstValueWithDefault {
column: self,

View File

@@ -56,7 +56,7 @@ fn get_doc_ids_with_values<'a>(
ColumnIndex::Full => Box::new(doc_range),
ColumnIndex::Optional(optional_index) => Box::new(
optional_index
.iter_non_null_docs()
.iter_docs()
.map(move |row| row + doc_range.start),
),
ColumnIndex::Multivalued(multivalued_index) => match multivalued_index {
@@ -73,7 +73,7 @@ fn get_doc_ids_with_values<'a>(
MultiValueIndex::MultiValueIndexV2(multivalued_index) => Box::new(
multivalued_index
.optional_index
.iter_non_null_docs()
.iter_docs()
.map(move |row| row + doc_range.start),
),
},
@@ -105,11 +105,10 @@ fn get_num_values_iterator<'a>(
) -> Box<dyn Iterator<Item = u32> + 'a> {
match column_index {
ColumnIndex::Empty { .. } => Box::new(std::iter::empty()),
ColumnIndex::Full => Box::new(std::iter::repeat_n(1u32, num_docs as usize)),
ColumnIndex::Optional(optional_index) => Box::new(std::iter::repeat_n(
1u32,
optional_index.num_non_nulls() as usize,
)),
ColumnIndex::Full => Box::new(std::iter::repeat(1u32).take(num_docs as usize)),
ColumnIndex::Optional(optional_index) => {
Box::new(std::iter::repeat(1u32).take(optional_index.num_non_nulls() as usize))
}
ColumnIndex::Multivalued(multivalued_index) => Box::new(
multivalued_index
.get_start_index_column()
@@ -178,7 +177,7 @@ impl<'a> Iterable<RowId> for StackedOptionalIndex<'a> {
ColumnIndex::Full => Box::new(columnar_row_range),
ColumnIndex::Optional(optional_index) => Box::new(
optional_index
.iter_non_null_docs()
.iter_docs()
.map(move |row_id: RowId| columnar_row_range.start + row_id),
),
ColumnIndex::Multivalued(_) => {

View File

@@ -215,32 +215,6 @@ impl MultiValueIndex {
}
}
/// Returns an iterator over document ids that have at least one value.
pub fn iter_non_null_docs(&self) -> Box<dyn Iterator<Item = DocId> + '_> {
match self {
MultiValueIndex::MultiValueIndexV1(idx) => {
let mut doc: DocId = 0u32;
let num_docs = idx.num_docs();
Box::new(std::iter::from_fn(move || {
// This is not the most efficient way to do this, but it's legacy code.
while doc < num_docs {
let cur = doc;
doc += 1;
let start = idx.start_index_column.get_val(cur);
let end = idx.start_index_column.get_val(cur + 1);
if end > start {
return Some(cur);
}
}
None
}))
}
MultiValueIndex::MultiValueIndexV2(idx) => {
Box::new(idx.optional_index.iter_non_null_docs())
}
}
}
/// Converts a list of ranks (row ids of values) in a 1:n index to the corresponding list of
/// docids. Positions are converted inplace to docids.
///

View File

@@ -1,4 +1,4 @@
use std::io;
use std::io::{self, Write};
use std::sync::Arc;
mod set;
@@ -11,7 +11,7 @@ use set_block::{
};
use crate::iterable::Iterable;
use crate::{DocId, RowId};
use crate::{DocId, InvalidData, RowId};
/// The threshold for for number of elements after which we switch to dense block encoding.
///
@@ -88,7 +88,7 @@ pub struct OptionalIndex {
impl Iterable<u32> for &OptionalIndex {
fn boxed_iter(&self) -> Box<dyn Iterator<Item = u32> + '_> {
Box::new(self.iter_non_null_docs())
Box::new(self.iter_docs())
}
}
@@ -280,9 +280,8 @@ impl OptionalIndex {
self.num_non_null_docs
}
pub fn iter_non_null_docs(&self) -> impl Iterator<Item = RowId> + '_ {
// TODO optimize. We could iterate over the blocks directly.
// We use the dense value ids and retrieve the doc ids via select.
pub fn iter_docs(&self) -> impl Iterator<Item = RowId> + '_ {
// TODO optimize
let mut select_batch = self.select_cursor();
(0..self.num_non_null_docs).map(move |rank| select_batch.select(rank))
}
@@ -335,6 +334,38 @@ enum Block<'a> {
Sparse(SparseBlock<'a>),
}
#[derive(Debug, Copy, Clone)]
enum OptionalIndexCodec {
Dense = 0,
Sparse = 1,
}
impl OptionalIndexCodec {
fn to_code(self) -> u8 {
self as u8
}
fn try_from_code(code: u8) -> Result<Self, InvalidData> {
match code {
0 => Ok(Self::Dense),
1 => Ok(Self::Sparse),
_ => Err(InvalidData),
}
}
}
impl BinarySerializable for OptionalIndexCodec {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&[self.to_code()])
}
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let optional_codec_code = u8::deserialize(reader)?;
let optional_codec = Self::try_from_code(optional_codec_code)?;
Ok(optional_codec)
}
}
fn serialize_optional_index_block(block_els: &[u16], out: &mut impl io::Write) -> io::Result<()> {
let is_sparse = is_sparse(block_els.len() as u32);
if is_sparse {

View File

@@ -164,11 +164,7 @@ fn test_optional_index_large() {
fn test_optional_index_iter_aux(row_ids: &[RowId], num_rows: RowId) {
let optional_index = OptionalIndex::for_test(num_rows, row_ids);
assert_eq!(optional_index.num_docs(), num_rows);
assert!(
optional_index
.iter_non_null_docs()
.eq(row_ids.iter().copied())
);
assert!(optional_index.iter_docs().eq(row_ids.iter().copied()));
}
#[test]

View File

@@ -1,7 +1,7 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
/// Monotonic maps a value to u128 value space
/// Montonic maps a value to u128 value space
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
/// Converts a value to u128.

View File

@@ -185,10 +185,10 @@ impl CompactSpaceBuilder {
let mut covered_space = Vec::with_capacity(self.blanks.len());
// beginning of the blanks
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start)
&& *first_blank_start != 0
{
covered_space.push(0..=first_blank_start - 1);
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start) {
if *first_blank_start != 0 {
covered_space.push(0..=first_blank_start - 1);
}
}
// Between the blanks
@@ -202,10 +202,10 @@ impl CompactSpaceBuilder {
covered_space.extend(between_blanks);
// end of the blanks
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end)
&& *last_blank_end != u128::MAX
{
covered_space.push(last_blank_end + 1..=u128::MAX);
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end) {
if *last_blank_end != u128::MAX {
covered_space.push(last_blank_end + 1..=u128::MAX);
}
}
if covered_space.is_empty() {

View File

@@ -41,6 +41,12 @@ fn transform_range_before_linear_transformation(
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
@@ -99,7 +105,7 @@ impl ColumnCodecEstimator for BitpackedCodecEstimator {
fn estimate(&self, stats: &ColumnStats) -> Option<u64> {
let num_bits_per_value = num_bits(stats);
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64)).div_ceil(8))
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64) + 7) / 8)
}
fn serialize(

View File

@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
const MID_POINT: u64 = (1u64 << 32) - 1u64;
/// `Line` describes a line function `y: ax + b` using integer
/// arithmetic.
/// arithmetics.
///
/// The slope is in fact a decimal split into a 32 bit integer value,
/// and a 32-bit decimal value.
@@ -94,7 +94,7 @@ impl Line {
// `(i, ys[])`.
//
// The best intercept therefore has the form
// `y[i] - line.eval(i)` (using wrapping arithmetic).
// `y[i] - line.eval(i)` (using wrapping arithmetics).
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
// and our task is just to pick the one that minimizes our error.
//

View File

@@ -117,7 +117,7 @@ impl ColumnCodecEstimator for LinearCodecEstimator {
Some(
stats.num_bytes()
+ linear_params.num_bytes()
+ (num_bits as u64 * stats.num_rows as u64).div_ceil(8),
+ (num_bits as u64 * stats.num_rows as u64 + 7) / 8,
)
}

View File

@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
) -> io::Result<()>;
}
/// A column codec describes a column serialization format.
/// A column codec describes a colunm serialization format.
pub trait ColumnCodec<T: PartialOrd = u64> {
/// Specialized `ColumnValues` type.
type ColumnValues: ColumnValues<T> + 'static;

View File

@@ -367,7 +367,7 @@ fn is_empty_after_merge(
ColumnIndex::Empty { .. } => true,
ColumnIndex::Full => alive_bitset.len() == 0,
ColumnIndex::Optional(optional_index) => {
for doc in optional_index.iter_non_null_docs() {
for doc in optional_index.iter_docs() {
if alive_bitset.contains(doc) {
return false;
}

View File

@@ -244,7 +244,7 @@ impl SymbolValue for UnorderedId {
fn compute_num_bytes_for_u64(val: u64) -> usize {
let msb = (64u32 - val.leading_zeros()) as usize;
msb.div_ceil(8)
(msb + 7) / 8
}
fn encode_zig_zag(n: i64) -> u64 {

View File

@@ -3,8 +3,7 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, OwnedBytes};
use serde::{Deserialize, Serialize};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -318,89 +317,10 @@ impl DynamicColumnHandle {
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -48,7 +48,7 @@ pub use columnar::{
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;

View File

@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect();
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
}
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
panic!();
};
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect();
assert_eq!(
&vals,
&[
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
panic!();
};
let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(str_col.num_rows(), 5);
let mut term_buffer = String::new();
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
panic!();
};
let index: Vec<Option<u64>> = (0..5)
.map(|doc_id| bytes_col.ords().first(doc_id))
.map(|row_id| bytes_col.ords().first(row_id))
.collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(bytes_col.num_rows(), 5);

View File

@@ -1,5 +1,3 @@
use std::str::FromStr;
use common::DateTime;
use crate::InvalidData;
@@ -11,23 +9,6 @@ pub enum NumericalValue {
F64(f64),
}
impl FromStr for NumericalValue {
type Err = ();
fn from_str(s: &str) -> Result<Self, ()> {
if let Ok(val_i64) = s.parse::<i64>() {
return Ok(val_i64.into());
}
if let Ok(val_u64) = s.parse::<u64>() {
return Ok(val_u64.into());
}
if let Ok(val_f64) = s.parse::<f64>() {
return Ok(NumericalValue::from(val_f64).normalize());
}
Err(())
}
}
impl NumericalValue {
pub fn numerical_type(&self) -> NumericalType {
match self {
@@ -45,7 +26,7 @@ impl NumericalValue {
if val <= i64::MAX as u64 {
NumericalValue::I64(val as i64)
} else {
NumericalValue::U64(val)
NumericalValue::F64(val as f64)
}
}
NumericalValue::I64(val) => NumericalValue::I64(val),
@@ -160,7 +141,6 @@ impl Coerce for DateTime {
#[cfg(test)]
mod tests {
use super::NumericalType;
use crate::NumericalValue;
#[test]
fn test_numerical_type_code() {
@@ -173,58 +153,4 @@ mod tests {
}
assert_eq!(num_numerical_type, 3);
}
#[test]
fn test_parse_numerical() {
assert_eq!(
"123".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(123)
);
assert_eq!(
"18446744073709551615".parse::<NumericalValue>().unwrap(),
NumericalValue::U64(18446744073709551615u64)
);
assert_eq!(
"1.0".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(1i64)
);
assert_eq!(
"1.1".parse::<NumericalValue>().unwrap(),
NumericalValue::F64(1.1f64)
);
assert_eq!(
"-1.0".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(-1i64)
);
}
#[test]
fn test_normalize_numerical() {
assert_eq!(
NumericalValue::from(1u64).normalize(),
NumericalValue::I64(1i64),
);
let limit_val = i64::MAX as u64 + 1u64;
assert_eq!(
NumericalValue::from(limit_val).normalize(),
NumericalValue::U64(limit_val),
);
assert_eq!(
NumericalValue::from(-1i64).normalize(),
NumericalValue::I64(-1i64),
);
assert_eq!(
NumericalValue::from(-2.0f64).normalize(),
NumericalValue::I64(-2i64),
);
assert_eq!(
NumericalValue::from(-2.1f64).normalize(),
NumericalValue::F64(-2.1f64),
);
let large_float = 2.0f64.powf(70.0f64);
assert_eq!(
NumericalValue::from(large_float).normalize(),
NumericalValue::F64(large_float),
);
}
}

View File

@@ -181,17 +181,9 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)
(max_val + 63u32) / 64u32
}
impl BitSet {

View File

@@ -28,9 +28,7 @@ impl BinarySerializable for VIntU128 {
writer.write_all(&buffer)
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u128;
let mut shift = 0u64;
@@ -197,9 +195,7 @@ impl BinarySerializable for VInt {
writer.write_all(&buffer[0..num_bytes])
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;

Binary file not shown.

After

Width:  |  Height:  |  Size: 653 KiB

View File

@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
// is the role of the `TopDocs` collector.
// We can now perform our query.
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
// The actual documents still need to be
// retrieved from Tantivy's store.
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
let (_score, doc_address) = searcher
.search(&query, &TopDocs::with_limit(1).order_by_score())?
.search(&query, &TopDocs::with_limit(1))?
.into_iter()
.next()
.unwrap();

View File

@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
// here we want to get a hit on the 'ken' in Frankenstein
let query = query_parser.parse_query("ken")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
for (_, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
{
// Simple exact search on the date
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
assert_eq!(count_docs.len(), 1);
}
{
// Range query on the date field
let query = query_parser
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
assert_eq!(count_docs.len(), 1);
for (_score, doc_address) in count_docs {
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;

View File

@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
// The second argument is here to tell we don't care about decoding positions,
// or term frequencies.
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
if let Some((_score, doc_address)) = top_docs.first() {
let doc = searcher.doc(*doc_address)?;

View File

@@ -1,212 +0,0 @@
// # Filter Aggregation Example
//
// This example demonstrates filter aggregations - creating buckets of documents
// matching specific queries, with nested aggregations computed on each bucket.
//
// Filter aggregations are useful for computing metrics on different subsets of
// your data in a single query, like "average price overall + average price for
// electronics + count of in-stock items".
use serde_json::json;
use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::AllQuery;
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index};
fn main() -> tantivy::Result<()> {
// Create a simple product schema
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("category", TEXT | FAST);
schema_builder.add_text_field("brand", TEXT | FAST);
schema_builder.add_u64_field("price", FAST);
schema_builder.add_f64_field("rating", FAST);
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
let schema = schema_builder.build();
// Create index and add sample products
let index = Index::create_in_ram(schema.clone());
let mut writer = index.writer(50_000_000)?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "apple",
schema.get_field("price")? => 999u64,
schema.get_field("rating")? => 4.5f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "samsung",
schema.get_field("price")? => 799u64,
schema.get_field("rating")? => 4.2f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "clothing",
schema.get_field("brand")? => "nike",
schema.get_field("price")? => 120u64,
schema.get_field("rating")? => 4.1f64,
schema.get_field("in_stock")? => false
))?;
writer.add_document(doc!(
schema.get_field("category")? => "books",
schema.get_field("brand")? => "penguin",
schema.get_field("price")? => 25u64,
schema.get_field("rating")? => 4.8f64,
schema.get_field("in_stock")? => true
))?;
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Example 1: Basic filter with metric aggregation
println!("=== Example 1: Electronics average price ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 2: Multiple independent filters
println!("=== Example 2: Multiple filters in one query ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": { "avg_price": { "avg": { "field": "price" } } }
},
"in_stock": {
"filter": "in_stock:true",
"aggs": { "count": { "value_count": { "field": "brand" } } }
},
"high_rated": {
"filter": "rating:[4.5 TO *]",
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
},
"in_stock": {
"doc_count": 3,
"count": { "value": 3.0 }
},
"high_rated": {
"doc_count": 2,
"count": { "value": 2.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 3: Nested filters - progressive refinement
println!("=== Example 3: Nested filters ===");
let agg_req = json!({
"in_stock": {
"filter": "in_stock:true",
"aggs": {
"electronics": {
"filter": "category:electronics",
"aggs": {
"expensive": {
"filter": "price:[800 TO *]",
"aggs": {
"avg_rating": { "avg": { "field": "rating" } }
}
}
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"in_stock": {
"doc_count": 3, // apple, samsung, penguin
"electronics": {
"doc_count": 2, // apple, samsung
"expensive": {
"doc_count": 1, // only apple (999)
"avg_rating": { "value": 4.5 }
}
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 4: Filter with sub-aggregation (terms)
println!("=== Example 4: Filter with terms sub-aggregation ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"by_brand": {
"terms": { "field": "brand" },
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"by_brand": {
"buckets": [
{
"key": "samsung",
"doc_count": 1,
"avg_price": { "value": 799.0 }
},
{
"key": "apple",
"doc_count": 1,
"avg_price": { "value": 999.0 }
}
],
"sum_other_doc_count": 0,
"doc_count_error_upper_bound": 0
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}

View File

@@ -85,6 +85,7 @@ fn main() -> tantivy::Result<()> {
index_writer.add_document(doc!(
title => "The Diary of a Young Girl",
))?;
index_writer.commit()?;
// ### Committing
//
@@ -145,7 +146,7 @@ fn main() -> tantivy::Result<()> {
let query = FuzzyTermQuery::new(term, 2, true);
let (top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
.search(&query, &(TopDocs::with_limit(5), Count))
.unwrap();
assert_eq!(count, 3);
assert_eq!(top_docs.len(), 3);

View File

@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
{
// Inclusive range queries
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
assert_eq!(count_docs.len(), 1);
}
{
// Exclusive range queries
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 0);
}
{
// Find docs with IP addresses smaller equal 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}
{
// Find docs with IP addresses smaller than 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}

View File

@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
{
let query = query_parser.parse_query("target:submit-button")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}
{
let query = query_parser.parse_query("target:submit")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}
{
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
}
{
let query = query_parser.parse_query("click AND cart.product_id:133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(hits.len(), 1);
}
{
// The sub-fields in the json field marked as default field still need to be explicitly
// addressed
let query = query_parser.parse_query("click AND 133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(hits.len(), 0);
}
{
// Default json fields are ignored if they collide with the schema
let query = query_parser.parse_query("event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(hits.len(), 0);
}
// # Query via full attribute path
{
// This only searches in our schema's `event_type` field
let query = query_parser.parse_query("event_type:click")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(hits.len(), 2);
}
{
// Default json fields can still be accessed by full path
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(hits.len(), 1);
}
Ok(())

View File

@@ -63,7 +63,7 @@ fn main() -> Result<()> {
// but not "in the Gulf Stream".
let query = query_parser.parse_query("\"in the su\"*")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let mut titles = top_docs
.into_iter()
.map(|(_score, doc_address)| {

View File

@@ -107,8 +107,7 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
assert_eq!(count, 2);
@@ -129,8 +128,7 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (_top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
assert_eq!(count, 0);

View File

@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("sycamore spring")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;

View File

@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
// stop words are applied on the query as well.
// The following will be equivalent to `title:frankenstein`
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
for (score, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
move |doc_id: DocId| Reverse(price[doc_id as usize])
};
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
let hits = searcher.search(&query, &most_expensive_first)?;
assert_eq!(

View File

@@ -15,5 +15,3 @@ edition = "2024"
nom = "7"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
ordered-float = "5.0.0"
fnv = "1.0.7"

View File

@@ -117,22 +117,6 @@ where F: nom::Parser<I, (O, ErrorList), Infallible> {
}
}
pub(crate) fn terminated_infallible<I, O1, O2, F, G>(
mut first: F,
mut second: G,
) -> impl FnMut(I) -> JResult<I, O1>
where
F: nom::Parser<I, (O1, ErrorList), Infallible>,
G: nom::Parser<I, (O2, ErrorList), Infallible>,
{
move |input: I| {
let (input, (o1, mut err)) = first.parse(input)?;
let (input, (_, mut err2)) = second.parse(input)?;
err.append(&mut err2);
Ok((input, (o1, err)))
}
}
pub(crate) fn delimited_infallible<I, O1, O2, O3, F, G, H>(
mut first: F,
mut second: G,

View File

@@ -31,17 +31,7 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec<LenientError>) {
#[cfg(test)]
mod tests {
use crate::{UserInputAst, parse_query, parse_query_lenient};
#[test]
fn test_deduplication() {
let ast: UserInputAst = parse_query("a a").unwrap();
let json = serde_json::to_string(&ast).unwrap();
assert_eq!(
json,
r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"#
);
}
use crate::{parse_query, parse_query_lenient};
#[test]
fn test_parse_query_serialization() {

View File

@@ -1,7 +1,6 @@
use std::borrow::Cow;
use std::iter::once;
use fnv::FnvHashSet;
use nom::IResult;
use nom::branch::alt;
use nom::bytes::complete::tag;
@@ -69,7 +68,7 @@ fn interpret_escape(source: &str) -> String {
/// Consume a word outside of any context.
// TODO should support escape sequences
fn word(inp: &str) -> IResult<&str, Cow<'_, str>> {
fn word(inp: &str) -> IResult<&str, Cow<str>> {
map_res(
recognize(tuple((
alt((
@@ -306,14 +305,15 @@ fn term_group_infallible(inp: &str) -> JResult<&str, UserInputAst> {
let (inp, (field_name, _, _, _)) =
tuple((field_name, multispace0, char('('), multispace0))(inp).expect("precondition failed");
delimited_infallible(
let res = delimited_infallible(
nothing,
map(ast_infallible, |(mut ast, errors)| {
ast.set_default_field(field_name.to_string());
(ast, errors)
}),
opt_i_err(char(')'), "expected ')'"),
)(inp)
)(inp);
res
}
fn exists(inp: &str) -> IResult<&str, UserInputLeaf> {
@@ -367,10 +367,7 @@ fn literal(inp: &str) -> IResult<&str, UserInputAst> {
// something (a field name) got parsed before
alt((
map(
tuple((
opt(field_name),
alt((range, set, exists, regex, term_or_phrase)),
)),
tuple((opt(field_name), alt((range, set, exists, term_or_phrase)))),
|(field_name, leaf): (Option<String>, UserInputLeaf)| leaf.set_field(field_name).into(),
),
term_group,
@@ -392,10 +389,6 @@ fn literal_no_group_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>>
value((), peek(one_of("{[><"))),
map(range_infallible, |(range, errs)| (Some(range), errs)),
),
(
value((), peek(one_of("/"))),
map(regex_infallible, |(regex, errs)| (Some(regex), errs)),
),
),
delimited_infallible(space0_infallible, term_or_phrase_infallible, nothing),
),
@@ -696,61 +689,6 @@ fn set_infallible(mut inp: &str) -> JResult<&str, UserInputLeaf> {
}
}
fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
map(
terminated(
delimited(
char('/'),
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
char('/'),
),
peek(alt((multispace1, eof))),
),
|elements| UserInputLeaf::Regex {
field: None,
pattern: elements.into_iter().collect::<String>(),
},
)(inp)
}
fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
match terminated_infallible(
delimited_infallible(
opt_i_err(char('/'), "missing delimiter /"),
opt_i(many1(alt((preceded(char('\\'), char('/')), none_of("/"))))),
opt_i_err(char('/'), "missing delimiter /"),
),
opt_i_err(
peek(alt((multispace1, eof))),
"expected whitespace or end of input",
),
)(inp)
{
Ok((rest, (elements_part, errors))) => {
let pattern = match elements_part {
Some(elements_part) => elements_part.into_iter().collect(),
None => String::new(),
};
let res = UserInputLeaf::Regex {
field: None,
pattern,
};
Ok((rest, (res, errors)))
}
Err(e) => {
let errs = vec![LenientErrorInternal {
pos: inp.len(),
message: e.to_string(),
}];
let res = UserInputLeaf::Regex {
field: None,
pattern: String::new(),
};
Ok((inp, (res, errs)))
}
}
}
fn negate(expr: UserInputAst) -> UserInputAst {
expr.unary(Occur::MustNot)
}
@@ -758,17 +696,7 @@ fn negate(expr: UserInputAst) -> UserInputAst {
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
alt((
delimited(char('('), ast, char(')')),
map(
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
),
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
literal,
))(inp)
@@ -789,17 +717,7 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
),
),
(
value(
(),
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
),
value((), char('*')),
map(nothing, |_| {
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
}),
@@ -835,7 +753,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> {
tuple((leaf, fallible(boost))),
|(leaf, boost_opt)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => {
UserInputAst::Boost(Box::new(leaf), boost.into())
UserInputAst::Boost(Box::new(leaf), boost)
}
_ => leaf,
},
@@ -847,7 +765,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
tuple_infallible((leaf_infallible, boost)),
|((leaf, boost_opt), error)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => (
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())),
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)),
error,
),
_ => (leaf, error),
@@ -1098,25 +1016,12 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec<LenientError>
(rewrite_ast(res), errors)
}
/// Removes unnecessary children clauses in AST
///
/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
fn rewrite_ast(mut input: UserInputAst) -> UserInputAst {
if let UserInputAst::Clause(sub_clauses) = &mut input {
// call rewrite_ast recursively on children clauses if applicable
let mut new_clauses = Vec::with_capacity(sub_clauses.len());
for (occur, clause) in sub_clauses.drain(..) {
let rewritten_clause = rewrite_ast(clause);
new_clauses.push((occur, rewritten_clause));
}
*sub_clauses = new_clauses;
// remove duplicate child clauses
// e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d)
let mut seen = FnvHashSet::default();
sub_clauses.retain(|term| seen.insert(term.clone()));
// Removes unnecessary children clauses in AST
//
// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
for term in sub_clauses {
if let UserInputAst::Clause(terms) = &mut input {
for term in terms {
rewrite_ast_clause(term);
}
}
@@ -1691,21 +1596,6 @@ mod test {
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
// Phrase prefixed with *
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
test_parse_query_to_ast_helper("*A", "*A");
test_parse_query_to_ast_helper("(*A)", "*A");
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
}
#[test]
fn test_parse_query_all() {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
}
#[test]
@@ -1804,63 +1694,6 @@ mod test {
test_is_parse_err(r#"!bc:def"#, "!bc:def");
}
#[test]
fn test_regex_parser() {
let r = parse_to_ast(r#"a:/joh?n(ath[oa]n)/"#);
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
let (_, input) = r.unwrap();
match input {
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
UserInputLeaf::Regex { field, pattern } => {
assert_eq!(field, &Some("a".to_string()));
assert_eq!(pattern, "joh?n(ath[oa]n)");
}
_ => panic!("Expected a regex leaf, got {leaf:?}"),
},
_ => panic!("Expected a leaf"),
}
let r = parse_to_ast(r#"a:/\\/cgi-bin\\/luci.*/"#);
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
let (_, input) = r.unwrap();
match input {
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
UserInputLeaf::Regex { field, pattern } => {
assert_eq!(field, &Some("a".to_string()));
assert_eq!(pattern, "\\/cgi-bin\\/luci.*");
}
_ => panic!("Expected a regex leaf, got {leaf:?}"),
},
_ => panic!("Expected a leaf"),
}
}
#[test]
fn test_regex_parser_lenient() {
let literal = |query| literal_infallible(query).unwrap().1;
let (res, errs) = literal(r#"a:/joh?n(ath[oa]n)/"#);
let expected = UserInputLeaf::Regex {
field: Some("a".to_string()),
pattern: "joh?n(ath[oa]n)".to_string(),
}
.into();
assert_eq!(res.unwrap(), expected);
assert!(errs.is_empty(), "Expected no errors, got: {errs:?}");
let (res, errs) = literal("title:/joh?n(ath[oa]n)");
let expected = UserInputLeaf::Regex {
field: Some("title".to_string()),
pattern: "joh?n(ath[oa]n)".to_string(),
}
.into();
assert_eq!(res.unwrap(), expected);
assert_eq!(errs.len(), 1, "Expected 1 error, got: {errs:?}");
assert_eq!(
errs[0].message, "missing delimiter /",
"Unexpected error message",
);
}
#[test]
fn test_space_before_value() {
test_parse_query_to_ast_helper("field : a", r#""field":a"#);

View File

@@ -5,7 +5,7 @@ use serde::Serialize;
use crate::Occur;
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[derive(PartialEq, Clone, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum UserInputLeaf {
@@ -23,10 +23,6 @@ pub enum UserInputLeaf {
Exists {
field: String,
},
Regex {
field: Option<String>,
pattern: String,
},
}
impl UserInputLeaf {
@@ -50,7 +46,6 @@ impl UserInputLeaf {
UserInputLeaf::Exists { field: _ } => UserInputLeaf::Exists {
field: field.expect("Exist query without a field isn't allowed"),
},
UserInputLeaf::Regex { field: _, pattern } => UserInputLeaf::Regex { field, pattern },
}
}
@@ -108,19 +103,11 @@ impl Debug for UserInputLeaf {
UserInputLeaf::Exists { field } => {
write!(formatter, "$exists(\"{field}\")")
}
UserInputLeaf::Regex { field, pattern } => {
if let Some(field) = field {
// TODO properly escape field (in case of \")
write!(formatter, "\"{field}\":")?;
}
// TODO properly escape pattern (in case of \")
write!(formatter, "/{pattern}/")
}
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Delimiter {
SingleQuotes,
@@ -128,7 +115,7 @@ pub enum Delimiter {
None,
}
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[derive(PartialEq, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub struct UserInputLiteral {
pub field_name: Option<String>,
@@ -167,7 +154,7 @@ impl fmt::Debug for UserInputLiteral {
}
}
#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)]
#[derive(PartialEq, Debug, Clone, Serialize)]
#[serde(tag = "type", content = "value")]
#[serde(rename_all = "snake_case")]
pub enum UserInputBound {
@@ -204,11 +191,11 @@ impl UserInputBound {
}
}
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[derive(PartialEq, Clone, Serialize)]
#[serde(into = "UserInputAstSerde")]
pub enum UserInputAst {
Clause(Vec<(Option<Occur>, UserInputAst)>),
Boost(Box<UserInputAst>, ordered_float::OrderedFloat<f64>),
Boost(Box<UserInputAst>, f64),
Leaf(Box<UserInputLeaf>),
}
@@ -230,10 +217,9 @@ impl From<UserInputAst> for UserInputAstSerde {
fn from(ast: UserInputAst) -> Self {
match ast {
UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause },
UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost {
underlying,
boost: boost.into_inner(),
},
UserInputAst::Boost(underlying, boost) => {
UserInputAstSerde::Boost { underlying, boost }
}
UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf),
}
}
@@ -392,7 +378,7 @@ mod tests {
#[test]
fn test_boost_serialization() {
let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All));
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into());
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5);
let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!(
json,
@@ -419,7 +405,7 @@ mod tests {
}))),
),
])),
2.5.into(),
2.5,
);
let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!(

View File

@@ -20,16 +20,17 @@ Contains all metric aggregations, like average aggregation. Metric aggregations
#### agg_req
agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests.
#### agg_data
agg_data contains the users aggregation request enriched with fast field accessors etc, which are
#### agg_req_with_accessor
agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are
used during collection.
#### segment_agg_result
segment_agg_result contains the aggregation result tree, which is used for collection of a segment.
agg_data is passed during collection.
The tree from agg_req_with_accessor is passed during collection.
#### intermediate_agg_result
intermediate_agg_result contains the aggregation tree for merging with other trees.
#### agg_result
agg_result contains the final aggregation tree.

View File

@@ -1,105 +0,0 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::io;
use columnar::{Column, ColumnType};
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
column_max_value: u64,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
/// Get fast field reader or empty as default.
pub(crate) fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
pub(crate) fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
pub(crate) fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

File diff suppressed because it is too large Load Diff

View File

@@ -35,7 +35,6 @@ pub struct AggregationLimitsGuard {
/// Allocated memory with this guard.
allocated_with_the_guard: u64,
}
impl Clone for AggregationLimitsGuard {
fn clone(&self) -> Self {
Self {
@@ -71,7 +70,7 @@ impl AggregationLimitsGuard {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*

View File

@@ -26,14 +26,12 @@
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
//! ```
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -45,7 +43,7 @@ use super::metric::{
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
///
/// The key is the user defined name of the aggregation.
pub type Aggregations = FxHashMap<String, Aggregation>;
pub type Aggregations = HashMap<String, Aggregation>;
/// Aggregation request.
///
@@ -131,9 +129,6 @@ pub enum AggregationVariants {
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -179,7 +174,6 @@ impl AggregationVariants {
AggregationVariants::Range(range) => vec![range.field.as_str()],
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -214,6 +208,13 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
match &self {
AggregationVariants::TopHits(top_hits) => Some(top_hits),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -0,0 +1,471 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::collections::HashMap;
use std::io;
use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::VecWithNames;
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
use crate::SegmentOrdinal;
#[derive(Default)]
pub(crate) struct AggregationsWithAccessor {
pub aggs: VecWithNames<AggregationWithAccessor>,
}
impl AggregationsWithAccessor {
fn from_data(aggs: VecWithNames<AggregationWithAccessor>) -> Self {
Self { aggs }
}
pub fn is_empty(&self) -> bool {
self.aggs.is_empty()
}
}
pub struct AggregationWithAccessor {
pub(crate) segment_ordinal: SegmentOrdinal,
/// In general there can be buckets without fast field access, e.g. buckets that are created
/// based on search terms. That is not that case currently, but eventually this needs to be
/// Option or moved.
pub(crate) accessor: Column<u64>,
/// Load insert u64 for missing use case
pub(crate) missing_value_for_accessor: Option<u64>,
pub(crate) str_dict_column: Option<StrColumn>,
pub(crate) field_type: ColumnType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) limits: AggregationLimitsGuard,
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// Used for missing term aggregation, which checks all columns for existence.
/// And also for `top_hits` aggregation, which may sort on multiple fields.
/// By convention the missing aggregation is chosen, when this property is set
/// (instead bein set in `agg`).
/// If this needs to used by other aggregations, we need to refactor this.
// NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type`
// (making them obsolete) But will it have a performance impact?
pub(crate) accessors: Vec<(Column<u64>, ColumnType)>,
/// Map field names to all associated column accessors.
/// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`.
pub(crate) value_accessors: HashMap<String, Vec<DynamicColumn>>,
pub(crate) agg: Aggregation,
}
impl AggregationWithAccessor {
/// May return multiple accessors if the aggregation is e.g. on mixed field types.
fn try_from_agg(
agg: &Aggregation,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: AggregationLimitsGuard,
) -> crate::Result<Vec<AggregationWithAccessor>> {
let mut agg = agg.clone();
let add_agg_with_accessor = |agg: &Aggregation,
accessor: Column<u64>,
column_type: ColumnType,
aggs: &mut Vec<AggregationWithAccessor>|
-> crate::Result<()> {
let res = AggregationWithAccessor {
segment_ordinal,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits: limits.clone(),
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let add_agg_with_accessors = |agg: &Aggregation,
accessors: Vec<(Column<u64>, ColumnType)>,
aggs: &mut Vec<AggregationWithAccessor>,
value_accessors: HashMap<String, Vec<DynamicColumn>>|
-> crate::Result<()> {
let (accessor, field_type) = accessors.first().expect("at least one accessor");
let limits = limits.clone();
let res = AggregationWithAccessor {
segment_ordinal,
// TODO: We should do away with the `accessor` field altogether
accessor: accessor.clone(),
value_accessors,
field_type: *field_type,
accessors,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits,
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let mut res: Vec<AggregationWithAccessor> = Vec::new();
use AggregationVariants::*;
match agg.agg {
Range(RangeAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Histogram(HistogramAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
DateHistogram(DateHistogramAggregationReq {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
// Only DateTime is supported for DateHistogram
get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Terms(TermsAggregation {
field: ref field_name,
ref missing,
..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => {
let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
// In case the column is empty we want the shim column to match the missing type
let fallback_type = missing
.as_ref()
.map(|missing| match missing {
Key::Str(_) => ColumnType::Str,
Key::F64(_) => ColumnType::F64,
Key::I64(_) => ColumnType::I64,
Key::U64(_) => ColumnType::U64,
})
.unwrap_or(ColumnType::U64);
let column_and_types = get_all_ff_reader_or_empty(
reader,
field_name,
Some(&allowed_column_types),
fallback_type,
)?;
let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some();
let text_on_non_text_col = column_and_types.len() == 1
&& column_and_types[0].1.numerical_type().is_some()
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
// Actually we could convert the text to a number and have the fast path, if it is
// provided in Rfc3339 format. But this use case is probably common
// enough to justify the effort.
let text_on_date_col = column_and_types.len() == 1
&& column_and_types[0].1 == ColumnType::DateTime
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
let use_special_missing_agg =
missing_and_more_than_one_col || text_on_non_text_col || text_on_date_col;
if use_special_missing_agg {
let column_and_types =
get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?;
let accessors = column_and_types
.iter()
.map(|c_t| (c_t.0.clone(), c_t.1))
.collect();
add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?;
}
for (accessor, column_type) in column_and_types {
let missing_value_term_agg = if use_special_missing_agg {
None
} else {
missing.clone()
};
let missing_value_for_accessor =
if let Some(missing) = missing_value_term_agg.as_ref() {
get_missing_val_as_u64_lenient(
column_type,
missing,
agg.agg.get_fast_field_names()[0],
)?
} else {
None
};
let limits = limits.clone();
let agg = AggregationWithAccessor {
segment_ordinal,
missing_value_for_accessor,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
str_dict_column: str_dict_column.clone(),
limits,
column_block_accessor: Default::default(),
};
res.push(agg);
}
}
Average(AverageAggregation {
field: ref field_name,
..
})
| Max(MaxAggregation {
field: ref field_name,
..
})
| Min(MinAggregation {
field: ref field_name,
..
})
| Stats(StatsAggregation {
field: ref field_name,
..
})
| ExtendedStats(ExtendedStatsAggregation {
field: ref field_name,
..
})
| Sum(SumAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Count(CountAggregation {
field: ref field_name,
..
}) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(&allowed_column_types))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Percentiles(ref percentiles) => {
let (accessor, column_type) = get_ff_reader(
reader,
percentiles.field_name(),
Some(get_numeric_or_date_column_types()),
)?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
TopHits(ref mut top_hits) => {
top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?;
let accessors: Vec<(Column<u64>, ColumnType)> = top_hits
.field_names()
.iter()
.map(|field| {
get_ff_reader(reader, field, Some(get_numeric_or_date_column_types()))
})
.collect::<crate::Result<_>>()?;
let value_accessors = top_hits
.value_field_names()
.iter()
.map(|field_name| {
Ok((
field_name.to_string(),
get_dynamic_columns(reader, field_name)?,
))
})
.collect::<crate::Result<_>>()?;
add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?;
}
};
Ok(res)
}
}
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
pub(crate) fn get_aggs_with_segment_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard,
) -> crate::Result<AggregationsWithAccessor> {
let mut aggss = Vec::new();
for (key, agg) in aggs.iter() {
let aggs = AggregationWithAccessor::try_from_agg(
agg,
agg.sub_aggregation(),
reader,
segment_ordinal,
limits.clone(),
)?;
for agg in aggs {
aggss.push((key.to_string(), agg));
}
}
Ok(AggregationsWithAccessor::from_data(
VecWithNames::from_entries(aggss),
))
}
/// Get fast field reader or empty as default.
fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

View File

@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
/// The final aggregation result.
/// The final aggegation result.
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {
@@ -156,8 +156,6 @@ pub enum BucketResult {
/// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>,
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
}
impl BucketResult {
@@ -174,11 +172,6 @@ impl BucketResult {
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
BucketResult::Filter(filter_result) => {
// Filter doesn't add to bucket count - it's not a user-facing bucket
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
}
}
}
@@ -315,25 +308,3 @@ impl RangeBucketEntry {
1 + self.sub_aggregation.get_bucket_count()
}
}
/// This is the filter bucket result, which contains the document count and sub-aggregations.
///
/// # JSON Format
/// ```json
/// {
/// "electronics_only": {
/// "doc_count": 2,
/// "avg_price": {
/// "value": 150.0
/// }
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FilterBucketResult {
/// Number of documents in the filter bucket
pub doc_count: u64,
/// Sub-aggregation results
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}

View File

@@ -2,441 +2,16 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -451,10 +26,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -467,9 +38,8 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
@@ -558,8 +128,10 @@ fn test_aggregation_flushing(
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimitsGuard::default(),
);
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();

File diff suppressed because it is too large Load Diff

View File

@@ -1,49 +1,25 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentHistogramCollector to perform the
/// histogram or date_histogram aggregation on a segment.
pub struct HistogramAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The name of the aggregation.
pub name: String,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
pub is_date_histogram: bool,
/// The bounds to limit the buckets to.
pub bounds: HistogramBounds,
/// The offset used to calculate the bucket position.
pub offset: f64,
}
impl HistogramAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
/// Each document value is rounded down to its bucket.
///
@@ -252,24 +228,18 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -279,38 +249,31 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
column_type: ColumnType,
interval: f64,
offset: f64,
bounds: HistogramBounds,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -319,77 +282,72 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
let bounds = self.bounds;
let interval = self.interval;
let offset = self.offset;
let get_bucket_pos = |val| (get_bucket_pos_f64(val, interval, offset) as i64);
agg_data
bucket_agg_accessor
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.fetch_block(docs, &bucket_agg_accessor.accessor);
for (doc, val) in bucket_agg_accessor
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
let val = self.f64_from_fastfield_u64(val);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
SegmentHistogramBucketEntry { key, doc_count: 0 }
});
bucket.doc_count += 1;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
if let Some(sub_aggregation_blueprint) = self.sub_aggregation_blueprint.as_mut() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
}
}
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
bucket_agg_accessor
.limits
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(sub_aggregation_accessor)?;
}
Ok(())
}
}
@@ -397,62 +355,72 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(histogram.buckets.len());
let mut buckets = Vec::with_capacity(self.buckets.len());
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
&agg_with_accessor.sub_aggregation,
);
buckets.push(bucket_res?);
}
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
let is_date_agg = agg_data
.get_histogram_req_data(self.accessor_idx)
.field_type
== ColumnType::DateTime;
Ok(IntermediateBucketResult::Histogram {
buckets,
is_date_agg,
is_date_agg: self.column_type == ColumnType::DateTime,
})
}
pub(crate) fn from_req_and_validate(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
mut req: HistogramAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
req.validate()?;
if field_type == ColumnType::DateTime {
req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
let sub_aggregation_blueprint = if sub_aggregation.is_empty() {
None
} else {
let sub_aggregation = build_segment_agg_collector(sub_aggregation)?;
Some(sub_aggregation)
};
let bounds = req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
buckets: Default::default(),
column_type: field_type,
interval: req.interval,
offset: req.offset.unwrap_or(0.0),
bounds,
sub_aggregations: Default::default(),
sub_aggregation_blueprint,
accessor_idx,
})
}
#[inline]
fn f64_from_fastfield_u64(&self, val: u64) -> f64 {
f64_from_fastfield_u64(val, &self.column_type)
}
}
#[inline]

View File

@@ -22,7 +22,6 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod filter;
mod histogram;
mod range;
mod term_agg;
@@ -31,7 +30,6 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use filter::*;
pub use histogram::*;
pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

View File

@@ -1,47 +1,20 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentRangeCollector to perform the
/// range aggregation on a segment.
pub struct RangeAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Provide user-defined buckets to aggregate on.
///
/// Two special buckets will automatically be created to cover the whole range of values.
@@ -155,47 +128,19 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<C: SubAggCache> {
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
buckets: Vec<SegmentRangeAndBucketEntry>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -216,50 +161,46 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateRangeBucketEntry> {
let sub_aggregation = IntermediateAggregationResults::default();
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?
} else {
Default::default()
};
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation_res: sub_aggregation,
sub_aggregation: sub_aggregation_res,
from: self.from,
to: self.to,
})
}
}
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
impl SegmentAggregationCollector for SegmentRangeCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
.into_iter()
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(sub_agg)?,
))
})
.collect::<crate::Result<_>>()?;
@@ -276,114 +217,69 @@ impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let req = agg_data.take_range_req_data(self.accessor_idx);
self.collect_block(&[doc], agg_with_accessor)
}
agg_data
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
bucket_agg_accessor
.column_block_accessor
.fetch_block(docs, &req.accessor);
.fetch_block(docs, &bucket_agg_accessor.accessor);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
for (doc, val) in bucket_agg_accessor
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
{
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(sub_aggregation_accessor)?;
}
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
limits: &mut AggregationLimitsGuard,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -392,20 +288,24 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
let sub_aggregation = if sub_aggregation.is_empty() {
None
} else {
Some(build_segment_agg_collector(sub_aggregation)?)
};
// let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
bucket_id,
sub_aggregation,
key,
from,
to,
@@ -414,19 +314,26 @@ impl<C: SubAggCache> SegmentRangeCollector<C> {
})
.collect::<crate::Result<_>>()?;
self.limits.add_memory_consumed(
limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(buckets)
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
@@ -524,7 +431,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
Ok(f64_from_fastfield_u64(val, field_type).to_string())
}
};
@@ -554,54 +461,21 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggCache> {
) -> SegmentRangeCollector {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
..Default::default()
};
// Build buckets directly as in from_req_and_validate without AggregationsData
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
.expect("unexpected error in extend_validate_ranges")
.iter()
.map(|range| {
let key = range
.key
.clone()
.map(|key| Ok(Key::Str(key)))
.unwrap_or_else(|| range_to_key(&range.range, &field_type))
.expect("unexpected error in range_to_key");
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
SegmentRangeCollector::from_req_and_validate(
&req,
&mut Default::default(),
&mut AggregationLimitsGuard::default(),
field_type,
0,
)
.expect("unexpected error")
}
#[test]
@@ -847,7 +721,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -870,7 +744,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -885,7 +759,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -894,7 +768,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.parent_buckets[0].clone();
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -903,7 +777,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -949,7 +823,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -961,3 +835,63 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,93 +1,55 @@
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
///
/// This is needed when:
/// - The field is multi-valued and we therefore have multiple columns
/// - The field is not text and missing is provided as string (we cannot use the numeric missing
/// value optimization)
#[derive(Default)]
pub struct MissingTermAggReqData {
/// The accessors to check for existence of a value.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The name of the aggregation.
pub name: String,
/// The original terms aggregation request.
pub req: TermsAggregation,
}
impl MissingTermAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
#[derive(Default, Debug, Clone)]
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
missing_count: u32,
accessor_idx: usize,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
}
impl TermMissingAgg {
pub(crate) fn new(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
accessor_idx: usize,
sub_aggregations: &mut AggregationsWithAccessor,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let has_sub_aggregations = !sub_aggregations.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
let sub_aggregation = build_segment_agg_collector(sub_aggregations)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
..Default::default()
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let term_agg = agg_with_accessor
.agg
.agg
.as_term()
.expect("TermMissingAgg collector must be term agg req");
let missing = term_agg
.missing
.as_ref()
@@ -96,16 +58,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: missing_count.missing_count,
doc_count: self.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = &mut self.sub_agg {
if let Some(sub_agg) = self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
sub_agg.add_intermediate_aggregation_result(
&agg_with_accessor.sub_aggregation,
&mut res,
)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -118,62 +80,37 @@ impl SegmentAggregationCollector for TermMissingAgg {
},
};
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Bucket(bucket),
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let has_value = agg
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, &mut agg.sub_aggregation)?;
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
for doc in docs {
self.collect(*doc, agg_with_accessor)?;
}
Ok(())
}

View File

@@ -0,0 +1,83 @@
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::DocId;
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_with_accessor)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
self.collector.flush(agg_with_accessor)?;
Ok(())
}
}

View File

@@ -1,245 +0,0 @@
use std::fmt::Debug;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
Ok(())
}
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and there may be nothing to group.
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
partition.clear();
}
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
_force: bool,
) -> crate::Result<()> {
let mut max_bucket = 0u32;
for partition in self.partitions.iter() {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
for slot in self.partitions.iter() {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
}
}
self.clear();
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
}
}
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
// Note: There may be other flexible values, for other aggregations, but we can use the
// const value here as a upper bound. (better than nothing)
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold_limit > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
}
self.clear();
Ok(())
}
}

View File

@@ -1,12 +1,12 @@
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::index::SegmentReader;
use crate::{DocId, SegmentOrdinal, TantivyError};
@@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
context: AggContextParams,
limits: AggregationLimitsGuard,
}
impl AggregationCollector {
@@ -30,8 +30,8 @@ impl AggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
}
}
@@ -45,7 +45,7 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
context: AggContextParams,
limits: AggregationLimitsGuard,
}
impl DistributedAggregationCollector {
@@ -53,8 +53,8 @@ impl DistributedAggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
}
}
@@ -72,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.context,
&self.limits,
)
}
@@ -102,7 +102,7 @@ impl Collector for AggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.context,
&self.limits,
)
}
@@ -115,7 +115,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_result(self.agg.clone(), self.context.limits.clone())
res.into_final_result(self.agg.clone(), self.limits.clone())
}
}
@@ -135,8 +135,8 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardCachedSubAggs,
aggs_with_accessor: AggregationsWithAccessor,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -147,18 +147,14 @@ impl AggregationSegmentCollector {
agg: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
context: &AggContextParams,
limits: &AggregationLimitsGuard,
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
let mut aggs_with_accessor =
get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
aggs_with_accessor,
agg_collector: result,
error: None,
})
@@ -173,31 +169,26 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
self.agg_collector.push(0, doc);
match self
if let Err(err) = self
.agg_collector
.check_flush_local(&mut self.aggs_with_accessor)
.collect(doc, &mut self.aggs_with_accessor)
{
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
}
}
@@ -208,13 +199,10 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -24,9 +24,7 @@ use super::metric::{
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
@@ -181,17 +179,12 @@ impl IntermediateAggregationResults {
}
/// Merge another intermediate aggregation result into this result.
pub fn merge_fruits(&mut self, mut other: IntermediateAggregationResults) -> crate::Result<()> {
for (key, left) in self.aggs_res.iter_mut() {
if let Some(key) = other.aggs_res.remove(key) {
left.merge_fruits(key)?;
}
}
// Move remainder of other aggs_res into self.
// Note: Currently we don't expect this to happen, as we create empty intermediate results
// via [IntermediateAggregationResults::empty_from_req].
for (key, value) in other.aggs_res {
self.aggs_res.insert(key, value);
///
/// The order of the values need to be the same on both results. This is ensured when the same
/// (key values) are present on the underlying `VecWithNames` struct.
pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> {
for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) {
left.merge_fruits(right)?;
}
Ok(())
}
@@ -248,16 +241,11 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
}
}
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum IntermediateAggregationResult {
/// Bucket variant
Bucket(IntermediateBucketResult),
@@ -438,13 +426,6 @@ pub enum IntermediateBucketResult {
/// The term buckets
buckets: IntermediateTermBucketResult,
},
/// Filter aggregation - a single bucket with sub-aggregations
Filter {
/// Document count in the filter bucket
doc_count: u64,
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
}
impl IntermediateBucketResult {
@@ -528,18 +509,6 @@ impl IntermediateBucketResult {
req.sub_aggregation(),
limits,
),
IntermediateBucketResult::Filter {
doc_count,
sub_aggregations,
} => {
// Convert sub-aggregation results to final format
let final_sub_aggregations = sub_aggregations
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
Ok(BucketResult::Filter(FilterBucketResult {
doc_count,
sub_aggregations: final_sub_aggregations,
}))
}
}
}
@@ -593,19 +562,6 @@ impl IntermediateBucketResult {
*buckets_left = buckets?;
}
(
IntermediateBucketResult::Filter {
doc_count: doc_count_left,
sub_aggregations: sub_aggs_left,
},
IntermediateBucketResult::Filter {
doc_count: doc_count_right,
sub_aggregations: sub_aggs_right,
},
) => {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -615,9 +571,6 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Terms { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -792,7 +745,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation_res: IntermediateAggregationResults,
pub sub_aggregation: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +764,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation_res
.sub_aggregation
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -857,8 +810,7 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
Ok(())
}
}
@@ -888,7 +840,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation_res: Default::default(),
sub_aggregation: Default::default(),
from: None,
to: None,
},
@@ -921,7 +873,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,8 +52,10 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,13 +2,15 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use columnar::Dictionary;
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
@@ -95,30 +97,6 @@ pub struct CardinalityAggregationReq {
pub missing: Option<Key>,
}
/// Contains all information required by the SegmentCardinalityCollector to perform the
/// cardinality aggregation on a segment.
pub struct CardinalityAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The column_type of the field.
pub column_type: ColumnType,
/// The string dictionary column if the field is of type string.
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
pub req: CardinalityAggregationReq,
}
impl CardinalityAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
@@ -133,37 +111,51 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
accessor_idx: usize,
missing: Option<Key>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: FxHashSet::default(),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
req_data: &CardinalityAggReqData,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateMetricResult> {
if req_data.column_type == ColumnType::Str {
if self.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
let dict = agg_with_accessor
.str_dict_column
.as_ref()
.map(|el| el.dictionary())
@@ -181,7 +173,6 @@ impl SegmentCardinalityCollectorBucket {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -189,10 +180,10 @@ impl SegmentCardinalityCollectorBucket {
})?;
if has_missing {
// Replace missing with the actual value provided
let missing_key =
req_data.req.missing.as_ref().expect(
"Found sentinel value u64::MAX for term_ord but `missing` is not set",
);
let missing_key = self
.missing
.as_ref()
.expect("Found sentinel value u64::MAX for term_ord but `missing` is not set");
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
@@ -215,49 +206,16 @@ impl SegmentCardinalityCollectorBucket {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -268,20 +226,27 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
self.collect_block(&[doc], agg_with_accessor)
}
let col_block_accessor = &agg_data.column_block_accessor;
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
bucket.entries.insert(term_ord);
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
let compact_space_accessor = bucket_agg_accessor
.accessor
.values
.clone()
@@ -296,29 +261,16 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.sketch.insert_any(&val);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.sketch.insert_any(&val);
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,8 +52,10 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -4,13 +4,15 @@ use std::mem;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -61,7 +63,7 @@ impl ExtendedStatsAggregation {
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound information
/// and bound informations
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStats {
/// The number of documents.
@@ -317,28 +319,51 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
missing,
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
}
@@ -346,18 +371,15 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
extended_stats,
self.extended_stats,
)),
)?;
@@ -367,36 +389,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
}

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,8 +52,10 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,6 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -45,33 +44,6 @@ pub use top_hits::*;
use crate::schema::OwnedValue;
/// Contains all information required by metric aggregations like avg, min, max, sum, stats,
/// extended_stats, count, percentiles.
#[repr(C)]
pub struct MetricAggReqData {
/// True if the field is of number or date type.
pub is_number_or_date_type: bool,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result
pub collecting_for: StatsType,
/// The missing value
pub missing: Option<f64>,
/// The name of the aggregation.
pub name: String,
}
impl MetricAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Single-metric aggregations use this common result structure.
///
/// Main reason to wrap it in value is to match elasticsearch output structure.

View File

@@ -3,13 +3,15 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// # Percentiles
///
@@ -110,8 +112,7 @@ impl PercentilesAggregationReq {
&self.field
}
/// Validates the request parameters.
pub fn validate(&self) -> crate::Result<()> {
fn validate(&self) -> crate::Result<()> {
if let Some(percents) = self.percents.as_ref() {
let all_in_range = percents
.iter()
@@ -130,16 +131,12 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) buckets: Vec<PercentilesCollector>,
field_type: ColumnType,
pub(crate) percentiles: PercentilesCollector,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
missing: Option<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -235,17 +232,43 @@ impl PercentilesCollector {
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(
req: &PercentilesAggregationReq,
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
) -> crate::Result<Self> {
req.validate()?;
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Ok(Self {
field_type,
missing_u64,
accessor,
percentiles: PercentilesCollector::new(),
accessor_idx,
missing,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -253,18 +276,12 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
results.push(
name,
@@ -277,33 +294,40 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
}
}
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
}

View File

@@ -1,16 +1,17 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +84,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -165,98 +166,106 @@ impl IntermediateStats {
}
}
/// The type of stats aggregation to perform.
/// Note that not all stats types are supported in the stats aggregation.
#[derive(Clone, Copy, Debug)]
pub enum StatsType {
/// The average of the values.
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentStatsType {
Average,
/// The count of the values.
Count,
/// The maximum value.
Max,
/// The minimum value.
Min,
/// The stats (count, sum, min, max, avg) of the values.
Stats,
/// The extended stats (count, sum, min, max, avg, sum_of_squares, variance, std_deviation,
ExtendedStats(Option<f64>), // sigma
/// The sum of the values.
Sum,
/// The percentiles of the values.
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentStatsCollector {
missing: Option<u64>,
field_type: ColumnType,
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
impl SegmentStatsCollector {
pub fn from_req(
field_type: ColumnType,
collecting_for: SegmentStatsType,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
field_type,
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
if [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::DateTime,
]
.contains(&self.field_type)
{
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
} else {
for _val in agg_accessor.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = self.name.clone();
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
SegmentStatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
SegmentStatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
self.collecting_for
)))
SegmentStatsType::Max => {
IntermediateMetricResult::Max(IntermediateMax::from_collector(*self))
}
SegmentStatsType::Min => {
IntermediateMetricResult::Min(IntermediateMin::from_collector(*self))
}
SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats),
SegmentStatsType::Sum => {
IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self))
}
};
@@ -271,69 +280,43 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(())
}
fn prepare_max_bucket(
#[inline]
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use serde_json::Value;

View File

@@ -52,8 +52,10 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -9,41 +9,16 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::{TopHitsMetricResult, TopHitsVecEntry};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::Order;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::aggregation::AggregationError;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
#[derive(Default)]
pub struct TopHitsAggReqData {
/// The accessors to access the fast field values.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The accessors to access the fast field values for retrieving document fields.
pub value_accessors: HashMap<String, Vec<DynamicColumn>>,
/// The ordinal of the segment this request data is for.
pub segment_ordinal: SegmentOrdinal,
/// The name of the aggregation.
pub name: String,
/// The top_hits aggregation request.
pub req: TopHitsAggregationReq,
}
impl TopHitsAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// # Top Hits
///
/// The top hits aggregation is a useful tool to answer questions like:
@@ -458,7 +433,7 @@ impl Eq for DocSortValuesAndFields {}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
req: TopHitsAggregationReq,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
}
impl std::cmp::PartialEq for TopHitsTopNComputer {
@@ -471,10 +446,7 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
req: req.clone(),
}
}
@@ -485,7 +457,7 @@ impl TopHitsTopNComputer {
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
for doc in other_fruit.top_n.into_vec() {
self.collect(doc.sort_key, doc.doc);
self.collect(doc.feature, doc.doc);
}
Ok(())
}
@@ -497,9 +469,9 @@ impl TopHitsTopNComputer {
.into_sorted_vec()
.into_iter()
.map(|doc| TopHitsVecEntry {
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
doc_value_fields: doc
.sort_key
.feature
.doc_value_fields
.into_iter()
.map(|(k, v)| (k, v.into()))
@@ -520,8 +492,7 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
}
impl TopHitsSegmentCollector {
@@ -530,35 +501,25 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
num_hits,
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
fn into_top_hits_collector(
self,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = top_n.into_vec();
let top_results = self.top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.sort_key,
sorts: res.feature,
doc_value_fields,
},
res.doc,
@@ -567,26 +528,61 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let value_accessors = &req_data.value_accessors;
let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors;
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, tophits_req),
);
results.push(
req_data.name.to_string(),
name,
IntermediateAggregationResult::Metric(intermediate_result),
)
}
@@ -594,54 +590,34 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc_id: crate::DocId,
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let req = &req_data.req;
let accessors = &req_data.accessors;
for &doc_id in docs {
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
self.collect_with(doc_id, tophits_req, accessors)?;
Ok(())
}
fn prepare_max_bucket(
fn collect_block(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, tophits_req, accessors)?;
}
Ok(())
}
}
@@ -659,7 +635,6 @@ mod tests {
use crate::aggregation::bucket::tests::get_test_index_from_docs;
use crate::aggregation::tests::get_test_index_from_values;
use crate::aggregation::AggregationCollector;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::ComparableDoc;
use crate::query::AllQuery;
use crate::schema::OwnedValue;
@@ -675,7 +650,7 @@ mod tests {
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
super::TopHitsTopNComputer {
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
top_n: super::TopNComputer::new(capacity),
req: Default::default(),
}
}
@@ -759,7 +734,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -789,12 +764,12 @@ mod tests {
#[test]
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
let docs = vec![
ComparableDoc::<_, _> {
ComparableDoc::<_, _, false> {
doc: crate::DocAddress {
segment_ord: 0,
doc_id: 0,
},
sort_key: DocSortValuesAndFields {
feature: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(1),
order: Order::Asc,
@@ -807,7 +782,7 @@ mod tests {
segment_ord: 0,
doc_id: 2,
},
sort_key: DocSortValuesAndFields {
feature: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(3),
order: Order::Asc,
@@ -820,7 +795,7 @@ mod tests {
segment_ord: 0,
doc_id: 1,
},
sort_key: DocSortValuesAndFields {
feature: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(5),
order: Order::Asc,
@@ -832,7 +807,7 @@ mod tests {
let mut collector = collector_with_capacity(3);
for doc in docs.clone() {
collector.collect(doc.sort_key, doc.doc);
collector.collect(doc.feature, doc.doc);
}
let res = collector.into_final_result();
@@ -842,15 +817,15 @@ mod tests {
super::TopHitsMetricResult {
hits: vec![
super::TopHitsVecEntry {
sort: vec![docs[0].sort_key.sorts[0].value],
sort: vec![docs[0].feature.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[1].sort_key.sorts[0].value],
sort: vec![docs[1].feature.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[2].sort_key.sorts[0].value],
sort: vec![docs[2].feature.sorts[0].value],
doc_value_fields: Default::default(),
},
]
@@ -888,7 +863,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -127,13 +127,12 @@
//! [`AggregationResults`](agg_result::AggregationResults) via the
//! [`into_final_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_result) method.
mod accessor_helpers;
mod agg_data;
mod agg_limits;
pub mod agg_req;
mod agg_req_with_accessor;
pub mod agg_result;
pub mod bucket;
pub(crate) mod cached_sub_aggs;
mod buf_collector;
mod collector;
mod date;
mod error;
@@ -141,6 +140,7 @@ pub mod intermediate_agg_result;
pub mod metric;
mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
#[cfg(test)]
@@ -160,41 +160,6 @@ use itertools::Itertools;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
/// - `limits`: Memory and bucket limits for the aggregation
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
#[derive(Clone, Default)]
pub struct AggContextParams {
/// Aggregation limits (memory and bucket count)
pub limits: AggregationLimitsGuard,
/// Tokenizer manager for query string parsing
pub tokenizers: TokenizerManager,
}
impl AggContextParams {
/// Create new aggregation context parameters
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
Self { limits, tokenizers }
}
}
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
let parsed = value
.parse::<f64>()
@@ -292,6 +257,80 @@ where D: Deserializer<'de> {
deserializer.deserialize_any(StringOrFloatVisitor)
}
/// Represents an associative array `(key => values)` in a very efficient manner.
#[derive(PartialEq, Serialize, Deserialize)]
pub(crate) struct VecWithNames<T> {
pub(crate) values: Vec<T>,
keys: Vec<String>,
}
impl<T: Clone> Clone for VecWithNames<T> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
keys: self.keys.clone(),
}
}
}
impl<T> Default for VecWithNames<T> {
fn default() -> Self {
Self {
values: Default::default(),
keys: Default::default(),
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for VecWithNames<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl<T> From<HashMap<String, T>> for VecWithNames<T> {
fn from(map: HashMap<String, T>) -> Self {
VecWithNames::from_entries(map.into_iter().collect_vec())
}
}
impl<T> VecWithNames<T> {
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
// Sort to ensure order of elements match across multiple instances
entries.sort_by(|left, right| left.0.cmp(&right.0));
let mut data = Vec::with_capacity(entries.len());
let mut data_names = Vec::with_capacity(entries.len());
for entry in entries {
data_names.push(entry.0);
data.push(entry.1);
}
VecWithNames {
values: data,
keys: data_names,
}
}
fn iter(&self) -> impl Iterator<Item = (&str, &T)> + '_ {
self.keys().zip(self.values.iter())
}
fn keys(&self) -> impl Iterator<Item = &str> + '_ {
self.keys.iter().map(|key| key.as_str())
}
fn values_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
self.values.iter_mut()
}
fn is_empty(&self) -> bool {
self.keys.is_empty()
}
fn len(&self) -> usize {
self.keys.len()
}
fn get(&self, name: &str) -> Option<&T> {
self.keys()
.position(|key| key == name)
.map(|pos| &self.values[pos])
}
}
/// The serialized key is used in a `HashMap`.
pub type SerializedKey = String;
@@ -348,37 +387,19 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
match field_type {
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
}
}
@@ -443,10 +464,7 @@ mod tests {
query: Option<(&str, &str)>,
limits: AggregationLimitsGuard,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(
agg_req,
AggContextParams::new(limits, index.tokenizers().clone()),
);
let collector = AggregationCollector::from_aggs(agg_req, limits);
let reader = index.reader()?;
let searcher = reader.searcher();

View File

@@ -6,79 +6,179 @@
use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::agg_req::AggregationVariants;
use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
SumAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: Debug {
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()>;
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
fn collect_block(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
/// This method ensures those staged docs will be collected.
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
Ok(())
}
}
#[derive(Default)]
pub(crate) trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
pub(crate) fn build_segment_agg_collector(
req: &mut AggregationsWithAccessor,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
// Single collector special case
if req.aggs.len() == 1 {
let req = &mut req.aggs.values[0];
let accessor_idx = 0;
return build_single_agg_segment_collector(req, accessor_idx);
}
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
Ok(Box::new(agg))
}
pub(crate) fn build_single_agg_segment_collector(
req: &mut AggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
use AggregationVariants::*;
match &req.agg.agg {
Terms(terms_req) => {
if req.accessors.is_empty() {
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
} else {
Ok(Box::new(TermMissingAgg::new(
accessor_idx,
&mut req.sub_aggregation,
)?))
}
}
Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&mut req.sub_aggregation,
&mut req.limits,
req.field_type,
accessor_idx,
)?)),
Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.clone(),
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.to_histogram_req()?,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
Average(AverageAggregation { missing, .. }) => {
Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Average,
accessor_idx,
*missing,
)))
}
Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Count,
accessor_idx,
*missing,
))),
Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Max,
accessor_idx,
*missing,
))),
Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Min,
accessor_idx,
*missing,
))),
Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
*missing,
))),
ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing),
)),
Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Sum,
accessor_idx,
*missing,
))),
Percentiles(percentiles_req) => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
percentiles_req,
req.field_type,
accessor_idx,
)?,
)),
TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req(
top_hits_req,
accessor_idx,
req.segment_ordinal,
))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -96,13 +196,12 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_with_accessor, results)?;
}
Ok(())
@@ -110,31 +209,43 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect(parent_bucket_id, docs, agg_data)?;
}
self.collect_block(&[doc], agg_with_accessor)?;
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
fn collect_block(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
collector.collect_block(docs, agg_with_accessor)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.flush(agg_with_accessor)?;
}
Ok(())
}
}
impl GenericSegmentAggregationResultsCollector {
pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result<Self> {
let aggs = req
.aggs
.values_mut()
.enumerate()
.map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx))
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
Ok(GenericSegmentAggregationResultsCollector { aggs })
}
}

View File

@@ -1,49 +0,0 @@
mod postings;
mod standard;
use std::borrow::Cow;
use serde::{Deserialize, Serialize};
pub use standard::StandardCodec;
pub trait Codec: Clone + std::fmt::Debug + Send + Sync + 'static {
type PostingsCodec;
const NAME: &'static str;
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self>;
fn to_json_props(&self) -> serde_json::Value;
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CodecConfiguration {
name: Cow<'static, str>,
#[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
props: serde_json::Value,
}
impl CodecConfiguration {
pub fn from_codec<C: Codec>(codec: &C) -> Self {
CodecConfiguration {
name: Cow::Borrowed(C::NAME),
props: codec.to_json_props(),
}
}
pub fn to_codec<C: Codec>(&self) -> crate::Result<C> {
if self.name != C::NAME {
return Err(crate::TantivyError::InvalidArgument(format!(
"Codec name mismatch: expected {}, got {}",
C::NAME,
self.name
)));
}
C::from_json_props(&self.props)
}
}
impl Default for CodecConfiguration {
fn default() -> Self {
CodecConfiguration::from_codec(&StandardCodec)
}
}

View File

@@ -1,23 +0,0 @@
use std::io;
use crate::fieldnorm::FieldNormReader;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
pub trait PostingsCodec {
type PostingsSerializer: PostingsSerializer;
}
pub trait PostingsSerializer {
fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self;
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool);
fn write_doc(&mut self, doc_id: DocId, term_freq: u32);
fn close_term(&mut self, doc_freq: u32, wrt: &mut impl io::Write) -> io::Result<()>;
}

View File

@@ -1,29 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::codec::standard::postings::StandardPostingsCodec;
use crate::codec::Codec;
mod postings;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StandardCodec;
impl Codec for StandardCodec {
type PostingsCodec = StandardPostingsCodec;
const NAME: &'static str = "standard";
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self> {
if !json_value.is_null() {
return Err(crate::TantivyError::InvalidArgument(format!(
"Codec property for the StandardCodec are unexpected. expected null, got {}",
json_value.as_str().unwrap_or("null")
)));
}
Ok(StandardCodec)
}
fn to_json_props(&self) -> serde_json::Value {
serde_json::Value::Null
}
}

View File

@@ -1,50 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::DocId;
pub struct Block {
doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
len: usize,
}
impl Block {
pub fn new() -> Self {
Block {
doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
len: 0,
}
}
pub fn doc_ids(&self) -> &[DocId] {
&self.doc_ids[..self.len]
}
pub fn term_freqs(&self) -> &[u32] {
&self.term_freqs[..self.len]
}
pub fn clear(&mut self) {
self.len = 0;
}
pub fn append_doc(&mut self, doc: DocId, term_freq: u32) {
let len = self.len;
self.doc_ids[len] = doc;
self.term_freqs[len] = term_freq;
self.len = len + 1;
}
pub fn is_full(&self) -> bool {
self.len == COMPRESSION_BLOCK_SIZE
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn last_doc(&self) -> DocId {
assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
}
}

View File

@@ -1,13 +0,0 @@
use crate::codec::postings::PostingsCodec;
mod block;
mod postings_serializer;
mod skip;
pub use postings_serializer::StandardPostingsSerializer;
pub struct StandardPostingsCodec;
impl PostingsCodec for StandardPostingsCodec {
type PostingsSerializer = StandardPostingsSerializer;
}

View File

@@ -1,187 +0,0 @@
use std::cmp::Ordering;
use std::io::{self, Write as _};
use common::{BinarySerializable as _, VInt};
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::block::Block;
use crate::codec::standard::postings::skip::SkipSerializer;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockEncoder, VIntEncoder as _, COMPRESSION_BLOCK_SIZE};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
pub struct StandardPostingsSerializer {
last_doc_id_encoded: u32,
block_encoder: BlockEncoder,
block: Box<Block>,
postings_write: Vec<u8>,
skip_write: SkipSerializer,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<Bm25Weight>,
avg_fieldnorm: Score, /* Average number of term in the field for that segment.
* this value is used to compute the block wand information. */
term_has_freq: bool,
}
impl PostingsSerializer for StandardPostingsSerializer {
fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> StandardPostingsSerializer {
Self {
block_encoder: BlockEncoder::new(),
block: Box::new(Block::new()),
postings_write: Vec::new(),
skip_write: SkipSerializer::new(),
last_doc_id_encoded: 0u32,
mode,
fieldnorm_reader,
bm25_weight: None,
avg_fieldnorm,
term_has_freq: false,
}
}
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.bm25_weight = None;
self.term_has_freq = self.mode.has_freq() && record_term_freq;
if !self.term_has_freq {
return;
}
let num_docs_in_segment: u64 =
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
fieldnorm_reader.num_docs() as u64
} else {
return;
};
if num_docs_in_segment == 0 {
return;
}
self.bm25_weight = Some(Bm25Weight::for_one_term_without_explain(
term_doc_freq as u64,
num_docs_in_segment,
self.avg_fieldnorm,
));
}
fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
self.write_block();
}
}
fn close_term(
&mut self,
doc_freq: u32,
output_write: &mut impl std::io::Write,
) -> io::Result<()> {
if !self.block.is_empty() {
// we have doc ids waiting to be written
// this happens when the number of doc ids is
// not a perfect multiple of our block size.
//
// In that case, the remaining part is encoded
// using variable int encoding.
{
let block_encoded = self
.block_encoder
.compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.postings_write.write_all(block_encoded)?;
}
// ... Idem for term frequencies
if self.term_has_freq {
let block_encoded = self
.block_encoder
.compress_vint_unsorted(self.block.term_freqs());
self.postings_write.write_all(block_encoded)?;
}
self.block.clear();
}
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(output_write)?;
output_write.write_all(skip_data)?;
}
output_write.write_all(&self.postings_write[..])?;
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}
}
impl StandardPostingsSerializer {
fn write_block(&mut self) {
{
// encode the doc ids
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.last_doc_id_encoded = self.block.last_doc();
self.skip_write
.write_doc(self.last_doc_id_encoded, num_bits);
// last el block 0, offset block 1,
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
self.postings_write.extend(block_encoded);
self.skip_write.write_term_freq(num_bits);
if self.mode.has_positions() {
// We serialize the sum of term freqs within the skip information
// in order to navigate through positions.
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params = (0u8, 0u32);
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids().iter().cloned();
let term_freqs = self.block.term_freqs().iter().cloned();
let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
blockwand_params = fieldnorms
.zip(term_freqs)
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
)
.unwrap();
}
}
let (fieldnorm_id, term_freq) = blockwand_params;
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
fn clear(&mut self) {
self.block.clear();
self.last_doc_id_encoded = 0;
}
}

View File

@@ -1,448 +0,0 @@
use crate::directory::OwnedBytes;
use crate::postings::compression::{compressed_block_size, COMPRESSION_BLOCK_SIZE};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score, TERMINATED};
// doc num bits uses the following encoding:
// given 0b a b cdefgh
// |1|2|3| 4 |
// - 1: unused
// - 2: is delta-1 encoded. 0 if not, 1, if yes
// - 3: unused
// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
// be represented on 31b without delta encoding.
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
assert!(bitwidth < 32);
bitwidth | ((delta_1 as u8) << 6)
}
fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) {
let delta_1 = ((raw_bitwidth >> 6) & 1) != 0;
let bitwidth = raw_bitwidth & 0x1f;
(bitwidth, delta_1)
}
#[inline]
fn encode_block_wand_max_tf(max_tf: u32) -> u8 {
max_tf.min(u8::MAX as u32) as u8
}
#[inline]
fn decode_block_wand_max_tf(max_tf_code: u8) -> u32 {
if max_tf_code == u8::MAX {
u32::MAX
} else {
max_tf_code as u32
}
}
#[inline]
fn read_u32(data: &[u8]) -> u32 {
u32::from_le_bytes(data[..4].try_into().unwrap())
}
#[inline]
fn write_u32(val: u32, buf: &mut Vec<u8>) {
buf.extend_from_slice(&val.to_le_bytes());
}
pub struct SkipSerializer {
buffer: Vec<u8>,
}
impl SkipSerializer {
pub fn new() -> SkipSerializer {
SkipSerializer { buffer: Vec::new() }
}
pub fn write_doc(&mut self, last_doc: DocId, doc_num_bits: u8) {
write_u32(last_doc, &mut self.buffer);
self.buffer.push(encode_bitwidth(doc_num_bits, true));
}
pub fn write_term_freq(&mut self, tf_num_bits: u8) {
self.buffer.push(tf_num_bits);
}
pub fn write_total_term_freq(&mut self, tf_sum: u32) {
write_u32(tf_sum, &mut self.buffer);
}
pub fn write_blockwand_max(&mut self, fieldnorm_id: u8, term_freq: u32) {
let block_wand_tf = encode_block_wand_max_tf(term_freq);
self.buffer
.extend_from_slice(&[fieldnorm_id, block_wand_tf]);
}
pub fn data(&self) -> &[u8] {
&self.buffer[..]
}
pub fn clear(&mut self) {
self.buffer.clear();
}
}
#[derive(Clone)]
pub(crate) struct SkipReader {
last_doc_in_block: DocId,
pub(crate) last_doc_in_previous_block: DocId,
owned_read: OwnedBytes,
skip_info: IndexRecordOption,
byte_offset: usize,
remaining_docs: u32, // number of docs remaining, including the
// documents in the current block.
block_info: BlockInfo,
position_offset: u64,
}
#[derive(Clone, Eq, PartialEq, Copy, Debug)]
pub(crate) enum BlockInfo {
BitPacked {
doc_num_bits: u8,
strict_delta_encoded: bool,
tf_num_bits: u8,
tf_sum: u32,
block_wand_fieldnorm_id: u8,
block_wand_term_freq: u32,
},
VInt {
num_docs: u32,
},
}
impl Default for BlockInfo {
fn default() -> Self {
BlockInfo::VInt { num_docs: 0u32 }
}
}
impl SkipReader {
pub fn new(data: OwnedBytes, doc_freq: u32, skip_info: IndexRecordOption) -> SkipReader {
let mut skip_reader = SkipReader {
last_doc_in_block: if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
},
last_doc_in_previous_block: 0u32,
owned_read: data,
skip_info,
block_info: BlockInfo::VInt { num_docs: doc_freq },
byte_offset: 0,
remaining_docs: doc_freq,
position_offset: 0u64,
};
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
skip_reader.read_block_info();
}
skip_reader
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
};
self.last_doc_in_previous_block = 0u32;
self.owned_read = data;
self.block_info = BlockInfo::VInt { num_docs: doc_freq };
self.byte_offset = 0;
self.remaining_docs = doc_freq;
self.position_offset = 0u64;
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
}
}
// Returns the block max score for this block if available.
//
// The block max score is available for all full bitpacked block,
// but no available for the last VInt encoded incomplete block.
pub fn block_max_score(&self, bm25_weight: &Bm25Weight) -> Option<Score> {
match self.block_info {
BlockInfo::BitPacked {
block_wand_fieldnorm_id,
block_wand_term_freq,
..
} => Some(bm25_weight.score(block_wand_fieldnorm_id, block_wand_term_freq)),
BlockInfo::VInt { .. } => None,
}
}
pub(crate) fn last_doc_in_block(&self) -> DocId {
self.last_doc_in_block
}
pub fn position_offset(&self) -> u64 {
self.position_offset
}
#[inline]
pub fn byte_offset(&self) -> usize {
self.byte_offset
}
fn read_block_info(&mut self) {
let bytes = self.owned_read.as_slice();
let advance_len: usize;
self.last_doc_in_block = read_u32(bytes);
let (doc_num_bits, strict_delta_encoded) = decode_bitwidth(bytes[4]);
match self.skip_info {
IndexRecordOption::Basic => {
advance_len = 5;
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
strict_delta_encoded,
tf_num_bits: 0,
tf_sum: 0,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0,
};
}
IndexRecordOption::WithFreqs => {
let tf_num_bits = bytes[5];
let block_wand_fieldnorm_id = bytes[6];
let block_wand_term_freq = decode_block_wand_max_tf(bytes[7]);
advance_len = 8;
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
strict_delta_encoded,
tf_num_bits,
tf_sum: 0,
block_wand_fieldnorm_id,
block_wand_term_freq,
};
}
IndexRecordOption::WithFreqsAndPositions => {
let tf_num_bits = bytes[5];
let tf_sum = read_u32(&bytes[6..10]);
let block_wand_fieldnorm_id = bytes[10];
let block_wand_term_freq = decode_block_wand_max_tf(bytes[11]);
advance_len = 12;
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
strict_delta_encoded,
tf_num_bits,
tf_sum,
block_wand_fieldnorm_id,
block_wand_term_freq,
};
}
}
self.owned_read.advance(advance_len);
}
pub fn block_info(&self) -> BlockInfo {
self.block_info
}
/// Advance the skip reader to the block that may contain the target.
///
/// If the target is larger than all documents, the skip_reader
/// then advance to the last Variable In block.
pub fn seek(&mut self, target: DocId) -> bool {
if self.last_doc_in_block() >= target {
return false;
}
loop {
self.advance();
if self.last_doc_in_block() >= target {
return true;
}
}
}
pub fn advance(&mut self) {
match self.block_info {
BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum,
..
} => {
self.remaining_docs -= COMPRESSION_BLOCK_SIZE as u32;
self.byte_offset += compressed_block_size(doc_num_bits + tf_num_bits);
self.position_offset += tf_sum as u64;
}
BlockInfo::VInt { num_docs } => {
debug_assert_eq!(num_docs, self.remaining_docs);
self.remaining_docs = 0;
self.byte_offset = usize::MAX;
}
}
self.last_doc_in_previous_block = self.last_doc_in_block;
if self.remaining_docs >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
} else {
self.last_doc_in_block = TERMINATED;
self.block_info = BlockInfo::VInt {
num_docs: self.remaining_docs,
};
}
}
}
#[cfg(test)]
mod tests {
use super::{
decode_bitwidth, encode_bitwidth, BlockInfo, IndexRecordOption, SkipReader, SkipSerializer,
};
use crate::directory::OwnedBytes;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
#[test]
fn test_encode_block_wand_max_tf() {
for tf in 0..255 {
assert_eq!(super::encode_block_wand_max_tf(tf), tf as u8);
}
for &tf in &[255, 256, 1_000_000, u32::MAX] {
assert_eq!(super::encode_block_wand_max_tf(tf), 255);
}
}
#[test]
fn test_decode_block_wand_max_tf() {
for tf in 0..255 {
assert_eq!(super::decode_block_wand_max_tf(tf), tf as u32);
}
assert_eq!(super::decode_block_wand_max_tf(255), u32::MAX);
}
#[test]
fn test_skip_with_freq() {
let buf = {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.write_term_freq(3u8);
skip_serializer.write_blockwand_max(13u8, 3u32);
skip_serializer.write_doc(5u32, 5u8);
skip_serializer.write_term_freq(2u8);
skip_serializer.write_blockwand_max(8u8, 2u32);
skip_serializer.data().to_owned()
};
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32;
let mut skip_reader =
SkipReader::new(OwnedBytes::new(buf), doc_freq, IndexRecordOption::WithFreqs);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert_eq!(
skip_reader.block_info,
BlockInfo::BitPacked {
doc_num_bits: 2u8,
strict_delta_encoded: true,
tf_num_bits: 3u8,
tf_sum: 0,
block_wand_fieldnorm_id: 13,
block_wand_term_freq: 3
}
);
skip_reader.advance();
assert_eq!(skip_reader.last_doc_in_block(), 5u32);
assert_eq!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 5u8,
strict_delta_encoded: true,
tf_num_bits: 2u8,
tf_sum: 0,
block_wand_fieldnorm_id: 8,
block_wand_term_freq: 2
}
);
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 3u32 });
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 });
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 });
}
#[test]
fn test_skip_no_freq() {
let buf = {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.write_doc(5u32, 5u8);
skip_serializer.data().to_owned()
};
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32;
let mut skip_reader =
SkipReader::new(OwnedBytes::new(buf), doc_freq, IndexRecordOption::Basic);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert_eq!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
strict_delta_encoded: true,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
);
skip_reader.advance();
assert_eq!(skip_reader.last_doc_in_block(), 5u32);
assert_eq!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 5u8,
strict_delta_encoded: true,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
);
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 3u32 });
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 });
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 });
}
#[test]
fn test_skip_multiple_of_block_size() {
let buf = {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.data().to_owned()
};
let doc_freq = COMPRESSION_BLOCK_SIZE as u32;
let mut skip_reader =
SkipReader::new(OwnedBytes::new(buf), doc_freq, IndexRecordOption::Basic);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert_eq!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
strict_delta_encoded: true,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
);
skip_reader.advance();
assert_eq!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 });
}
#[test]
fn test_encode_decode_bitwidth() {
for bitwidth in 0..32 {
for delta_1 in [false, true] {
assert_eq!(
(bitwidth, delta_1),
decode_bitwidth(encode_bitwidth(bitwidth, delta_1))
);
}
}
assert_eq!(0b01000010, encode_bitwidth(0b10, true));
assert_eq!(0b00000010, encode_bitwidth(0b10, false));
}
}

View File

@@ -0,0 +1,121 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score, SegmentReader};
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
}
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where TScore: Clone + PartialOrd
{
pub(crate) fn new(
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector {
custom_scorer,
collector,
}
}
}
/// A custom segment scorer makes it possible to define any kind of score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`CustomScorer`].
pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`.
fn score(&mut self, doc: DocId) -> TScore;
}
/// `CustomScorer` makes it possible to define any kind of score.
///
/// The `CustomerScorer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait CustomScorer<TScore>: Sync {
/// Type of the associated [`CustomSegmentScorer`].
type Child: CustomSegmentScorer<TScore>;
/// Builds a child scorer for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
}
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
where
TCustomScorer: CustomScorer<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
Ok(CustomScoreTopSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
T: CustomSegmentScorer<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: T,
}
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
T: 'static + CustomSegmentScorer<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, _score: Score) {
let score = self.segment_scorer.score(doc);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, T> CustomScorer<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
T: CustomSegmentScorer<TScore>,
{
type Child = T;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> CustomSegmentScorer<TScore> for F
where F: 'static + FnMut(DocId) -> TScore
{
fn score(&mut self, doc: DocId) -> TScore {
(self)(doc)
}
}

View File

@@ -484,6 +484,7 @@ impl FacetCounts {
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use std::iter;
use columnar::Dictionary;
use rand::distributions::Uniform;

View File

@@ -12,7 +12,6 @@ use std::marker::PhantomData;
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
use crate::collector::{Collector, SegmentCollector};
use crate::schema::Schema;
use crate::{DocId, Score, SegmentReader};
/// The `FilterCollector` filters docs using a fast field value and a predicate.
@@ -50,13 +49,13 @@ use crate::{DocId, Score, SegmentReader};
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
///
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
///
/// assert_eq!(filtered_top_docs.len(), 0);
@@ -105,11 +104,6 @@ where
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -126,7 +120,6 @@ where
segment_collector,
predicate: self.predicate.clone(),
t_predicate_value: PhantomData,
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -147,7 +140,6 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
segment_collector: TSegmentCollector,
predicate: TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate, TPredicateValue>
@@ -184,20 +176,6 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}
@@ -240,7 +218,7 @@ where
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
/// let top_docs = searcher.search(&query, &filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
@@ -280,10 +258,6 @@ where
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -300,7 +274,6 @@ where
segment_collector,
predicate: self.predicate.clone(),
buffer: Vec::new(),
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -323,7 +296,6 @@ where TPredicate: 'static
segment_collector: TSegmentCollector,
predicate: TPredicate,
buffer: Vec<u8>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
@@ -362,20 +334,6 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}

View File

@@ -57,7 +57,7 @@
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
//! # let query = query_parser.parse_query("diary")?;
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
//! # Ok(())
//! # }
//! ```
@@ -83,15 +83,11 @@
use downcast_rs::impl_downcast;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
mod count_collector;
pub use self::count_collector::Count;
/// Sort keys
pub mod sort_key;
mod histogram_collector;
pub use histogram_collector::HistogramCollector;
@@ -99,13 +95,16 @@ mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
pub use self::top_collector::ComparableDoc;
mod top_score_collector;
pub use self::top_collector::ComparableDoc;
pub use self::top_score_collector::{TopDocs, TopNComputer};
mod sort_key_top_collector;
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
mod tweak_score_top_collector;
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod facet_collector;
pub use self::facet_collector::{FacetCollector, FacetCounts};
use crate::query::Weight;
@@ -146,11 +145,6 @@ pub trait Collector: Sync + Send {
/// Type of the `SegmentCollector` associated with this collector.
type Child: SegmentCollector;
/// Returns an error if the schema is not compatible with the collector.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// `set_segment` is called before beginning to enumerate
/// on this segment.
fn for_segment(
@@ -176,50 +170,41 @@ pub trait Collector: Sync + Send {
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
match (reader.alive_bitset(), self.requires_scoring()) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Ok(segment_collector.harvest())
}
}
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Ok(())
}
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
type Fruit = Option<TSegmentCollector::Fruit>;
@@ -229,12 +214,6 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
}
}
fn collect_block(&mut self, docs: &[DocId]) {
if let Some(segment_collector) = self {
segment_collector.collect_block(docs);
}
}
fn harvest(self) -> Self::Fruit {
self.map(|segment_collector| segment_collector.harvest())
}
@@ -245,13 +224,6 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
type Child = Option<<TCollector as Collector>::Child>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
if let Some(underlying_collector) = self {
underlying_collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -327,12 +299,6 @@ where
type Fruit = (Left::Fruit, Right::Fruit);
type Child = (Left::Child, Right::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -376,11 +342,6 @@ where
self.1.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest())
}
@@ -397,13 +358,6 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
type Child = (One::Child, Two::Child, Three::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -453,12 +407,6 @@ where
self.2.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest(), self.2.harvest())
}
@@ -476,14 +424,6 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -542,13 +482,6 @@ where
self.3.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
self.3.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(
self.0.harvest(),

View File

@@ -3,7 +3,6 @@ use std::ops::Deref;
use super::{Collector, SegmentCollector};
use crate::collector::Fruit;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
/// MultiFruit keeps Fruits from every nested Collector
@@ -17,10 +16,6 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
type Fruit = Box<dyn Fruit>;
type Child = Box<dyn BoxableSegmentCollector>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -152,7 +147,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
/// let searcher = reader.searcher();
///
/// let mut collectors = MultiCollector::new();
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
/// let count_handle = collectors.add_collector(Count);
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
@@ -199,13 +194,6 @@ impl Collector for MultiCollector<'_> {
type Fruit = MultiFruit;
type Child = MultiCollectorChild;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
for collector in &self.collector_wrappers {
collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -262,12 +250,6 @@ impl SegmentCollector for MultiCollectorChild {
}
}
fn collect_block(&mut self, docs: &[DocId]) {
for child in &mut self.children {
child.collect_block(docs);
}
}
fn harvest(self) -> MultiFruit {
MultiFruit {
sub_fruits: self
@@ -311,7 +293,7 @@ mod tests {
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut collectors = MultiCollector::new();
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
let count_handler = collectors.add_collector(Count);
let mut multifruits = searcher.search(&query, &collectors).unwrap();

Some files were not shown because too many files have changed in this diff Show More