mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-03 07:42:54 +00:00
Compare commits
66 Commits
clippy-and
...
flatheadmi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
643639f14b | ||
|
|
f85a27068d | ||
|
|
1619e05bc5 | ||
|
|
5d03c600ba | ||
|
|
32beb06382 | ||
|
|
d8bc0e7c99 | ||
|
|
79622f1f0b | ||
|
|
d26d6c34fc | ||
|
|
6da54fa5da | ||
|
|
9f10279681 | ||
|
|
68009bb25b | ||
|
|
459456ca28 | ||
|
|
dbbc8c3f65 | ||
|
|
d3049cb323 | ||
|
|
ccdf399cd7 | ||
|
|
2dc46b235e | ||
|
|
f38140f72f | ||
|
|
0996bea7ac | ||
|
|
1c66567efc | ||
|
|
b2a9bb279d | ||
|
|
558c99fa2d | ||
|
|
43b5f34721 | ||
|
|
63c66005db | ||
|
|
7d513a44c5 | ||
|
|
ca87fcd454 | ||
|
|
08a92675dc | ||
|
|
f7f4b354d6 | ||
|
|
25d44fcec8 | ||
|
|
842fe9295f | ||
|
|
f88b7200b2 | ||
|
|
8725594d47 | ||
|
|
43a784671a | ||
|
|
c363bbd23d | ||
|
|
70e591e230 | ||
|
|
5277367cb0 | ||
|
|
8b02bff9b8 | ||
|
|
60225bdd45 | ||
|
|
938bfec8b7 | ||
|
|
dabcaa5809 | ||
|
|
d410a3b0c0 | ||
|
|
fc93391d0e | ||
|
|
f8e79271ab | ||
|
|
33835b6a01 | ||
|
|
270ca5123c | ||
|
|
714366d3b9 | ||
|
|
40659d4d07 | ||
|
|
e1e131a804 | ||
|
|
70da310b2d | ||
|
|
85010b589a | ||
|
|
2340dca628 | ||
|
|
71a26d5b24 | ||
|
|
203751f2fe | ||
|
|
7963b0b4aa | ||
|
|
d5eefca11d | ||
|
|
5d6c8de23e | ||
|
|
a06365f39f | ||
|
|
f4b374110f | ||
|
|
c37af9c1ff | ||
|
|
33794a114c | ||
|
|
8676a1f57b | ||
|
|
021ff2ad63 | ||
|
|
39e027667b | ||
|
|
a1d65c3df3 | ||
|
|
2e4615c2d3 | ||
|
|
610091e2c4 | ||
|
|
d4b090124c |
28
CHANGELOG.md
28
CHANGELOG.md
@@ -2,14 +2,30 @@ Tantivy 0.25
|
||||
================================
|
||||
|
||||
## Bugfixes
|
||||
- fix union performance regression in tantivy 0.24 [#2663](https://github.com/quickwit-oss/tantivy/pull/2663)(@PSeitz-dd)
|
||||
- fix union performance regression in tantivy 0.24 [#2663](https://github.com/quickwit-oss/tantivy/pull/2663)(@PSeitz)
|
||||
- make zstd optional in sstable [#2633](https://github.com/quickwit-oss/tantivy/pull/2633)(@Parth)
|
||||
- Fix TopDocs::order_by_string_fast_field for asc order [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
|
||||
|
||||
## Features/Improvements
|
||||
- add docs/example and Vec<u32> values to sstable [#2660](https://github.com/quickwit-oss/tantivy/pull/2660)(@PSeitz)
|
||||
- Add string fast field support to `TopDocs`. [#2642](https://github.com/quickwit-oss/tantivy/pull/2642)(@stuhood)
|
||||
- update edition to 2024 [#2620](https://github.com/quickwit-oss/tantivy/pull/2620)(@PSeitz)
|
||||
- Allow optional spaces between the field name and the value in the query parser [#2678](https://github.com/quickwit-oss/tantivy/pull/2678)(@Darkheir)
|
||||
- Support mixed field types in query parser [#2676](https://github.com/quickwit-oss/tantivy/pull/2676)(@trinity-1686a)
|
||||
- Add per-field size details [#2679](https://github.com/quickwit-oss/tantivy/pull/2679)(@fulmicoton)
|
||||
|
||||
Tantivy 0.24.2
|
||||
================================
|
||||
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
|
||||
|
||||
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
|
||||
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
|
||||
for `Order::Asc`
|
||||
|
||||
Tantivy 0.24.1
|
||||
================================
|
||||
- Fix: bump required rust version to 1.81
|
||||
|
||||
Tantivy 0.24
|
||||
================================
|
||||
Tantivy 0.24 will be backwards compatible with indices created with v0.22 and v0.21. The new minimum rust version will be 1.75. Tantivy 0.23 will be skipped.
|
||||
@@ -62,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
|
||||
|
||||
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
|
||||
|
||||
- **Performace/Memory**
|
||||
- **Performance/Memory**
|
||||
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
|
||||
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
|
||||
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)
|
||||
@@ -92,6 +108,14 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
|
||||
- Fix trait bound of StoreReader::iter [#2360](https://github.com/quickwit-oss/tantivy/pull/2360)(@adamreichold)
|
||||
- remove read_postings_no_deletes [#2526](https://github.com/quickwit-oss/tantivy/pull/2526)(@PSeitz)
|
||||
|
||||
Tantivy 0.22.1
|
||||
================================
|
||||
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
|
||||
|
||||
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
|
||||
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
|
||||
for `Order::Asc`
|
||||
|
||||
Tantivy 0.22
|
||||
================================
|
||||
|
||||
|
||||
29
Cargo.toml
29
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy"
|
||||
version = "0.24.0"
|
||||
version = "0.26.0"
|
||||
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
|
||||
license = "MIT"
|
||||
categories = ["database-implementations", "data-structures"]
|
||||
@@ -56,19 +56,22 @@ itertools = "0.14.0"
|
||||
measure_time = "0.9.0"
|
||||
arc-swap = "1.5.0"
|
||||
bon = "3.3.1"
|
||||
i_triangle = "0.38.0"
|
||||
|
||||
columnar = { version = "0.5", path = "./columnar", package = "tantivy-columnar" }
|
||||
sstable = { version = "0.5", path = "./sstable", package = "tantivy-sstable", optional = true }
|
||||
stacker = { version = "0.5", path = "./stacker", package = "tantivy-stacker" }
|
||||
query-grammar = { version = "0.24.0", path = "./query-grammar", package = "tantivy-query-grammar" }
|
||||
tantivy-bitpacker = { version = "0.8", path = "./bitpacker" }
|
||||
common = { version = "0.9", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.5", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
|
||||
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
|
||||
stacker = { version = "0.6", path = "./stacker", package = "tantivy-stacker" }
|
||||
query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tantivy-query-grammar" }
|
||||
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
|
||||
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
|
||||
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
typetag = "0.2.21"
|
||||
geo-types = "0.7.17"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = "0.3.9"
|
||||
@@ -87,7 +90,7 @@ more-asserts = "0.3.1"
|
||||
rand_distr = "0.4.3"
|
||||
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
||||
postcard = { version = "1.0.4", features = [
|
||||
"use-std",
|
||||
"use-std",
|
||||
], default-features = false }
|
||||
|
||||
[target.'cfg(not(windows))'.dev-dependencies]
|
||||
@@ -167,3 +170,11 @@ harness = false
|
||||
[[bench]]
|
||||
name = "agg_bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "exists_json"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "and_or_queries"
|
||||
harness = false
|
||||
|
||||
@@ -23,8 +23,6 @@ performance for different types of queries/collections.
|
||||
|
||||
Your mileage WILL vary depending on the nature of queries and their load.
|
||||
|
||||
<img src="doc/assets/images/searchbenchmark.png">
|
||||
|
||||
Details about the benchmark can be found at this [repository](https://github.com/quickwit-oss/search-benchmark-game).
|
||||
|
||||
## Features
|
||||
|
||||
27
RELEASE.md
27
RELEASE.md
@@ -1,4 +1,4 @@
|
||||
# Release a new Tantivy Version
|
||||
# Releasing a new Tantivy Version
|
||||
|
||||
## Steps
|
||||
|
||||
@@ -10,12 +10,29 @@
|
||||
6. Set git tag with new version
|
||||
|
||||
|
||||
In conjucation with `cargo-release` Steps 1-4 (I'm not sure if the change detection works):
|
||||
Set new packages to version 0.0.0
|
||||
[`cargo-release`](https://github.com/crate-ci/cargo-release) will help us with steps 1-5:
|
||||
|
||||
Replace prev-tag-name
|
||||
```bash
|
||||
cargo release --workspace --no-publish -v --prev-tag-name 0.19 --push-remote origin minor --no-tag --execute
|
||||
cargo release --workspace --no-publish -v --prev-tag-name 0.24 --push-remote origin minor --no-tag
|
||||
```
|
||||
|
||||
no-tag or it will create tags for all the subpackages
|
||||
`no-tag` or it will create tags for all the subpackages
|
||||
|
||||
cargo release will _not_ ignore unchanged packages, but it will print warnings for them.
|
||||
e.g. "warning: updating ownedbytes to 0.10.0 despite no changes made since tag 0.24"
|
||||
|
||||
We need to manually ignore these unchanged packages
|
||||
```bash
|
||||
cargo release --workspace --no-publish -v --prev-tag-name 0.24 --push-remote origin minor --no-tag --exclude tokenizer-api
|
||||
```
|
||||
|
||||
Add `--execute` to actually publish the packages, otherwise it will only print the commands that would be run.
|
||||
|
||||
### Tag Version
|
||||
```bash
|
||||
git tag 0.25.0
|
||||
git push upstream tag 0.25.0
|
||||
```
|
||||
|
||||
|
||||
|
||||
2
TODO.txt
2
TODO.txt
@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
|
||||
remove fast field reader
|
||||
|
||||
find a way to unify the two DateTime.
|
||||
readd type check in the filter wrapper
|
||||
re-add type check in the filter wrapper
|
||||
|
||||
add unit test on columnar list columns.
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_few_with_avg_sub_agg);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
@@ -71,8 +73,15 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, histogram);
|
||||
register!(group, histogram_hard_bounds);
|
||||
register!(group, histogram_with_avg_sub_agg);
|
||||
register!(group, histogram_with_term_agg_few);
|
||||
register!(group, avg_and_range_with_avg_sub_agg);
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
register!(group, filter_agg_all_query_count_agg);
|
||||
register!(group, filter_agg_term_query_count_agg);
|
||||
register!(group, filter_agg_all_query_with_sub_aggs);
|
||||
register!(group, filter_agg_term_query_with_sub_aggs);
|
||||
|
||||
group.run();
|
||||
}
|
||||
|
||||
@@ -213,6 +222,19 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -339,6 +361,17 @@ fn histogram_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn histogram_with_term_agg_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"histogram": { "field": "score_f64", "interval": 10 },
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn avg_and_range_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
@@ -460,3 +493,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
|
||||
fn filter_agg_all_query_count_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"filtered": {
|
||||
"filter": "*",
|
||||
"aggs": {
|
||||
"count": { "value_count": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn filter_agg_term_query_count_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"filtered": {
|
||||
"filter": "text:cool",
|
||||
"aggs": {
|
||||
"count": { "value_count": { "field": "score" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"filtered": {
|
||||
"filter": "*",
|
||||
"aggs": {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"filtered": {
|
||||
"filter": "text:cool",
|
||||
"aggs": {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
218
benches/and_or_queries.rs
Normal file
218
benches/and_or_queries.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
// Benchmarks boolean conjunction queries using binggan.
|
||||
//
|
||||
// What’s measured:
|
||||
// - Or and And queries with varying selectivity (only `Term` queries for now on leafs)
|
||||
// - Nested AND/OR combinations (on multiple fields)
|
||||
// - No-scoring path using the Count collector (focus on iterator/skip performance)
|
||||
// - Top-K retrieval (k=10) using the TopDocs collector
|
||||
//
|
||||
// Corpus model:
|
||||
// - Synthetic docs; each token a/b/c is independently included per doc
|
||||
// - If none of a/b/c are included, emit a neutral filler token to keep doc length similar
|
||||
//
|
||||
// Notes:
|
||||
// - After optimization, when scoring is disabled Tantivy reads doc-only postings
|
||||
// (IndexRecordOption::Basic), avoiding frequency decoding overhead.
|
||||
// - This bench isolates boolean iteration speed and intersection/union cost.
|
||||
// - Use `cargo bench --bench boolean_conjunction` to run.
|
||||
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::sort_key::SortByStaticFastValue;
|
||||
use tantivy::collector::{Collector, Count, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
/// Build a single index containing both fields (title, body) and
|
||||
/// return two BenchIndex views:
|
||||
/// - single_field: QueryParser defaults to only "body"
|
||||
/// - multi_field: QueryParser defaults to ["title", "body"]
|
||||
fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (BenchIndex, BenchIndex) {
|
||||
// Unified schema (two text fields)
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_body = schema_builder.add_text_field("body", TEXT);
|
||||
let f_score = schema_builder.add_u64_field("score", FAST);
|
||||
let f_score2 = schema_builder.add_u64_field("score2", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
// Populate: spread each present token 90/10 to body/title
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
|
||||
for _ in 0..num_docs {
|
||||
let has_a = rng.gen_bool(p_a as f64);
|
||||
let has_b = rng.gen_bool(p_b as f64);
|
||||
let has_c = rng.gen_bool(p_c as f64);
|
||||
let score = rng.gen_range(0u64..100u64);
|
||||
let score2 = rng.gen_range(0u64..100_000u64);
|
||||
let mut title_tokens: Vec<&str> = Vec::new();
|
||||
let mut body_tokens: Vec<&str> = Vec::new();
|
||||
if has_a {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("a");
|
||||
} else {
|
||||
body_tokens.push("a");
|
||||
}
|
||||
}
|
||||
if has_b {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("b");
|
||||
} else {
|
||||
body_tokens.push("b");
|
||||
}
|
||||
}
|
||||
if has_c {
|
||||
if rng.gen_bool(0.1) {
|
||||
title_tokens.push("c");
|
||||
} else {
|
||||
body_tokens.push("c");
|
||||
}
|
||||
}
|
||||
if title_tokens.is_empty() && body_tokens.is_empty() {
|
||||
body_tokens.push("z");
|
||||
}
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_tokens.join(" "),
|
||||
f_body=>body_tokens.join(" "),
|
||||
f_score=>score,
|
||||
f_score2=>score2,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Build two query parsers with different default fields.
|
||||
let qp_single = QueryParser::for_index(&index, vec![f_body]);
|
||||
let qp_multi = QueryParser::for_index(&index, vec![f_title, f_body]);
|
||||
|
||||
let single_view = BenchIndex {
|
||||
index: index.clone(),
|
||||
searcher: searcher.clone(),
|
||||
query_parser: qp_single,
|
||||
};
|
||||
let multi_view = BenchIndex {
|
||||
index,
|
||||
searcher,
|
||||
query_parser: qp_multi,
|
||||
};
|
||||
(single_view, multi_view)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying selectivity. Build one index per corpus
|
||||
// and derive two views (single-field vs multi-field) from it.
|
||||
let scenarios = vec![
|
||||
(
|
||||
"N=1M, p(a)=5%, p(b)=1%, p(c)=15%".to_string(),
|
||||
1_000_000,
|
||||
0.05,
|
||||
0.01,
|
||||
0.15,
|
||||
),
|
||||
(
|
||||
"N=1M, p(a)=1%, p(b)=1%, p(c)=15%".to_string(),
|
||||
1_000_000,
|
||||
0.01,
|
||||
0.01,
|
||||
0.15,
|
||||
),
|
||||
];
|
||||
|
||||
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (label, n, pa, pb, pc) in scenarios {
|
||||
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
|
||||
|
||||
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
|
||||
{
|
||||
// Single-field group: default field is body only
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("{} — {}", view_name, label));
|
||||
for query_str in queries {
|
||||
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_score(),
|
||||
"top10",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
|
||||
"top10_by_ff",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by((
|
||||
SortByStaticFastValue::<u64>::for_field("score"),
|
||||
SortByStaticFastValue::<u64>::for_field("score2"),
|
||||
)),
|
||||
"top10_by_2ff",
|
||||
);
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
1
|
||||
}
|
||||
}
|
||||
69
benches/exists_json.rs
Normal file
69
benches/exists_json.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use serde_json::json;
|
||||
use tantivy::collector::Count;
|
||||
use tantivy::query::ExistsQuery;
|
||||
use tantivy::schema::{Schema, FAST, TEXT};
|
||||
use tantivy::{doc, Index};
|
||||
|
||||
#[global_allocator]
|
||||
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
|
||||
|
||||
fn main() {
|
||||
let doc_count: usize = 500_000;
|
||||
let subfield_counts: &[usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 16, 256, 4096, 65536, 262144];
|
||||
|
||||
let indices: Vec<(String, Index)> = subfield_counts
|
||||
.iter()
|
||||
.map(|&sub_fields| {
|
||||
(
|
||||
format!("subfields={sub_fields}"),
|
||||
build_index_with_json_subfields(doc_count, sub_fields),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut group = InputGroup::new_with_inputs(indices);
|
||||
group.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
|
||||
|
||||
group.config().num_iter_group = Some(1);
|
||||
group.config().num_iter_bench = Some(1);
|
||||
group.register("exists_json", exists_json_union);
|
||||
|
||||
group.run();
|
||||
}
|
||||
|
||||
fn exists_json_union(index: &Index) {
|
||||
let reader = index.reader().expect("reader");
|
||||
let searcher = reader.searcher();
|
||||
let query = ExistsQuery::new("json".to_string(), true);
|
||||
let count = searcher.search(&query, &Count).expect("exists search");
|
||||
// Prevents optimizer from eliding the search
|
||||
black_box(count);
|
||||
}
|
||||
|
||||
fn build_index_with_json_subfields(num_docs: usize, num_subfields: usize) -> Index {
|
||||
// Schema: single JSON field stored as FAST to support ExistsQuery.
|
||||
let mut schema_builder = Schema::builder();
|
||||
let json_field = schema_builder.add_json_field("json", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
|
||||
let index = Index::create_from_tempdir(schema).expect("create index");
|
||||
{
|
||||
let mut index_writer = index
|
||||
.writer_with_num_threads(1, 200_000_000)
|
||||
.expect("writer");
|
||||
for i in 0..num_docs {
|
||||
let sub = i % num_subfields;
|
||||
// Only one subpath set per document; rotate subpaths so that
|
||||
// no single subpath is full, but the union covers all docs.
|
||||
let v = json!({ format!("field_{sub}"): i as u64 });
|
||||
index_writer
|
||||
.add_document(doc!(json_field => v))
|
||||
.expect("add_document");
|
||||
}
|
||||
index_writer.commit().expect("commit");
|
||||
}
|
||||
|
||||
index
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy-bitpacker"
|
||||
version = "0.8.0"
|
||||
version = "0.9.0"
|
||||
edition = "2024"
|
||||
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
|
||||
license = "MIT"
|
||||
|
||||
@@ -48,7 +48,7 @@ impl BitPacker {
|
||||
|
||||
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
|
||||
if self.mini_buffer_written > 0 {
|
||||
let num_bytes = (self.mini_buffer_written + 7) / 8;
|
||||
let num_bytes = self.mini_buffer_written.div_ceil(8);
|
||||
let bytes = self.mini_buffer.to_le_bytes();
|
||||
output.write_all(&bytes[..num_bytes])?;
|
||||
self.mini_buffer_written = 0;
|
||||
@@ -138,7 +138,7 @@ impl BitUnpacker {
|
||||
|
||||
// We use `usize` here to avoid overflow issues.
|
||||
let end_bit_read = (end_idx as usize) * self.num_bits;
|
||||
let end_byte_read = (end_bit_read + 7) / 8;
|
||||
let end_byte_read = end_bit_read.div_ceil(8);
|
||||
assert!(
|
||||
end_byte_read <= data.len(),
|
||||
"Requested index is out of bounds."
|
||||
@@ -258,7 +258,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut data).unwrap();
|
||||
}
|
||||
bitpacker.close(&mut data).unwrap();
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
(bitunpacker, vals, data)
|
||||
}
|
||||
@@ -304,7 +304,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut buffer).unwrap();
|
||||
}
|
||||
bitpacker.flush(&mut buffer).unwrap();
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
let max_val = if num_bits == 64 {
|
||||
u64::MAX
|
||||
|
||||
@@ -140,10 +140,10 @@ impl BlockedBitpacker {
|
||||
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
|
||||
// todo performance: we could decompress a whole block and cache it instead
|
||||
let bitpacked_elems = self.offset_and_bits.len() * BLOCK_SIZE;
|
||||
let iter = (0..bitpacked_elems)
|
||||
|
||||
(0..bitpacked_elems)
|
||||
.map(move |idx| self.get(idx))
|
||||
.chain(self.buffer.iter().cloned());
|
||||
iter
|
||||
.chain(self.buffer.iter().cloned())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy-columnar"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
edition = "2024"
|
||||
license = "MIT"
|
||||
homepage = "https://github.com/quickwit-oss/tantivy"
|
||||
@@ -12,10 +12,10 @@ categories = ["database-implementations", "data-structures", "compression"]
|
||||
itertools = "0.14.0"
|
||||
fastdivide = "0.4.0"
|
||||
|
||||
stacker = { version= "0.5", path = "../stacker", package="tantivy-stacker"}
|
||||
sstable = { version= "0.5", path = "../sstable", package = "tantivy-sstable" }
|
||||
common = { version= "0.9", path = "../common", package = "tantivy-common" }
|
||||
tantivy-bitpacker = { version= "0.8", path = "../bitpacker/" }
|
||||
stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
|
||||
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
|
||||
common = { version= "0.10", path = "../common", package = "tantivy-common" }
|
||||
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
|
||||
serde = "1.0.152"
|
||||
downcast-rs = "2.0.1"
|
||||
|
||||
@@ -33,6 +33,29 @@ harness = false
|
||||
name = "bench_access"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_first_vals"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_values_u64"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_values_u128"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_create_column_values"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_column_values_get"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_optional_index"
|
||||
harness = false
|
||||
|
||||
[features]
|
||||
unstable = []
|
||||
zstd-compression = ["sstable/zstd-compression"]
|
||||
|
||||
@@ -73,7 +73,7 @@ The crate introduces the following concepts.
|
||||
`Columnar` is an equivalent of a dataframe.
|
||||
It maps `column_key` to `Column`.
|
||||
|
||||
A `Column<T>` asssociates a `RowId` (u32) to any
|
||||
A `Column<T>` associates a `RowId` (u32) to any
|
||||
number of values.
|
||||
|
||||
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.
|
||||
|
||||
@@ -19,7 +19,7 @@ fn main() {
|
||||
|
||||
let mut add_card = |card1: Card| {
|
||||
inputs.push((
|
||||
format!("{card1}"),
|
||||
card1.to_string(),
|
||||
generate_columnar_and_open(card1, NUM_DOCS),
|
||||
));
|
||||
};
|
||||
@@ -50,6 +50,7 @@ fn bench_group(mut runner: InputGroup<Column>) {
|
||||
let mut buffer = vec![None; BLOCK_SIZE];
|
||||
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
|
||||
// fill docs
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for idx in 0..BLOCK_SIZE {
|
||||
docs[idx] = idx as u32 + i;
|
||||
}
|
||||
|
||||
61
columnar/benches/bench_column_values_get.rs
Normal file
61
columnar/benches/bench_column_values_get.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use binggan::{InputGroup, black_box};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy_columnar::ColumnValues;
|
||||
use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_column_values};
|
||||
|
||||
fn get_data() -> Vec<u64> {
|
||||
let mut rng = StdRng::seed_from_u64(2u64);
|
||||
let mut data: Vec<_> = (100..55_000_u64)
|
||||
.map(|num| num + rng.r#gen::<u8>() as u64)
|
||||
.collect();
|
||||
data.push(99_000);
|
||||
data.insert(1000, 2000);
|
||||
data.insert(2000, 100);
|
||||
data.insert(3000, 4100);
|
||||
data.insert(4000, 100);
|
||||
data.insert(5000, 800);
|
||||
data
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn value_iter() -> impl Iterator<Item = u64> {
|
||||
0..20_000
|
||||
}
|
||||
|
||||
type Col = Arc<dyn ColumnValues<u64>>;
|
||||
|
||||
fn main() {
|
||||
let data = get_data();
|
||||
let inputs: Vec<(String, Col)> = vec![
|
||||
(
|
||||
"bitpacked".to_string(),
|
||||
serialize_and_load_u64_based_column_values(&data.as_slice(), &[CodecType::Bitpacked]),
|
||||
),
|
||||
(
|
||||
"linear".to_string(),
|
||||
serialize_and_load_u64_based_column_values(&data.as_slice(), &[CodecType::Linear]),
|
||||
),
|
||||
(
|
||||
"blockwise_linear".to_string(),
|
||||
serialize_and_load_u64_based_column_values(
|
||||
&data.as_slice(),
|
||||
&[CodecType::BlockwiseLinear],
|
||||
),
|
||||
),
|
||||
];
|
||||
|
||||
let mut group: InputGroup<Col> = InputGroup::new_with_inputs(inputs);
|
||||
|
||||
group.register("fastfield_get", |col: &Col| {
|
||||
let mut sum = 0u64;
|
||||
for pos in value_iter() {
|
||||
sum = sum.wrapping_add(col.get_val(pos as u32));
|
||||
}
|
||||
black_box(sum);
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
44
columnar/benches/bench_create_column_values.rs
Normal file
44
columnar/benches/bench_create_column_values.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use binggan::{InputGroup, black_box};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_values};
|
||||
|
||||
fn get_data() -> Vec<u64> {
|
||||
let mut rng = StdRng::seed_from_u64(2u64);
|
||||
let mut data: Vec<_> = (100..55_000_u64)
|
||||
.map(|num| num + rng.r#gen::<u8>() as u64)
|
||||
.collect();
|
||||
data.push(99_000);
|
||||
data.insert(1000, 2000);
|
||||
data.insert(2000, 100);
|
||||
data.insert(3000, 4100);
|
||||
data.insert(4000, 100);
|
||||
data.insert(5000, 800);
|
||||
data
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let data = get_data();
|
||||
let mut group: InputGroup<(CodecType, Vec<u64>)> = InputGroup::new_with_inputs(vec![
|
||||
(
|
||||
"bitpacked codec".to_string(),
|
||||
(CodecType::Bitpacked, data.clone()),
|
||||
),
|
||||
(
|
||||
"linear codec".to_string(),
|
||||
(CodecType::Linear, data.clone()),
|
||||
),
|
||||
(
|
||||
"blockwise linear codec".to_string(),
|
||||
(CodecType::BlockwiseLinear, data.clone()),
|
||||
),
|
||||
]);
|
||||
|
||||
group.register("serialize column_values", |data| {
|
||||
let mut buffer = Vec::new();
|
||||
serialize_u64_based_column_values(&data.1.as_slice(), &[data.0], &mut buffer).unwrap();
|
||||
black_box(buffer.len());
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
@@ -1,12 +1,9 @@
|
||||
#![feature(test)]
|
||||
extern crate test;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use binggan::{InputGroup, black_box};
|
||||
use rand::prelude::*;
|
||||
use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_column_values};
|
||||
use tantivy_columnar::*;
|
||||
use test::{Bencher, black_box};
|
||||
|
||||
struct Columns {
|
||||
pub optional: Column,
|
||||
@@ -68,88 +65,38 @@ pub fn serialize_and_load(column: &[u64], codec_type: CodecType) -> Arc<dyn Colu
|
||||
serialize_and_load_u64_based_column_values(&column, &[codec_type])
|
||||
}
|
||||
|
||||
fn run_bench_on_column_full_scan(b: &mut Bencher, column: Column) {
|
||||
let num_iter = black_box(NUM_VALUES);
|
||||
b.iter(|| {
|
||||
fn main() {
|
||||
let Columns {
|
||||
optional,
|
||||
full,
|
||||
multi,
|
||||
} = get_test_columns();
|
||||
|
||||
let inputs = vec![
|
||||
("full".to_string(), full),
|
||||
("optional".to_string(), optional),
|
||||
("multi".to_string(), multi),
|
||||
];
|
||||
|
||||
let mut group = InputGroup::new_with_inputs(inputs);
|
||||
|
||||
group.register("first_full_scan", |column| {
|
||||
let mut sum = 0u64;
|
||||
for i in 0..num_iter as u32 {
|
||||
for i in 0..NUM_VALUES as u32 {
|
||||
let val = column.first(i);
|
||||
sum += val.unwrap_or(0);
|
||||
}
|
||||
sum
|
||||
black_box(sum);
|
||||
});
|
||||
}
|
||||
fn run_bench_on_column_block_fetch(b: &mut Bencher, column: Column) {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
b.iter(move || {
|
||||
column.first_vals(&fetch_docids, &mut block);
|
||||
block[0]
|
||||
});
|
||||
}
|
||||
fn run_bench_on_column_block_single_calls(b: &mut Bencher, column: Column) {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
b.iter(move || {
|
||||
|
||||
group.register("first_block_single_calls", |column| {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
for i in 0..fetch_docids.len() {
|
||||
block[i] = column.first(fetch_docids[i]);
|
||||
}
|
||||
block[0]
|
||||
black_box(block[0]);
|
||||
});
|
||||
}
|
||||
|
||||
/// Column first method
|
||||
#[bench]
|
||||
fn bench_get_first_on_full_column_full_scan(b: &mut Bencher) {
|
||||
let column = get_test_columns().full;
|
||||
run_bench_on_column_full_scan(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_first_on_optional_column_full_scan(b: &mut Bencher) {
|
||||
let column = get_test_columns().optional;
|
||||
run_bench_on_column_full_scan(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_first_on_multi_column_full_scan(b: &mut Bencher) {
|
||||
let column = get_test_columns().multi;
|
||||
run_bench_on_column_full_scan(b, column);
|
||||
}
|
||||
|
||||
/// Block fetch column accessor
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_optional_column(b: &mut Bencher) {
|
||||
let column = get_test_columns().optional;
|
||||
run_bench_on_column_block_fetch(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_multi_column(b: &mut Bencher) {
|
||||
let column = get_test_columns().multi;
|
||||
run_bench_on_column_block_fetch(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_full_column(b: &mut Bencher) {
|
||||
let column = get_test_columns().full;
|
||||
run_bench_on_column_block_fetch(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_optional_column_single_calls(b: &mut Bencher) {
|
||||
let column = get_test_columns().optional;
|
||||
run_bench_on_column_block_single_calls(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_multi_column_single_calls(b: &mut Bencher) {
|
||||
let column = get_test_columns().multi;
|
||||
run_bench_on_column_block_single_calls(b, column);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_get_block_first_on_full_column_single_calls(b: &mut Bencher) {
|
||||
let column = get_test_columns().full;
|
||||
run_bench_on_column_block_single_calls(b, column);
|
||||
group.run();
|
||||
}
|
||||
|
||||
106
columnar/benches/bench_optional_index.rs
Normal file
106
columnar/benches/bench_optional_index.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use binggan::{InputGroup, black_box};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy_columnar::column_index::{OptionalIndex, Set};
|
||||
|
||||
const TOTAL_NUM_VALUES: u32 = 1_000_000;
|
||||
|
||||
fn gen_optional_index(fill_ratio: f64) -> OptionalIndex {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let vals: Vec<u32> = (0..TOTAL_NUM_VALUES)
|
||||
.map(|_| rng.gen_bool(fill_ratio))
|
||||
.enumerate()
|
||||
.filter(|(_pos, val)| *val)
|
||||
.map(|(pos, _)| pos as u32)
|
||||
.collect();
|
||||
OptionalIndex::for_test(TOTAL_NUM_VALUES, &vals)
|
||||
}
|
||||
|
||||
fn random_range_iterator(
|
||||
start: u32,
|
||||
end: u32,
|
||||
avg_step_size: u32,
|
||||
avg_deviation: u32,
|
||||
) -> impl Iterator<Item = u32> {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let mut current = start;
|
||||
std::iter::from_fn(move || {
|
||||
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
if current >= end { None } else { Some(current) }
|
||||
})
|
||||
}
|
||||
|
||||
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
|
||||
let ratio = percent / 100.0;
|
||||
let step_size = (1f32 / ratio) as u32;
|
||||
let deviation = step_size - 1;
|
||||
random_range_iterator(0, num_values, step_size, deviation)
|
||||
}
|
||||
|
||||
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
|
||||
walk_over_data_from_positions(
|
||||
codec,
|
||||
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
|
||||
)
|
||||
}
|
||||
|
||||
fn walk_over_data_from_positions(
|
||||
codec: &OptionalIndex,
|
||||
positions: impl Iterator<Item = u32>,
|
||||
) -> Option<u32> {
|
||||
let mut dense_idx: Option<u32> = None;
|
||||
for idx in positions {
|
||||
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
|
||||
}
|
||||
dense_idx
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Build separate inputs for each fill ratio.
|
||||
let inputs: Vec<(String, OptionalIndex)> = vec![
|
||||
("fill=1%".to_string(), gen_optional_index(0.01)),
|
||||
("fill=5%".to_string(), gen_optional_index(0.05)),
|
||||
("fill=10%".to_string(), gen_optional_index(0.10)),
|
||||
("fill=50%".to_string(), gen_optional_index(0.50)),
|
||||
("fill=90%".to_string(), gen_optional_index(0.90)),
|
||||
];
|
||||
|
||||
let mut group: InputGroup<OptionalIndex> = InputGroup::new_with_inputs(inputs);
|
||||
|
||||
// Translate orig->codec (rank_if_exists) with sampling
|
||||
group.register("orig_to_codec_10pct_hit", |codec: &OptionalIndex| {
|
||||
black_box(walk_over_data(codec, 100));
|
||||
});
|
||||
group.register("orig_to_codec_1pct_hit", |codec: &OptionalIndex| {
|
||||
black_box(walk_over_data(codec, 1000));
|
||||
});
|
||||
group.register("orig_to_codec_full_scan", |codec: &OptionalIndex| {
|
||||
black_box(walk_over_data_from_positions(codec, 0..TOTAL_NUM_VALUES));
|
||||
});
|
||||
|
||||
// Translate codec->orig (select/select_batch) on sampled ranks
|
||||
fn bench_translate_codec_to_orig_util(codec: &OptionalIndex, percent_hit: f32) {
|
||||
let num_non_nulls = codec.num_non_nulls();
|
||||
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
|
||||
(0..num_non_nulls).collect()
|
||||
} else {
|
||||
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
|
||||
};
|
||||
let mut output = vec![0u32; idxs.len()];
|
||||
output.copy_from_slice(&idxs[..]);
|
||||
codec.select_batch(&mut output);
|
||||
black_box(output);
|
||||
}
|
||||
|
||||
group.register("codec_to_orig_0.005pct_hit", |codec: &OptionalIndex| {
|
||||
bench_translate_codec_to_orig_util(codec, 0.005);
|
||||
});
|
||||
group.register("codec_to_orig_10pct_hit", |codec: &OptionalIndex| {
|
||||
bench_translate_codec_to_orig_util(codec, 10.0);
|
||||
});
|
||||
group.register("codec_to_orig_full_scan", |codec: &OptionalIndex| {
|
||||
bench_translate_codec_to_orig_util(codec, 100.0);
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
@@ -1,15 +1,12 @@
|
||||
#![feature(test)]
|
||||
|
||||
use std::ops::RangeInclusive;
|
||||
use std::sync::Arc;
|
||||
|
||||
use binggan::{InputGroup, black_box};
|
||||
use common::OwnedBytes;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::{Rng, SeedableRng, random};
|
||||
use tantivy_columnar::ColumnValues;
|
||||
use test::Bencher;
|
||||
extern crate test;
|
||||
|
||||
// TODO does this make sense for IPv6 ?
|
||||
fn generate_random() -> Vec<u64> {
|
||||
@@ -47,78 +44,77 @@ fn get_data_50percent_item() -> Vec<u128> {
|
||||
}
|
||||
data.push(SINGLE_ITEM);
|
||||
data.shuffle(&mut rng);
|
||||
let data = data.iter().map(|el| *el as u128).collect::<Vec<_>>();
|
||||
data
|
||||
data.iter().map(|el| *el as u128).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u128_50percent_hit(b: &mut Bencher) {
|
||||
fn main() {
|
||||
let data = get_data_50percent_item();
|
||||
let column = get_u128_column_from_data(&data);
|
||||
let column_range = get_u128_column_from_data(&data);
|
||||
let column_random = get_u128_column_random();
|
||||
|
||||
b.iter(|| {
|
||||
struct Inputs {
|
||||
data: Vec<u128>,
|
||||
column_range: Arc<dyn ColumnValues<u128>>,
|
||||
column_random: Arc<dyn ColumnValues<u128>>,
|
||||
}
|
||||
|
||||
let inputs = Inputs {
|
||||
data,
|
||||
column_range,
|
||||
column_random,
|
||||
};
|
||||
let mut group: InputGroup<Inputs> =
|
||||
InputGroup::new_with_inputs(vec![("u128 benches".to_string(), inputs)]);
|
||||
|
||||
group.register(
|
||||
"intfastfield_getrange_u128_50percent_hit",
|
||||
|inp: &Inputs| {
|
||||
let mut positions = Vec::new();
|
||||
inp.column_range.get_row_ids_for_value_range(
|
||||
*FIFTY_PERCENT_RANGE.start() as u128..=*FIFTY_PERCENT_RANGE.end() as u128,
|
||||
0..inp.data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
black_box(positions.len());
|
||||
},
|
||||
);
|
||||
|
||||
group.register("intfastfield_getrange_u128_single_hit", |inp: &Inputs| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(
|
||||
*FIFTY_PERCENT_RANGE.start() as u128..=*FIFTY_PERCENT_RANGE.end() as u128,
|
||||
0..data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
positions
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u128_single_hit(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let column = get_u128_column_from_data(&data);
|
||||
|
||||
b.iter(|| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(
|
||||
inp.column_range.get_row_ids_for_value_range(
|
||||
*SINGLE_ITEM_RANGE.start() as u128..=*SINGLE_ITEM_RANGE.end() as u128,
|
||||
0..data.len() as u32,
|
||||
0..inp.data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
positions
|
||||
black_box(positions.len());
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u128_hit_all(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let column = get_u128_column_from_data(&data);
|
||||
|
||||
b.iter(|| {
|
||||
group.register("intfastfield_getrange_u128_hit_all", |inp: &Inputs| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(0..=u128::MAX, 0..data.len() as u32, &mut positions);
|
||||
positions
|
||||
inp.column_range.get_row_ids_for_value_range(
|
||||
0..=u128::MAX,
|
||||
0..inp.data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
black_box(positions.len());
|
||||
});
|
||||
}
|
||||
// U128 RANGE END
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_scan_all_fflookup_u128(b: &mut Bencher) {
|
||||
let column = get_u128_column_random();
|
||||
|
||||
b.iter(|| {
|
||||
group.register("intfastfield_scan_all_fflookup_u128", |inp: &Inputs| {
|
||||
let mut a = 0u128;
|
||||
for i in 0u64..column.num_vals() as u64 {
|
||||
a += column.get_val(i as u32);
|
||||
for i in 0u64..inp.column_random.num_vals() as u64 {
|
||||
a += inp.column_random.get_val(i as u32);
|
||||
}
|
||||
a
|
||||
black_box(a);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_jumpy_stride5_u128(b: &mut Bencher) {
|
||||
let column = get_u128_column_random();
|
||||
|
||||
b.iter(|| {
|
||||
let n = column.num_vals();
|
||||
group.register("intfastfield_jumpy_stride5_u128", |inp: &Inputs| {
|
||||
let n = inp.column_random.num_vals();
|
||||
let mut a = 0u128;
|
||||
for i in (0..n / 5).map(|val| val * 5) {
|
||||
a += column.get_val(i);
|
||||
a += inp.column_random.get_val(i);
|
||||
}
|
||||
a
|
||||
black_box(a);
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
#![feature(test)]
|
||||
extern crate test;
|
||||
|
||||
use std::ops::RangeInclusive;
|
||||
use std::sync::Arc;
|
||||
|
||||
use binggan::{InputGroup, black_box};
|
||||
use rand::prelude::*;
|
||||
use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_column_values};
|
||||
use tantivy_columnar::*;
|
||||
use test::Bencher;
|
||||
|
||||
// Warning: this generates the same permutation at each call
|
||||
fn generate_permutation() -> Vec<u64> {
|
||||
@@ -27,37 +24,11 @@ pub fn serialize_and_load(column: &[u64], codec_type: CodecType) -> Arc<dyn Colu
|
||||
serialize_and_load_u64_based_column_values(&column, &[codec_type])
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_jumpy_veclookup(b: &mut Bencher) {
|
||||
let permutation = generate_permutation();
|
||||
let n = permutation.len();
|
||||
b.iter(|| {
|
||||
let mut a = 0u64;
|
||||
for _ in 0..n {
|
||||
a = permutation[a as usize];
|
||||
}
|
||||
a
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_jumpy_fflookup_bitpacked(b: &mut Bencher) {
|
||||
let permutation = generate_permutation();
|
||||
let n = permutation.len();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&permutation, CodecType::Bitpacked);
|
||||
b.iter(|| {
|
||||
let mut a = 0u64;
|
||||
for _ in 0..n {
|
||||
a = column.get_val(a as u32);
|
||||
}
|
||||
a
|
||||
});
|
||||
}
|
||||
|
||||
const FIFTY_PERCENT_RANGE: RangeInclusive<u64> = 1..=50;
|
||||
const SINGLE_ITEM: u64 = 90;
|
||||
const SINGLE_ITEM_RANGE: RangeInclusive<u64> = 90..=90;
|
||||
const ONE_PERCENT_ITEM_RANGE: RangeInclusive<u64> = 49..=49;
|
||||
|
||||
fn get_data_50percent_item() -> Vec<u128> {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
|
||||
@@ -69,135 +40,122 @@ fn get_data_50percent_item() -> Vec<u128> {
|
||||
data.push(SINGLE_ITEM);
|
||||
|
||||
data.shuffle(&mut rng);
|
||||
let data = data.iter().map(|el| *el as u128).collect::<Vec<_>>();
|
||||
data
|
||||
data.iter().map(|el| *el as u128).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
// U64 RANGE START
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u64_50percent_hit(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let data = data.iter().map(|el| *el as u64).collect::<Vec<_>>();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&data, CodecType::Bitpacked);
|
||||
b.iter(|| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(
|
||||
FIFTY_PERCENT_RANGE,
|
||||
0..data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
positions
|
||||
});
|
||||
}
|
||||
type VecCol = (Vec<u64>, Arc<dyn ColumnValues<u64>>);
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u64_1percent_hit(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let data = data.iter().map(|el| *el as u64).collect::<Vec<_>>();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&data, CodecType::Bitpacked);
|
||||
|
||||
b.iter(|| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(
|
||||
ONE_PERCENT_ITEM_RANGE,
|
||||
0..data.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
positions
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u64_single_hit(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let data = data.iter().map(|el| *el as u64).collect::<Vec<_>>();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&data, CodecType::Bitpacked);
|
||||
|
||||
b.iter(|| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(SINGLE_ITEM_RANGE, 0..data.len() as u32, &mut positions);
|
||||
positions
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_getrange_u64_hit_all(b: &mut Bencher) {
|
||||
let data = get_data_50percent_item();
|
||||
let data = data.iter().map(|el| *el as u64).collect::<Vec<_>>();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&data, CodecType::Bitpacked);
|
||||
|
||||
b.iter(|| {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(0..=u64::MAX, 0..data.len() as u32, &mut positions);
|
||||
positions
|
||||
});
|
||||
}
|
||||
// U64 RANGE END
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_stride7_vec(b: &mut Bencher) {
|
||||
fn bench_access() {
|
||||
let permutation = generate_permutation();
|
||||
let n = permutation.len();
|
||||
b.iter(|| {
|
||||
let column_perm: Arc<dyn ColumnValues<u64>> =
|
||||
serialize_and_load(&permutation, CodecType::Bitpacked);
|
||||
|
||||
let permutation_gcd = generate_permutation_gcd();
|
||||
let column_perm_gcd: Arc<dyn ColumnValues<u64>> =
|
||||
serialize_and_load(&permutation_gcd, CodecType::Bitpacked);
|
||||
|
||||
let mut group: InputGroup<VecCol> = InputGroup::new_with_inputs(vec![
|
||||
(
|
||||
"access".to_string(),
|
||||
(permutation.clone(), column_perm.clone()),
|
||||
),
|
||||
(
|
||||
"access_gcd".to_string(),
|
||||
(permutation_gcd.clone(), column_perm_gcd.clone()),
|
||||
),
|
||||
]);
|
||||
|
||||
group.register("stride7_vec", |inp: &VecCol| {
|
||||
let n = inp.0.len();
|
||||
let mut a = 0u64;
|
||||
for i in (0..n / 7).map(|val| val * 7) {
|
||||
a += permutation[i as usize];
|
||||
a += inp.0[i];
|
||||
}
|
||||
a
|
||||
black_box(a);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_stride7_fflookup(b: &mut Bencher) {
|
||||
let permutation = generate_permutation();
|
||||
let n = permutation.len();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&permutation, CodecType::Bitpacked);
|
||||
b.iter(|| {
|
||||
let mut a = 0;
|
||||
group.register("fullscan_vec", |inp: &VecCol| {
|
||||
let mut a = 0u64;
|
||||
for i in 0..inp.0.len() {
|
||||
a += inp.0[i];
|
||||
}
|
||||
black_box(a);
|
||||
});
|
||||
|
||||
group.register("stride7_column_values", |inp: &VecCol| {
|
||||
let n = inp.1.num_vals() as usize;
|
||||
let mut a = 0u64;
|
||||
for i in (0..n / 7).map(|val| val * 7) {
|
||||
a += column.get_val(i as u32);
|
||||
a += inp.1.get_val(i as u32);
|
||||
}
|
||||
a
|
||||
black_box(a);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_scan_all_fflookup(b: &mut Bencher) {
|
||||
let permutation = generate_permutation();
|
||||
let n = permutation.len();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&permutation, CodecType::Bitpacked);
|
||||
let column_ref = column.as_ref();
|
||||
b.iter(|| {
|
||||
let mut a = 0u64;
|
||||
for i in 0u32..n as u32 {
|
||||
a += column_ref.get_val(i);
|
||||
}
|
||||
a
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_scan_all_fflookup_gcd(b: &mut Bencher) {
|
||||
let permutation = generate_permutation_gcd();
|
||||
let n = permutation.len();
|
||||
let column: Arc<dyn ColumnValues<u64>> = serialize_and_load(&permutation, CodecType::Bitpacked);
|
||||
b.iter(|| {
|
||||
group.register("fullscan_column_values", |inp: &VecCol| {
|
||||
let mut a = 0u64;
|
||||
let n = inp.1.num_vals() as usize;
|
||||
for i in 0..n {
|
||||
a += column.get_val(i as u32);
|
||||
a += inp.1.get_val(i as u32);
|
||||
}
|
||||
a
|
||||
black_box(a);
|
||||
});
|
||||
|
||||
group.run();
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_intfastfield_scan_all_vec(b: &mut Bencher) {
|
||||
let permutation = generate_permutation();
|
||||
b.iter(|| {
|
||||
let mut a = 0u64;
|
||||
for i in 0..permutation.len() {
|
||||
a += permutation[i as usize] as u64;
|
||||
}
|
||||
a
|
||||
});
|
||||
fn bench_range() {
|
||||
let data_50 = get_data_50percent_item();
|
||||
let data_u64 = data_50.iter().map(|el| *el as u64).collect::<Vec<_>>();
|
||||
let column_data: Arc<dyn ColumnValues<u64>> =
|
||||
serialize_and_load(&data_u64, CodecType::Bitpacked);
|
||||
|
||||
let mut group: InputGroup<Arc<dyn ColumnValues<u64>>> =
|
||||
InputGroup::new_with_inputs(vec![("dist_50pct_item".to_string(), column_data.clone())]);
|
||||
|
||||
group.register(
|
||||
"fastfield_getrange_u64_50percent_hit",
|
||||
|col: &Arc<dyn ColumnValues<u64>>| {
|
||||
let mut positions = Vec::new();
|
||||
col.get_row_ids_for_value_range(FIFTY_PERCENT_RANGE, 0..col.num_vals(), &mut positions);
|
||||
black_box(positions.len());
|
||||
},
|
||||
);
|
||||
|
||||
group.register(
|
||||
"fastfield_getrange_u64_1percent_hit",
|
||||
|col: &Arc<dyn ColumnValues<u64>>| {
|
||||
let mut positions = Vec::new();
|
||||
col.get_row_ids_for_value_range(
|
||||
ONE_PERCENT_ITEM_RANGE,
|
||||
0..col.num_vals(),
|
||||
&mut positions,
|
||||
);
|
||||
black_box(positions.len());
|
||||
},
|
||||
);
|
||||
|
||||
group.register(
|
||||
"fastfield_getrange_u64_single_hit",
|
||||
|col: &Arc<dyn ColumnValues<u64>>| {
|
||||
let mut positions = Vec::new();
|
||||
col.get_row_ids_for_value_range(SINGLE_ITEM_RANGE, 0..col.num_vals(), &mut positions);
|
||||
black_box(positions.len());
|
||||
},
|
||||
);
|
||||
|
||||
group.register(
|
||||
"fastfield_getrange_u64_hit_all",
|
||||
|col: &Arc<dyn ColumnValues<u64>>| {
|
||||
let mut positions = Vec::new();
|
||||
col.get_row_ids_for_value_range(0..=u64::MAX, 0..col.num_vals(), &mut positions);
|
||||
black_box(positions.len());
|
||||
},
|
||||
);
|
||||
|
||||
group.run();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
bench_access();
|
||||
bench_range();
|
||||
}
|
||||
|
||||
@@ -131,6 +131,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
|
||||
}
|
||||
|
||||
/// Get an iterator over the values for the provided docid.
|
||||
#[inline]
|
||||
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
|
||||
self.index
|
||||
.value_row_ids(doc_id)
|
||||
@@ -158,15 +160,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
.select_batch_in_place(selected_docid_range.start, doc_ids);
|
||||
}
|
||||
|
||||
/// Fills the output vector with the (possibly multiple values that are associated_with
|
||||
/// `row_id`.
|
||||
///
|
||||
/// This method clears the `output` vector.
|
||||
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
|
||||
output.clear();
|
||||
output.extend(self.values_for_doc(row_id));
|
||||
}
|
||||
|
||||
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
|
||||
Arc::new(FirstValueWithDefault {
|
||||
column: self,
|
||||
|
||||
@@ -56,7 +56,7 @@ fn get_doc_ids_with_values<'a>(
|
||||
ColumnIndex::Full => Box::new(doc_range),
|
||||
ColumnIndex::Optional(optional_index) => Box::new(
|
||||
optional_index
|
||||
.iter_docs()
|
||||
.iter_non_null_docs()
|
||||
.map(move |row| row + doc_range.start),
|
||||
),
|
||||
ColumnIndex::Multivalued(multivalued_index) => match multivalued_index {
|
||||
@@ -73,7 +73,7 @@ fn get_doc_ids_with_values<'a>(
|
||||
MultiValueIndex::MultiValueIndexV2(multivalued_index) => Box::new(
|
||||
multivalued_index
|
||||
.optional_index
|
||||
.iter_docs()
|
||||
.iter_non_null_docs()
|
||||
.map(move |row| row + doc_range.start),
|
||||
),
|
||||
},
|
||||
@@ -105,10 +105,11 @@ fn get_num_values_iterator<'a>(
|
||||
) -> Box<dyn Iterator<Item = u32> + 'a> {
|
||||
match column_index {
|
||||
ColumnIndex::Empty { .. } => Box::new(std::iter::empty()),
|
||||
ColumnIndex::Full => Box::new(std::iter::repeat(1u32).take(num_docs as usize)),
|
||||
ColumnIndex::Optional(optional_index) => {
|
||||
Box::new(std::iter::repeat(1u32).take(optional_index.num_non_nulls() as usize))
|
||||
}
|
||||
ColumnIndex::Full => Box::new(std::iter::repeat_n(1u32, num_docs as usize)),
|
||||
ColumnIndex::Optional(optional_index) => Box::new(std::iter::repeat_n(
|
||||
1u32,
|
||||
optional_index.num_non_nulls() as usize,
|
||||
)),
|
||||
ColumnIndex::Multivalued(multivalued_index) => Box::new(
|
||||
multivalued_index
|
||||
.get_start_index_column()
|
||||
@@ -177,7 +178,7 @@ impl<'a> Iterable<RowId> for StackedOptionalIndex<'a> {
|
||||
ColumnIndex::Full => Box::new(columnar_row_range),
|
||||
ColumnIndex::Optional(optional_index) => Box::new(
|
||||
optional_index
|
||||
.iter_docs()
|
||||
.iter_non_null_docs()
|
||||
.map(move |row_id: RowId| columnar_row_range.start + row_id),
|
||||
),
|
||||
ColumnIndex::Multivalued(_) => {
|
||||
|
||||
@@ -215,6 +215,32 @@ impl MultiValueIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over document ids that have at least one value.
|
||||
pub fn iter_non_null_docs(&self) -> Box<dyn Iterator<Item = DocId> + '_> {
|
||||
match self {
|
||||
MultiValueIndex::MultiValueIndexV1(idx) => {
|
||||
let mut doc: DocId = 0u32;
|
||||
let num_docs = idx.num_docs();
|
||||
Box::new(std::iter::from_fn(move || {
|
||||
// This is not the most efficient way to do this, but it's legacy code.
|
||||
while doc < num_docs {
|
||||
let cur = doc;
|
||||
doc += 1;
|
||||
let start = idx.start_index_column.get_val(cur);
|
||||
let end = idx.start_index_column.get_val(cur + 1);
|
||||
if end > start {
|
||||
return Some(cur);
|
||||
}
|
||||
}
|
||||
None
|
||||
}))
|
||||
}
|
||||
MultiValueIndex::MultiValueIndexV2(idx) => {
|
||||
Box::new(idx.optional_index.iter_non_null_docs())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a list of ranks (row ids of values) in a 1:n index to the corresponding list of
|
||||
/// docids. Positions are converted inplace to docids.
|
||||
///
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::io::{self, Write};
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod set;
|
||||
@@ -11,7 +11,7 @@ use set_block::{
|
||||
};
|
||||
|
||||
use crate::iterable::Iterable;
|
||||
use crate::{DocId, InvalidData, RowId};
|
||||
use crate::{DocId, RowId};
|
||||
|
||||
/// The threshold for for number of elements after which we switch to dense block encoding.
|
||||
///
|
||||
@@ -88,7 +88,7 @@ pub struct OptionalIndex {
|
||||
|
||||
impl Iterable<u32> for &OptionalIndex {
|
||||
fn boxed_iter(&self) -> Box<dyn Iterator<Item = u32> + '_> {
|
||||
Box::new(self.iter_docs())
|
||||
Box::new(self.iter_non_null_docs())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,8 +280,9 @@ impl OptionalIndex {
|
||||
self.num_non_null_docs
|
||||
}
|
||||
|
||||
pub fn iter_docs(&self) -> impl Iterator<Item = RowId> + '_ {
|
||||
// TODO optimize
|
||||
pub fn iter_non_null_docs(&self) -> impl Iterator<Item = RowId> + '_ {
|
||||
// TODO optimize. We could iterate over the blocks directly.
|
||||
// We use the dense value ids and retrieve the doc ids via select.
|
||||
let mut select_batch = self.select_cursor();
|
||||
(0..self.num_non_null_docs).map(move |rank| select_batch.select(rank))
|
||||
}
|
||||
@@ -334,38 +335,6 @@ enum Block<'a> {
|
||||
Sparse(SparseBlock<'a>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
enum OptionalIndexCodec {
|
||||
Dense = 0,
|
||||
Sparse = 1,
|
||||
}
|
||||
|
||||
impl OptionalIndexCodec {
|
||||
fn to_code(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
fn try_from_code(code: u8) -> Result<Self, InvalidData> {
|
||||
match code {
|
||||
0 => Ok(Self::Dense),
|
||||
1 => Ok(Self::Sparse),
|
||||
_ => Err(InvalidData),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BinarySerializable for OptionalIndexCodec {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
|
||||
writer.write_all(&[self.to_code()])
|
||||
}
|
||||
|
||||
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
|
||||
let optional_codec_code = u8::deserialize(reader)?;
|
||||
let optional_codec = Self::try_from_code(optional_codec_code)?;
|
||||
Ok(optional_codec)
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_optional_index_block(block_els: &[u16], out: &mut impl io::Write) -> io::Result<()> {
|
||||
let is_sparse = is_sparse(block_els.len() as u32);
|
||||
if is_sparse {
|
||||
|
||||
@@ -164,7 +164,11 @@ fn test_optional_index_large() {
|
||||
fn test_optional_index_iter_aux(row_ids: &[RowId], num_rows: RowId) {
|
||||
let optional_index = OptionalIndex::for_test(num_rows, row_ids);
|
||||
assert_eq!(optional_index.num_docs(), num_rows);
|
||||
assert!(optional_index.iter_docs().eq(row_ids.iter().copied()));
|
||||
assert!(
|
||||
optional_index
|
||||
.iter_non_null_docs()
|
||||
.eq(row_ids.iter().copied())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -219,170 +223,3 @@ fn test_optional_index_for_tests() {
|
||||
assert!(!optional_index.contains(3));
|
||||
assert_eq!(optional_index.num_docs(), 4);
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::*;
|
||||
|
||||
const TOTAL_NUM_VALUES: u32 = 1_000_000;
|
||||
fn gen_bools(fill_ratio: f64) -> OptionalIndex {
|
||||
let mut out = Vec::new();
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let vals: Vec<RowId> = (0..TOTAL_NUM_VALUES)
|
||||
.map(|_| rng.gen_bool(fill_ratio))
|
||||
.enumerate()
|
||||
.filter(|(_pos, val)| *val)
|
||||
.map(|(pos, _)| pos as RowId)
|
||||
.collect();
|
||||
serialize_optional_index(&&vals[..], TOTAL_NUM_VALUES, &mut out).unwrap();
|
||||
|
||||
open_optional_index(OwnedBytes::new(out)).unwrap()
|
||||
}
|
||||
|
||||
fn random_range_iterator(
|
||||
start: u32,
|
||||
end: u32,
|
||||
avg_step_size: u32,
|
||||
avg_deviation: u32,
|
||||
) -> impl Iterator<Item = u32> {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let mut current = start;
|
||||
std::iter::from_fn(move || {
|
||||
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
if current >= end { None } else { Some(current) }
|
||||
})
|
||||
}
|
||||
|
||||
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
|
||||
let ratio = percent / 100.0;
|
||||
let step_size = (1f32 / ratio) as u32;
|
||||
let deviation = step_size - 1;
|
||||
random_range_iterator(0, num_values, step_size, deviation)
|
||||
}
|
||||
|
||||
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
|
||||
walk_over_data_from_positions(
|
||||
codec,
|
||||
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
|
||||
)
|
||||
}
|
||||
|
||||
fn walk_over_data_from_positions(
|
||||
codec: &OptionalIndex,
|
||||
positions: impl Iterator<Item = u32>,
|
||||
) -> Option<u32> {
|
||||
let mut dense_idx: Option<u32> = None;
|
||||
for idx in positions {
|
||||
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
|
||||
}
|
||||
dense_idx
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 1000));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_1percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_10percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_90percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_10percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_50percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.5f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_90percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_10percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.1f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 10f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 100f32, bench);
|
||||
}
|
||||
|
||||
fn bench_translate_codec_to_orig_util(
|
||||
percent_filled: f64,
|
||||
percent_hit: f32,
|
||||
bench: &mut Bencher,
|
||||
) {
|
||||
let codec = gen_bools(percent_filled);
|
||||
let num_non_nulls = codec.num_non_nulls();
|
||||
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
|
||||
(0..num_non_nulls).collect()
|
||||
} else {
|
||||
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
|
||||
};
|
||||
let mut output = vec![0u32; idxs.len()];
|
||||
bench.iter(|| {
|
||||
output.copy_from_slice(&idxs[..]);
|
||||
codec.select_batch(&mut output);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 0.005, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 100.0f32, bench);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::OwnedBytes;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::{self, Bencher};
|
||||
|
||||
use super::*;
|
||||
use crate::column_values::u64_based::*;
|
||||
|
||||
fn get_data() -> Vec<u64> {
|
||||
let mut rng = StdRng::seed_from_u64(2u64);
|
||||
let mut data: Vec<_> = (100..55000_u64)
|
||||
.map(|num| num + rng.r#gen::<u8>() as u64)
|
||||
.collect();
|
||||
data.push(99_000);
|
||||
data.insert(1000, 2000);
|
||||
data.insert(2000, 100);
|
||||
data.insert(3000, 4100);
|
||||
data.insert(4000, 100);
|
||||
data.insert(5000, 800);
|
||||
data
|
||||
}
|
||||
|
||||
fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
|
||||
let mut stats_collector = StatsCollector::default();
|
||||
for val in vals {
|
||||
stats_collector.collect(val);
|
||||
}
|
||||
stats_collector.stats()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn value_iter() -> impl Iterator<Item = u64> {
|
||||
0..20_000
|
||||
}
|
||||
|
||||
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
|
||||
let mut bytes = Vec::new();
|
||||
let stats = compute_stats(data.iter().cloned());
|
||||
let mut codec_serializer = Codec::estimator();
|
||||
for val in data {
|
||||
codec_serializer.collect(*val);
|
||||
}
|
||||
codec_serializer
|
||||
.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
|
||||
.unwrap();
|
||||
|
||||
Codec::load(OwnedBytes::new(bytes)).unwrap()
|
||||
}
|
||||
|
||||
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
|
||||
let col = get_reader_for_bench::<Codec>(data);
|
||||
b.iter(|| {
|
||||
let mut sum = 0u64;
|
||||
for pos in value_iter() {
|
||||
let val = col.get_val(pos as u32);
|
||||
sum = sum.wrapping_add(val);
|
||||
}
|
||||
sum
|
||||
});
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn bench_get_dynamic_helper(b: &mut Bencher, col: Arc<dyn ColumnValues>) {
|
||||
b.iter(|| {
|
||||
let mut sum = 0u64;
|
||||
for pos in value_iter() {
|
||||
let val = col.get_val(pos as u32);
|
||||
sum = sum.wrapping_add(val);
|
||||
}
|
||||
sum
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_get_dynamic<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
|
||||
let col = Arc::new(get_reader_for_bench::<Codec>(data));
|
||||
bench_get_dynamic_helper(b, col);
|
||||
}
|
||||
fn bench_create<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
|
||||
let stats = compute_stats(data.iter().cloned());
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
b.iter(|| {
|
||||
bytes.clear();
|
||||
let mut codec_serializer = Codec::estimator();
|
||||
for val in data.iter().take(1024) {
|
||||
codec_serializer.collect(*val);
|
||||
}
|
||||
|
||||
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_fastfield_bitpack_create(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_create::<BitpackedCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_linearinterpol_create(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_create::<LinearCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_multilinearinterpol_create(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_create::<BlockwiseLinearCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_bitpack_get(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get::<BitpackedCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_bitpack_get_dynamic(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get_dynamic::<BitpackedCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_linearinterpol_get(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get::<LinearCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_linearinterpol_get_dynamic(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get_dynamic::<LinearCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_multilinearinterpol_get(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get::<BlockwiseLinearCodec>(b, &data);
|
||||
}
|
||||
#[bench]
|
||||
fn bench_fastfield_multilinearinterpol_get_dynamic(b: &mut Bencher) {
|
||||
let data: Vec<_> = get_data();
|
||||
bench_get_dynamic::<BlockwiseLinearCodec>(b, &data);
|
||||
}
|
||||
@@ -242,6 +242,3 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
|
||||
.get_row_ids_for_value_range(range, doc_id_range, positions)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
/// Montonic maps a value to u128 value space
|
||||
/// Monotonic maps a value to u128 value space
|
||||
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
|
||||
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
|
||||
/// Converts a value to u128.
|
||||
|
||||
@@ -185,10 +185,10 @@ impl CompactSpaceBuilder {
|
||||
let mut covered_space = Vec::with_capacity(self.blanks.len());
|
||||
|
||||
// beginning of the blanks
|
||||
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start) {
|
||||
if *first_blank_start != 0 {
|
||||
covered_space.push(0..=first_blank_start - 1);
|
||||
}
|
||||
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start)
|
||||
&& *first_blank_start != 0
|
||||
{
|
||||
covered_space.push(0..=first_blank_start - 1);
|
||||
}
|
||||
|
||||
// Between the blanks
|
||||
@@ -202,10 +202,10 @@ impl CompactSpaceBuilder {
|
||||
covered_space.extend(between_blanks);
|
||||
|
||||
// end of the blanks
|
||||
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end) {
|
||||
if *last_blank_end != u128::MAX {
|
||||
covered_space.push(last_blank_end + 1..=u128::MAX);
|
||||
}
|
||||
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end)
|
||||
&& *last_blank_end != u128::MAX
|
||||
{
|
||||
covered_space.push(last_blank_end + 1..=u128::MAX);
|
||||
}
|
||||
|
||||
if covered_space.is_empty() {
|
||||
|
||||
@@ -105,7 +105,7 @@ impl ColumnCodecEstimator for BitpackedCodecEstimator {
|
||||
|
||||
fn estimate(&self, stats: &ColumnStats) -> Option<u64> {
|
||||
let num_bits_per_value = num_bits(stats);
|
||||
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64) + 7) / 8)
|
||||
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64)).div_ceil(8))
|
||||
}
|
||||
|
||||
fn serialize(
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
|
||||
const MID_POINT: u64 = (1u64 << 32) - 1u64;
|
||||
|
||||
/// `Line` describes a line function `y: ax + b` using integer
|
||||
/// arithmetics.
|
||||
/// arithmetic.
|
||||
///
|
||||
/// The slope is in fact a decimal split into a 32 bit integer value,
|
||||
/// and a 32-bit decimal value.
|
||||
@@ -94,7 +94,7 @@ impl Line {
|
||||
// `(i, ys[])`.
|
||||
//
|
||||
// The best intercept therefore has the form
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetics).
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetic).
|
||||
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
|
||||
// and our task is just to pick the one that minimizes our error.
|
||||
//
|
||||
|
||||
@@ -117,7 +117,7 @@ impl ColumnCodecEstimator for LinearCodecEstimator {
|
||||
Some(
|
||||
stats.num_bytes()
|
||||
+ linear_params.num_bytes()
|
||||
+ (num_bits as u64 * stats.num_rows as u64 + 7) / 8,
|
||||
+ (num_bits as u64 * stats.num_rows as u64).div_ceil(8),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
|
||||
) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A column codec describes a colunm serialization format.
|
||||
/// A column codec describes a column serialization format.
|
||||
pub trait ColumnCodec<T: PartialOrd = u64> {
|
||||
/// Specialized `ColumnValues` type.
|
||||
type ColumnValues: ColumnValues<T> + 'static;
|
||||
|
||||
@@ -367,7 +367,7 @@ fn is_empty_after_merge(
|
||||
ColumnIndex::Empty { .. } => true,
|
||||
ColumnIndex::Full => alive_bitset.len() == 0,
|
||||
ColumnIndex::Optional(optional_index) => {
|
||||
for doc in optional_index.iter_docs() {
|
||||
for doc in optional_index.iter_non_null_docs() {
|
||||
if alive_bitset.contains(doc) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ impl SymbolValue for UnorderedId {
|
||||
|
||||
fn compute_num_bytes_for_u64(val: u64) -> usize {
|
||||
let msb = (64u32 - val.leading_zeros()) as usize;
|
||||
(msb + 7) / 8
|
||||
msb.div_ceil(8)
|
||||
}
|
||||
|
||||
fn encode_zig_zag(n: i64) -> u64 {
|
||||
|
||||
@@ -17,15 +17,10 @@
|
||||
//! column.
|
||||
//! - [column_values]: Stores the values of a column in a dense format.
|
||||
|
||||
#![cfg_attr(all(feature = "unstable", test), feature(test))]
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
extern crate more_asserts;
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
extern crate test;
|
||||
|
||||
use std::fmt::Display;
|
||||
use std::io;
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use common::DateTime;
|
||||
|
||||
use crate::InvalidData;
|
||||
@@ -9,6 +11,23 @@ pub enum NumericalValue {
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl FromStr for NumericalValue {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, ()> {
|
||||
if let Ok(val_i64) = s.parse::<i64>() {
|
||||
return Ok(val_i64.into());
|
||||
}
|
||||
if let Ok(val_u64) = s.parse::<u64>() {
|
||||
return Ok(val_u64.into());
|
||||
}
|
||||
if let Ok(val_f64) = s.parse::<f64>() {
|
||||
return Ok(NumericalValue::from(val_f64).normalize());
|
||||
}
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
impl NumericalValue {
|
||||
pub fn numerical_type(&self) -> NumericalType {
|
||||
match self {
|
||||
@@ -26,7 +45,7 @@ impl NumericalValue {
|
||||
if val <= i64::MAX as u64 {
|
||||
NumericalValue::I64(val as i64)
|
||||
} else {
|
||||
NumericalValue::F64(val as f64)
|
||||
NumericalValue::U64(val)
|
||||
}
|
||||
}
|
||||
NumericalValue::I64(val) => NumericalValue::I64(val),
|
||||
@@ -141,6 +160,7 @@ impl Coerce for DateTime {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::NumericalType;
|
||||
use crate::NumericalValue;
|
||||
|
||||
#[test]
|
||||
fn test_numerical_type_code() {
|
||||
@@ -153,4 +173,58 @@ mod tests {
|
||||
}
|
||||
assert_eq!(num_numerical_type, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_numerical() {
|
||||
assert_eq!(
|
||||
"123".parse::<NumericalValue>().unwrap(),
|
||||
NumericalValue::I64(123)
|
||||
);
|
||||
assert_eq!(
|
||||
"18446744073709551615".parse::<NumericalValue>().unwrap(),
|
||||
NumericalValue::U64(18446744073709551615u64)
|
||||
);
|
||||
assert_eq!(
|
||||
"1.0".parse::<NumericalValue>().unwrap(),
|
||||
NumericalValue::I64(1i64)
|
||||
);
|
||||
assert_eq!(
|
||||
"1.1".parse::<NumericalValue>().unwrap(),
|
||||
NumericalValue::F64(1.1f64)
|
||||
);
|
||||
assert_eq!(
|
||||
"-1.0".parse::<NumericalValue>().unwrap(),
|
||||
NumericalValue::I64(-1i64)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_numerical() {
|
||||
assert_eq!(
|
||||
NumericalValue::from(1u64).normalize(),
|
||||
NumericalValue::I64(1i64),
|
||||
);
|
||||
let limit_val = i64::MAX as u64 + 1u64;
|
||||
assert_eq!(
|
||||
NumericalValue::from(limit_val).normalize(),
|
||||
NumericalValue::U64(limit_val),
|
||||
);
|
||||
assert_eq!(
|
||||
NumericalValue::from(-1i64).normalize(),
|
||||
NumericalValue::I64(-1i64),
|
||||
);
|
||||
assert_eq!(
|
||||
NumericalValue::from(-2.0f64).normalize(),
|
||||
NumericalValue::I64(-2i64),
|
||||
);
|
||||
assert_eq!(
|
||||
NumericalValue::from(-2.1f64).normalize(),
|
||||
NumericalValue::F64(-2.1f64),
|
||||
);
|
||||
let large_float = 2.0f64.powf(70.0f64);
|
||||
assert_eq!(
|
||||
NumericalValue::from(large_float).normalize(),
|
||||
NumericalValue::F64(large_float),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy-common"
|
||||
version = "0.9.0"
|
||||
version = "0.10.0"
|
||||
authors = ["Paul Masurel <paul@quickwit.io>", "Pascal Seitz <pascal@quickwit.io>"]
|
||||
license = "MIT"
|
||||
edition = "2024"
|
||||
|
||||
@@ -183,7 +183,7 @@ pub struct BitSet {
|
||||
}
|
||||
|
||||
fn num_buckets(max_val: u32) -> u32 {
|
||||
(max_val + 63u32) / 64u32
|
||||
max_val.div_ceil(64u32)
|
||||
}
|
||||
|
||||
impl BitSet {
|
||||
|
||||
@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
|
||||
writer.write_all(&buffer)
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u128;
|
||||
let mut shift = 0u64;
|
||||
@@ -195,7 +197,9 @@ impl BinarySerializable for VInt {
|
||||
writer.write_all(&buffer[0..num_bytes])
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0u64;
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 653 KiB |
@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// is the role of the `TopDocs` collector.
|
||||
|
||||
// We can now perform our query.
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
// The actual documents still need to be
|
||||
// retrieved from Tantivy's store.
|
||||
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
|
||||
|
||||
let (_score, doc_address) = searcher
|
||||
.search(&query, &TopDocs::with_limit(1))?
|
||||
.search(&query, &TopDocs::with_limit(1).order_by_score())?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// here we want to get a hit on the 'ken' in Frankenstein
|
||||
let query = query_parser.parse_query("ken")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (_, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Simple exact search on the date
|
||||
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Range query on the date field
|
||||
let query = query_parser
|
||||
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
for (_score, doc_address) in count_docs {
|
||||
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;
|
||||
|
||||
@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
|
||||
// The second argument is here to tell we don't care about decoding positions,
|
||||
// or term frequencies.
|
||||
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
|
||||
|
||||
if let Some((_score, doc_address)) = top_docs.first() {
|
||||
let doc = searcher.doc(*doc_address)?;
|
||||
|
||||
212
examples/filter_aggregation.rs
Normal file
212
examples/filter_aggregation.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
// # Filter Aggregation Example
|
||||
//
|
||||
// This example demonstrates filter aggregations - creating buckets of documents
|
||||
// matching specific queries, with nested aggregations computed on each bucket.
|
||||
//
|
||||
// Filter aggregations are useful for computing metrics on different subsets of
|
||||
// your data in a single query, like "average price overall + average price for
|
||||
// electronics + count of in-stock items".
|
||||
|
||||
use serde_json::json;
|
||||
use tantivy::aggregation::agg_req::Aggregations;
|
||||
use tantivy::aggregation::AggregationCollector;
|
||||
use tantivy::query::AllQuery;
|
||||
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
|
||||
use tantivy::{doc, Index};
|
||||
|
||||
fn main() -> tantivy::Result<()> {
|
||||
// Create a simple product schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_text_field("category", TEXT | FAST);
|
||||
schema_builder.add_text_field("brand", TEXT | FAST);
|
||||
schema_builder.add_u64_field("price", FAST);
|
||||
schema_builder.add_f64_field("rating", FAST);
|
||||
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
|
||||
// Create index and add sample products
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut writer = index.writer(50_000_000)?;
|
||||
|
||||
writer.add_document(doc!(
|
||||
schema.get_field("category")? => "electronics",
|
||||
schema.get_field("brand")? => "apple",
|
||||
schema.get_field("price")? => 999u64,
|
||||
schema.get_field("rating")? => 4.5f64,
|
||||
schema.get_field("in_stock")? => true
|
||||
))?;
|
||||
writer.add_document(doc!(
|
||||
schema.get_field("category")? => "electronics",
|
||||
schema.get_field("brand")? => "samsung",
|
||||
schema.get_field("price")? => 799u64,
|
||||
schema.get_field("rating")? => 4.2f64,
|
||||
schema.get_field("in_stock")? => true
|
||||
))?;
|
||||
writer.add_document(doc!(
|
||||
schema.get_field("category")? => "clothing",
|
||||
schema.get_field("brand")? => "nike",
|
||||
schema.get_field("price")? => 120u64,
|
||||
schema.get_field("rating")? => 4.1f64,
|
||||
schema.get_field("in_stock")? => false
|
||||
))?;
|
||||
writer.add_document(doc!(
|
||||
schema.get_field("category")? => "books",
|
||||
schema.get_field("brand")? => "penguin",
|
||||
schema.get_field("price")? => 25u64,
|
||||
schema.get_field("rating")? => 4.8f64,
|
||||
schema.get_field("in_stock")? => true
|
||||
))?;
|
||||
|
||||
writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Example 1: Basic filter with metric aggregation
|
||||
println!("=== Example 1: Electronics average price ===");
|
||||
let agg_req = json!({
|
||||
"electronics": {
|
||||
"filter": "category:electronics",
|
||||
"aggs": {
|
||||
"avg_price": { "avg": { "field": "price" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||
let result = searcher.search(&AllQuery, &collector)?;
|
||||
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0 }
|
||||
}
|
||||
});
|
||||
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||
|
||||
// Example 2: Multiple independent filters
|
||||
println!("=== Example 2: Multiple filters in one query ===");
|
||||
let agg_req = json!({
|
||||
"electronics": {
|
||||
"filter": "category:electronics",
|
||||
"aggs": { "avg_price": { "avg": { "field": "price" } } }
|
||||
},
|
||||
"in_stock": {
|
||||
"filter": "in_stock:true",
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
},
|
||||
"high_rated": {
|
||||
"filter": "rating:[4.5 TO *]",
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
}
|
||||
});
|
||||
|
||||
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||
let result = searcher.search(&AllQuery, &collector)?;
|
||||
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0 }
|
||||
},
|
||||
"in_stock": {
|
||||
"doc_count": 3,
|
||||
"count": { "value": 3.0 }
|
||||
},
|
||||
"high_rated": {
|
||||
"doc_count": 2,
|
||||
"count": { "value": 2.0 }
|
||||
}
|
||||
});
|
||||
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||
|
||||
// Example 3: Nested filters - progressive refinement
|
||||
println!("=== Example 3: Nested filters ===");
|
||||
let agg_req = json!({
|
||||
"in_stock": {
|
||||
"filter": "in_stock:true",
|
||||
"aggs": {
|
||||
"electronics": {
|
||||
"filter": "category:electronics",
|
||||
"aggs": {
|
||||
"expensive": {
|
||||
"filter": "price:[800 TO *]",
|
||||
"aggs": {
|
||||
"avg_rating": { "avg": { "field": "rating" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||
let result = searcher.search(&AllQuery, &collector)?;
|
||||
|
||||
let expected = json!({
|
||||
"in_stock": {
|
||||
"doc_count": 3, // apple, samsung, penguin
|
||||
"electronics": {
|
||||
"doc_count": 2, // apple, samsung
|
||||
"expensive": {
|
||||
"doc_count": 1, // only apple (999)
|
||||
"avg_rating": { "value": 4.5 }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||
|
||||
// Example 4: Filter with sub-aggregation (terms)
|
||||
println!("=== Example 4: Filter with terms sub-aggregation ===");
|
||||
let agg_req = json!({
|
||||
"electronics": {
|
||||
"filter": "category:electronics",
|
||||
"aggs": {
|
||||
"by_brand": {
|
||||
"terms": { "field": "brand" },
|
||||
"aggs": {
|
||||
"avg_price": { "avg": { "field": "price" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||
let result = searcher.search(&AllQuery, &collector)?;
|
||||
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"by_brand": {
|
||||
"buckets": [
|
||||
{
|
||||
"key": "samsung",
|
||||
"doc_count": 1,
|
||||
"avg_price": { "value": 799.0 }
|
||||
},
|
||||
{
|
||||
"key": "apple",
|
||||
"doc_count": 1,
|
||||
"avg_price": { "value": 999.0 }
|
||||
}
|
||||
],
|
||||
"sum_other_doc_count": 0,
|
||||
"doc_count_error_upper_bound": 0
|
||||
}
|
||||
}
|
||||
});
|
||||
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -85,7 +85,6 @@ fn main() -> tantivy::Result<()> {
|
||||
index_writer.add_document(doc!(
|
||||
title => "The Diary of a Young Girl",
|
||||
))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
// ### Committing
|
||||
//
|
||||
@@ -146,7 +145,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = FuzzyTermQuery::new(term, 2, true);
|
||||
|
||||
let (top_docs, count) = searcher
|
||||
.search(&query, &(TopDocs::with_limit(5), Count))
|
||||
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
|
||||
.unwrap();
|
||||
assert_eq!(count, 3);
|
||||
assert_eq!(top_docs.len(), 3);
|
||||
|
||||
66
examples/geo_json.rs
Normal file
66
examples/geo_json.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use geo_types::Point;
|
||||
use tantivy::collector::TopDocs;
|
||||
use tantivy::query::SpatialQuery;
|
||||
use tantivy::schema::{Schema, Value, SPATIAL, STORED, TEXT};
|
||||
use tantivy::spatial::point::GeoPoint;
|
||||
use tantivy::{Index, IndexWriter, TantivyDocument};
|
||||
fn main() -> tantivy::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_json_field("properties", STORED | TEXT);
|
||||
schema_builder.add_spatial_field("geometry", STORED | SPATIAL);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
|
||||
let doc = TantivyDocument::parse_json(
|
||||
&schema,
|
||||
r#"{
|
||||
"type":"Feature",
|
||||
"geometry":{
|
||||
"type":"Polygon",
|
||||
"coordinates":[[[-99.483911,45.577697],[-99.483869,45.571457],[-99.481739,45.571461],[-99.474881,45.571584],[-99.473167,45.571615],[-99.463394,45.57168],[-99.463391,45.57883],[-99.463368,45.586076],[-99.48177,45.585926],[-99.48384,45.585953],[-99.483885,45.57873],[-99.483911,45.577697]]]
|
||||
},
|
||||
"properties":{
|
||||
"admin_level":"8",
|
||||
"border_type":"city",
|
||||
"boundary":"administrative",
|
||||
"gnis:feature_id":"1267426",
|
||||
"name":"Hosmer",
|
||||
"place":"city",
|
||||
"source":"TIGER/Line® 2008 Place Shapefiles (http://www.census.gov/geo/www/tiger/)",
|
||||
"wikidata":"Q2442118",
|
||||
"wikipedia":"en:Hosmer, South Dakota"
|
||||
}
|
||||
}"#,
|
||||
)?;
|
||||
index_writer.add_document(doc)?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let field = schema.get_field("geometry").unwrap();
|
||||
let query = SpatialQuery::new(
|
||||
field,
|
||||
[
|
||||
GeoPoint {
|
||||
lon: -99.49,
|
||||
lat: 45.56,
|
||||
},
|
||||
GeoPoint {
|
||||
lon: -99.45,
|
||||
lat: 45.59,
|
||||
},
|
||||
],
|
||||
tantivy::query::SpatialQueryType::Intersects,
|
||||
);
|
||||
let hits = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
for (_score, doc_address) in &hits {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)?;
|
||||
if let Some(field_value) = retrieved_doc.get_first(field) {
|
||||
if let Some(geometry_box) = field_value.as_value().into_geometry() {
|
||||
println!("Retrieved geometry: {:?}", geometry_box);
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(hits.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Inclusive range queries
|
||||
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Exclusive range queries
|
||||
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 0);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller equal 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller than 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit-button")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("click AND cart.product_id:133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
{
|
||||
// The sub-fields in the json field marked as default field still need to be explicitly
|
||||
// addressed
|
||||
let query = query_parser.parse_query("click AND 133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
{
|
||||
// Default json fields are ignored if they collide with the schema
|
||||
let query = query_parser.parse_query("event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
// # Query via full attribute path
|
||||
{
|
||||
// This only searches in our schema's `event_type` field
|
||||
let query = query_parser.parse_query("event_type:click")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 2);
|
||||
}
|
||||
{
|
||||
// Default json fields can still be accessed by full path
|
||||
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -63,7 +63,7 @@ fn main() -> Result<()> {
|
||||
// but not "in the Gulf Stream".
|
||||
let query = query_parser.parse_query("\"in the su\"*")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
let mut titles = top_docs
|
||||
.into_iter()
|
||||
.map(|(_score, doc_address)| {
|
||||
|
||||
@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 2);
|
||||
|
||||
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (_top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 0);
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, body]);
|
||||
let query = query_parser.parse_query("sycamore spring")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// stop words are applied on the query as well.
|
||||
// The following will be equivalent to `title:frankenstein`
|
||||
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (score, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
|
||||
move |doc_id: DocId| Reverse(price[doc_id as usize])
|
||||
};
|
||||
|
||||
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
|
||||
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
|
||||
|
||||
let hits = searcher.search(&query, &most_expensive_first)?;
|
||||
assert_eq!(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy-query-grammar"
|
||||
version = "0.24.0"
|
||||
version = "0.25.0"
|
||||
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
|
||||
license = "MIT"
|
||||
categories = ["database-implementations", "data-structures"]
|
||||
@@ -15,3 +15,5 @@ edition = "2024"
|
||||
nom = "7"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
ordered-float = "5.0.0"
|
||||
fnv = "1.0.7"
|
||||
|
||||
@@ -117,6 +117,22 @@ where F: nom::Parser<I, (O, ErrorList), Infallible> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn terminated_infallible<I, O1, O2, F, G>(
|
||||
mut first: F,
|
||||
mut second: G,
|
||||
) -> impl FnMut(I) -> JResult<I, O1>
|
||||
where
|
||||
F: nom::Parser<I, (O1, ErrorList), Infallible>,
|
||||
G: nom::Parser<I, (O2, ErrorList), Infallible>,
|
||||
{
|
||||
move |input: I| {
|
||||
let (input, (o1, mut err)) = first.parse(input)?;
|
||||
let (input, (_, mut err2)) = second.parse(input)?;
|
||||
err.append(&mut err2);
|
||||
Ok((input, (o1, err)))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn delimited_infallible<I, O1, O2, O3, F, G, H>(
|
||||
mut first: F,
|
||||
mut second: G,
|
||||
|
||||
@@ -31,7 +31,17 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec<LenientError>) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{parse_query, parse_query_lenient};
|
||||
use crate::{UserInputAst, parse_query, parse_query_lenient};
|
||||
|
||||
#[test]
|
||||
fn test_deduplication() {
|
||||
let ast: UserInputAst = parse_query("a a").unwrap();
|
||||
let json = serde_json::to_string(&ast).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_serialization() {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::borrow::Cow;
|
||||
use std::iter::once;
|
||||
|
||||
use fnv::FnvHashSet;
|
||||
use nom::IResult;
|
||||
use nom::branch::alt;
|
||||
use nom::bytes::complete::tag;
|
||||
@@ -36,7 +37,7 @@ fn field_name(inp: &str) -> IResult<&str, String> {
|
||||
alt((first_char, escape_sequence())),
|
||||
many0(alt((simple_char, escape_sequence(), char('\\')))),
|
||||
)),
|
||||
char(':'),
|
||||
tuple((multispace0, char(':'), multispace0)),
|
||||
),
|
||||
|(first_char, next)| once(first_char).chain(next).collect(),
|
||||
)(inp)
|
||||
@@ -68,7 +69,7 @@ fn interpret_escape(source: &str) -> String {
|
||||
|
||||
/// Consume a word outside of any context.
|
||||
// TODO should support escape sequences
|
||||
fn word(inp: &str) -> IResult<&str, Cow<str>> {
|
||||
fn word(inp: &str) -> IResult<&str, Cow<'_, str>> {
|
||||
map_res(
|
||||
recognize(tuple((
|
||||
alt((
|
||||
@@ -305,15 +306,14 @@ fn term_group_infallible(inp: &str) -> JResult<&str, UserInputAst> {
|
||||
let (inp, (field_name, _, _, _)) =
|
||||
tuple((field_name, multispace0, char('('), multispace0))(inp).expect("precondition failed");
|
||||
|
||||
let res = delimited_infallible(
|
||||
delimited_infallible(
|
||||
nothing,
|
||||
map(ast_infallible, |(mut ast, errors)| {
|
||||
ast.set_default_field(field_name.to_string());
|
||||
(ast, errors)
|
||||
}),
|
||||
opt_i_err(char(')'), "expected ')'"),
|
||||
)(inp);
|
||||
res
|
||||
)(inp)
|
||||
}
|
||||
|
||||
fn exists(inp: &str) -> IResult<&str, UserInputLeaf> {
|
||||
@@ -367,7 +367,10 @@ fn literal(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
// something (a field name) got parsed before
|
||||
alt((
|
||||
map(
|
||||
tuple((opt(field_name), alt((range, set, exists, term_or_phrase)))),
|
||||
tuple((
|
||||
opt(field_name),
|
||||
alt((range, set, exists, regex, term_or_phrase)),
|
||||
)),
|
||||
|(field_name, leaf): (Option<String>, UserInputLeaf)| leaf.set_field(field_name).into(),
|
||||
),
|
||||
term_group,
|
||||
@@ -389,6 +392,10 @@ fn literal_no_group_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>>
|
||||
value((), peek(one_of("{[><"))),
|
||||
map(range_infallible, |(range, errs)| (Some(range), errs)),
|
||||
),
|
||||
(
|
||||
value((), peek(one_of("/"))),
|
||||
map(regex_infallible, |(regex, errs)| (Some(regex), errs)),
|
||||
),
|
||||
),
|
||||
delimited_infallible(space0_infallible, term_or_phrase_infallible, nothing),
|
||||
),
|
||||
@@ -689,6 +696,61 @@ fn set_infallible(mut inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
}
|
||||
}
|
||||
|
||||
fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
|
||||
map(
|
||||
terminated(
|
||||
delimited(
|
||||
char('/'),
|
||||
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
|
||||
char('/'),
|
||||
),
|
||||
peek(alt((multispace1, eof))),
|
||||
),
|
||||
|elements| UserInputLeaf::Regex {
|
||||
field: None,
|
||||
pattern: elements.into_iter().collect::<String>(),
|
||||
},
|
||||
)(inp)
|
||||
}
|
||||
|
||||
fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
match terminated_infallible(
|
||||
delimited_infallible(
|
||||
opt_i_err(char('/'), "missing delimiter /"),
|
||||
opt_i(many1(alt((preceded(char('\\'), char('/')), none_of("/"))))),
|
||||
opt_i_err(char('/'), "missing delimiter /"),
|
||||
),
|
||||
opt_i_err(
|
||||
peek(alt((multispace1, eof))),
|
||||
"expected whitespace or end of input",
|
||||
),
|
||||
)(inp)
|
||||
{
|
||||
Ok((rest, (elements_part, errors))) => {
|
||||
let pattern = match elements_part {
|
||||
Some(elements_part) => elements_part.into_iter().collect(),
|
||||
None => String::new(),
|
||||
};
|
||||
let res = UserInputLeaf::Regex {
|
||||
field: None,
|
||||
pattern,
|
||||
};
|
||||
Ok((rest, (res, errors)))
|
||||
}
|
||||
Err(e) => {
|
||||
let errs = vec![LenientErrorInternal {
|
||||
pos: inp.len(),
|
||||
message: e.to_string(),
|
||||
}];
|
||||
let res = UserInputLeaf::Regex {
|
||||
field: None,
|
||||
pattern: String::new(),
|
||||
};
|
||||
Ok((inp, (res, errs)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn negate(expr: UserInputAst) -> UserInputAst {
|
||||
expr.unary(Occur::MustNot)
|
||||
}
|
||||
@@ -696,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
|
||||
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
alt((
|
||||
delimited(char('('), ast, char(')')),
|
||||
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
|
||||
map(
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
|_| UserInputAst::from(UserInputLeaf::All),
|
||||
),
|
||||
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
|
||||
literal,
|
||||
))(inp)
|
||||
@@ -717,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
|
||||
),
|
||||
),
|
||||
(
|
||||
value((), char('*')),
|
||||
value(
|
||||
(),
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
),
|
||||
map(nothing, |_| {
|
||||
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
|
||||
}),
|
||||
@@ -753,7 +835,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
tuple((leaf, fallible(boost))),
|
||||
|(leaf, boost_opt)| match boost_opt {
|
||||
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => {
|
||||
UserInputAst::Boost(Box::new(leaf), boost)
|
||||
UserInputAst::Boost(Box::new(leaf), boost.into())
|
||||
}
|
||||
_ => leaf,
|
||||
},
|
||||
@@ -765,7 +847,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
|
||||
tuple_infallible((leaf_infallible, boost)),
|
||||
|((leaf, boost_opt), error)| match boost_opt {
|
||||
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => (
|
||||
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)),
|
||||
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())),
|
||||
error,
|
||||
),
|
||||
_ => (leaf, error),
|
||||
@@ -1016,12 +1098,25 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec<LenientError>
|
||||
(rewrite_ast(res), errors)
|
||||
}
|
||||
|
||||
/// Removes unnecessary children clauses in AST
|
||||
///
|
||||
/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
|
||||
fn rewrite_ast(mut input: UserInputAst) -> UserInputAst {
|
||||
if let UserInputAst::Clause(terms) = &mut input {
|
||||
for term in terms {
|
||||
if let UserInputAst::Clause(sub_clauses) = &mut input {
|
||||
// call rewrite_ast recursively on children clauses if applicable
|
||||
let mut new_clauses = Vec::with_capacity(sub_clauses.len());
|
||||
for (occur, clause) in sub_clauses.drain(..) {
|
||||
let rewritten_clause = rewrite_ast(clause);
|
||||
new_clauses.push((occur, rewritten_clause));
|
||||
}
|
||||
*sub_clauses = new_clauses;
|
||||
|
||||
// remove duplicate child clauses
|
||||
// e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d)
|
||||
let mut seen = FnvHashSet::default();
|
||||
sub_clauses.retain(|term| seen.insert(term.clone()));
|
||||
|
||||
// Removes unnecessary children clauses in AST
|
||||
//
|
||||
// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
|
||||
for term in sub_clauses {
|
||||
rewrite_ast_clause(term);
|
||||
}
|
||||
}
|
||||
@@ -1283,6 +1378,10 @@ mod test {
|
||||
super::field_name("~my~field:a"),
|
||||
Ok(("a", "~my~field".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
super::field_name(".my.field.name : a"),
|
||||
Ok(("a", ".my.field.name".to_string()))
|
||||
);
|
||||
for special_char in SPECIAL_CHARS.iter() {
|
||||
let query = &format!("\\{special_char}my\\{special_char}field:a");
|
||||
assert_eq!(
|
||||
@@ -1592,6 +1691,21 @@ mod test {
|
||||
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
|
||||
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
|
||||
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
|
||||
|
||||
// Phrase prefixed with *
|
||||
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
|
||||
test_parse_query_to_ast_helper("*A", "*A");
|
||||
test_parse_query_to_ast_helper("(*A)", "*A");
|
||||
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
|
||||
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
|
||||
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_all() {
|
||||
test_parse_query_to_ast_helper("*", "*");
|
||||
test_parse_query_to_ast_helper("(*)", "*");
|
||||
test_parse_query_to_ast_helper("(* )", "*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1689,4 +1803,72 @@ mod test {
|
||||
fn test_invalid_field() {
|
||||
test_is_parse_err(r#"!bc:def"#, "!bc:def");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_parser() {
|
||||
let r = parse_to_ast(r#"a:/joh?n(ath[oa]n)/"#);
|
||||
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
|
||||
let (_, input) = r.unwrap();
|
||||
match input {
|
||||
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
|
||||
UserInputLeaf::Regex { field, pattern } => {
|
||||
assert_eq!(field, &Some("a".to_string()));
|
||||
assert_eq!(pattern, "joh?n(ath[oa]n)");
|
||||
}
|
||||
_ => panic!("Expected a regex leaf, got {leaf:?}"),
|
||||
},
|
||||
_ => panic!("Expected a leaf"),
|
||||
}
|
||||
let r = parse_to_ast(r#"a:/\\/cgi-bin\\/luci.*/"#);
|
||||
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
|
||||
let (_, input) = r.unwrap();
|
||||
match input {
|
||||
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
|
||||
UserInputLeaf::Regex { field, pattern } => {
|
||||
assert_eq!(field, &Some("a".to_string()));
|
||||
assert_eq!(pattern, "\\/cgi-bin\\/luci.*");
|
||||
}
|
||||
_ => panic!("Expected a regex leaf, got {leaf:?}"),
|
||||
},
|
||||
_ => panic!("Expected a leaf"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_parser_lenient() {
|
||||
let literal = |query| literal_infallible(query).unwrap().1;
|
||||
|
||||
let (res, errs) = literal(r#"a:/joh?n(ath[oa]n)/"#);
|
||||
let expected = UserInputLeaf::Regex {
|
||||
field: Some("a".to_string()),
|
||||
pattern: "joh?n(ath[oa]n)".to_string(),
|
||||
}
|
||||
.into();
|
||||
assert_eq!(res.unwrap(), expected);
|
||||
assert!(errs.is_empty(), "Expected no errors, got: {errs:?}");
|
||||
|
||||
let (res, errs) = literal("title:/joh?n(ath[oa]n)");
|
||||
let expected = UserInputLeaf::Regex {
|
||||
field: Some("title".to_string()),
|
||||
pattern: "joh?n(ath[oa]n)".to_string(),
|
||||
}
|
||||
.into();
|
||||
assert_eq!(res.unwrap(), expected);
|
||||
assert_eq!(errs.len(), 1, "Expected 1 error, got: {errs:?}");
|
||||
assert_eq!(
|
||||
errs[0].message, "missing delimiter /",
|
||||
"Unexpected error message",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_before_value() {
|
||||
test_parse_query_to_ast_helper("field : a", r#""field":a"#);
|
||||
test_parse_query_to_ast_helper("field: a", r#""field":a"#);
|
||||
test_parse_query_to_ast_helper("field :a", r#""field":a"#);
|
||||
test_parse_query_to_ast_helper(
|
||||
"field : 'happy tax payer' AND other_field : 1",
|
||||
r#"(+"field":'happy tax payer' +"other_field":1)"#,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use serde::Serialize;
|
||||
|
||||
use crate::Occur;
|
||||
|
||||
#[derive(PartialEq, Clone, Serialize)]
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UserInputLeaf {
|
||||
@@ -23,6 +23,10 @@ pub enum UserInputLeaf {
|
||||
Exists {
|
||||
field: String,
|
||||
},
|
||||
Regex {
|
||||
field: Option<String>,
|
||||
pattern: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl UserInputLeaf {
|
||||
@@ -46,6 +50,7 @@ impl UserInputLeaf {
|
||||
UserInputLeaf::Exists { field: _ } => UserInputLeaf::Exists {
|
||||
field: field.expect("Exist query without a field isn't allowed"),
|
||||
},
|
||||
UserInputLeaf::Regex { field: _, pattern } => UserInputLeaf::Regex { field, pattern },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,11 +108,19 @@ impl Debug for UserInputLeaf {
|
||||
UserInputLeaf::Exists { field } => {
|
||||
write!(formatter, "$exists(\"{field}\")")
|
||||
}
|
||||
UserInputLeaf::Regex { field, pattern } => {
|
||||
if let Some(field) = field {
|
||||
// TODO properly escape field (in case of \")
|
||||
write!(formatter, "\"{field}\":")?;
|
||||
}
|
||||
// TODO properly escape pattern (in case of \")
|
||||
write!(formatter, "/{pattern}/")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Delimiter {
|
||||
SingleQuotes,
|
||||
@@ -115,7 +128,7 @@ pub enum Delimiter {
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone, Serialize)]
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct UserInputLiteral {
|
||||
pub field_name: Option<String>,
|
||||
@@ -154,7 +167,7 @@ impl fmt::Debug for UserInputLiteral {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Serialize)]
|
||||
#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UserInputBound {
|
||||
@@ -191,11 +204,11 @@ impl UserInputBound {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone, Serialize)]
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
|
||||
#[serde(into = "UserInputAstSerde")]
|
||||
pub enum UserInputAst {
|
||||
Clause(Vec<(Option<Occur>, UserInputAst)>),
|
||||
Boost(Box<UserInputAst>, f64),
|
||||
Boost(Box<UserInputAst>, ordered_float::OrderedFloat<f64>),
|
||||
Leaf(Box<UserInputLeaf>),
|
||||
}
|
||||
|
||||
@@ -217,9 +230,10 @@ impl From<UserInputAst> for UserInputAstSerde {
|
||||
fn from(ast: UserInputAst) -> Self {
|
||||
match ast {
|
||||
UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause },
|
||||
UserInputAst::Boost(underlying, boost) => {
|
||||
UserInputAstSerde::Boost { underlying, boost }
|
||||
}
|
||||
UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost {
|
||||
underlying,
|
||||
boost: boost.into_inner(),
|
||||
},
|
||||
UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf),
|
||||
}
|
||||
}
|
||||
@@ -378,7 +392,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_boost_serialization() {
|
||||
let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All));
|
||||
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5);
|
||||
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into());
|
||||
let json = serde_json::to_string(&boost_ast).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
@@ -405,7 +419,7 @@ mod tests {
|
||||
}))),
|
||||
),
|
||||
])),
|
||||
2.5,
|
||||
2.5.into(),
|
||||
);
|
||||
let json = serde_json::to_string(&boost_ast).unwrap();
|
||||
assert_eq!(
|
||||
|
||||
@@ -20,17 +20,16 @@ Contains all metric aggregations, like average aggregation. Metric aggregations
|
||||
#### agg_req
|
||||
agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests.
|
||||
|
||||
#### agg_req_with_accessor
|
||||
agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are
|
||||
#### agg_data
|
||||
agg_data contains the users aggregation request enriched with fast field accessors etc, which are
|
||||
used during collection.
|
||||
|
||||
#### segment_agg_result
|
||||
segment_agg_result contains the aggregation result tree, which is used for collection of a segment.
|
||||
The tree from agg_req_with_accessor is passed during collection.
|
||||
agg_data is passed during collection.
|
||||
|
||||
#### intermediate_agg_result
|
||||
intermediate_agg_result contains the aggregation tree for merging with other trees.
|
||||
|
||||
#### agg_result
|
||||
agg_result contains the final aggregation tree.
|
||||
|
||||
|
||||
105
src/aggregation/accessor_helpers.rs
Normal file
105
src/aggregation/accessor_helpers.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
//! This will enhance the request tree with access to the fastfield and metadata.
|
||||
|
||||
use std::io;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
|
||||
use crate::aggregation::{f64_to_fastfield_u64, Key};
|
||||
use crate::index::SegmentReader;
|
||||
|
||||
/// Get the missing value as internal u64 representation
|
||||
///
|
||||
/// For terms we use u64::MAX as sentinel value
|
||||
/// For numerical data we convert the value into the representation
|
||||
/// we would get from the fast field, when we open it as u64_lenient_for_type.
|
||||
///
|
||||
/// That way we can use it the same way as if it would come from the fastfield.
|
||||
pub(crate) fn get_missing_val_as_u64_lenient(
|
||||
column_type: ColumnType,
|
||||
column_max_value: u64,
|
||||
missing: &Key,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Option<u64>> {
|
||||
let missing_val = match missing {
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
// Allow fallback to number on text fields
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::F64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val, &column_type)
|
||||
}
|
||||
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
|
||||
Key::I64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val as f64, &column_type)
|
||||
}
|
||||
Key::U64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val as f64, &column_type)
|
||||
}
|
||||
_ => {
|
||||
return Err(crate::TantivyError::InvalidArgument(format!(
|
||||
"Missing value {missing:?} for field {field_name} is not supported for column \
|
||||
type {column_type:?}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(missing_val)
|
||||
}
|
||||
|
||||
pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
|
||||
&[
|
||||
ColumnType::F64,
|
||||
ColumnType::U64,
|
||||
ColumnType::I64,
|
||||
ColumnType::DateTime,
|
||||
]
|
||||
}
|
||||
|
||||
/// Get fast field reader or empty as default.
|
||||
pub(crate) fn get_ff_reader(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
|
||||
let ff_fields = reader.fast_fields();
|
||||
let ff_field_with_type = ff_fields
|
||||
.u64_lenient_for_type(allowed_column_types, field_name)?
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
Column::build_empty_column(reader.num_docs()),
|
||||
ColumnType::U64,
|
||||
)
|
||||
});
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
|
||||
pub(crate) fn get_dynamic_columns(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Vec<columnar::DynamicColumn>> {
|
||||
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
|
||||
let cols = ff_fields
|
||||
.iter()
|
||||
.map(|h| h.open())
|
||||
.collect::<io::Result<_>>()?;
|
||||
assert!(!ff_fields.is_empty(), "field {field_name} not found");
|
||||
Ok(cols)
|
||||
}
|
||||
|
||||
/// Get all fast field reader or empty as default.
|
||||
///
|
||||
/// Is guaranteed to return at least one column.
|
||||
pub(crate) fn get_all_ff_reader_or_empty(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
fallback_type: ColumnType,
|
||||
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
|
||||
let ff_fields = reader.fast_fields();
|
||||
let mut ff_field_with_type =
|
||||
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
|
||||
if ff_field_with_type.is_empty() {
|
||||
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
|
||||
}
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
1095
src/aggregation/agg_data.rs
Normal file
1095
src/aggregation/agg_data.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
|
||||
/// Allocated memory with this guard.
|
||||
allocated_with_the_guard: u64,
|
||||
}
|
||||
|
||||
impl Clone for AggregationLimitsGuard {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
@@ -70,7 +71,7 @@ impl AggregationLimitsGuard {
|
||||
/// *memory_limit*
|
||||
/// memory_limit is defined in bytes.
|
||||
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
|
||||
/// memory_limit.
|
||||
/// memory_limit.
|
||||
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
|
||||
///
|
||||
/// *bucket_limit*
|
||||
|
||||
@@ -26,12 +26,14 @@
|
||||
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
|
||||
//! ```
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::bucket::{
|
||||
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
|
||||
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
|
||||
TermsAggregation,
|
||||
};
|
||||
use super::metric::{
|
||||
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
||||
@@ -43,7 +45,7 @@ use super::metric::{
|
||||
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
|
||||
///
|
||||
/// The key is the user defined name of the aggregation.
|
||||
pub type Aggregations = HashMap<String, Aggregation>;
|
||||
pub type Aggregations = FxHashMap<String, Aggregation>;
|
||||
|
||||
/// Aggregation request.
|
||||
///
|
||||
@@ -129,6 +131,9 @@ pub enum AggregationVariants {
|
||||
/// Put data into buckets of terms.
|
||||
#[serde(rename = "terms")]
|
||||
Terms(TermsAggregation),
|
||||
/// Filter documents into a single bucket.
|
||||
#[serde(rename = "filter")]
|
||||
Filter(FilterAggregation),
|
||||
|
||||
// Metric aggregation types
|
||||
/// Computes the average of the extracted values.
|
||||
@@ -174,6 +179,7 @@ impl AggregationVariants {
|
||||
AggregationVariants::Range(range) => vec![range.field.as_str()],
|
||||
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
|
||||
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
|
||||
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
|
||||
AggregationVariants::Average(avg) => vec![avg.field_name()],
|
||||
AggregationVariants::Count(count) => vec![count.field_name()],
|
||||
AggregationVariants::Max(max) => vec![max.field_name()],
|
||||
@@ -208,13 +214,6 @@ impl AggregationVariants {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
|
||||
match &self {
|
||||
AggregationVariants::TopHits(top_hits) => Some(top_hits),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
|
||||
match &self {
|
||||
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
//! This will enhance the request tree with access to the fastfield and metadata.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn};
|
||||
|
||||
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
use super::bucket::{
|
||||
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
|
||||
};
|
||||
use super::metric::{
|
||||
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
||||
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
|
||||
};
|
||||
use super::segment_agg_result::AggregationLimitsGuard;
|
||||
use super::VecWithNames;
|
||||
use crate::aggregation::{f64_to_fastfield_u64, Key};
|
||||
use crate::index::SegmentReader;
|
||||
use crate::SegmentOrdinal;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct AggregationsWithAccessor {
|
||||
pub aggs: VecWithNames<AggregationWithAccessor>,
|
||||
}
|
||||
|
||||
impl AggregationsWithAccessor {
|
||||
fn from_data(aggs: VecWithNames<AggregationWithAccessor>) -> Self {
|
||||
Self { aggs }
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.aggs.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AggregationWithAccessor {
|
||||
pub(crate) segment_ordinal: SegmentOrdinal,
|
||||
/// In general there can be buckets without fast field access, e.g. buckets that are created
|
||||
/// based on search terms. That is not that case currently, but eventually this needs to be
|
||||
/// Option or moved.
|
||||
pub(crate) accessor: Column<u64>,
|
||||
/// Load insert u64 for missing use case
|
||||
pub(crate) missing_value_for_accessor: Option<u64>,
|
||||
pub(crate) str_dict_column: Option<StrColumn>,
|
||||
pub(crate) field_type: ColumnType,
|
||||
pub(crate) sub_aggregation: AggregationsWithAccessor,
|
||||
pub(crate) limits: AggregationLimitsGuard,
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// Used for missing term aggregation, which checks all columns for existence.
|
||||
/// And also for `top_hits` aggregation, which may sort on multiple fields.
|
||||
/// By convention the missing aggregation is chosen, when this property is set
|
||||
/// (instead bein set in `agg`).
|
||||
/// If this needs to used by other aggregations, we need to refactor this.
|
||||
// NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type`
|
||||
// (making them obsolete) But will it have a performance impact?
|
||||
pub(crate) accessors: Vec<(Column<u64>, ColumnType)>,
|
||||
/// Map field names to all associated column accessors.
|
||||
/// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`.
|
||||
pub(crate) value_accessors: HashMap<String, Vec<DynamicColumn>>,
|
||||
pub(crate) agg: Aggregation,
|
||||
}
|
||||
|
||||
impl AggregationWithAccessor {
|
||||
/// May return multiple accessors if the aggregation is e.g. on mixed field types.
|
||||
fn try_from_agg(
|
||||
agg: &Aggregation,
|
||||
sub_aggregation: &Aggregations,
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
limits: AggregationLimitsGuard,
|
||||
) -> crate::Result<Vec<AggregationWithAccessor>> {
|
||||
let mut agg = agg.clone();
|
||||
|
||||
let add_agg_with_accessor = |agg: &Aggregation,
|
||||
accessor: Column<u64>,
|
||||
column_type: ColumnType,
|
||||
aggs: &mut Vec<AggregationWithAccessor>|
|
||||
-> crate::Result<()> {
|
||||
let res = AggregationWithAccessor {
|
||||
segment_ordinal,
|
||||
accessor,
|
||||
accessors: Default::default(),
|
||||
value_accessors: Default::default(),
|
||||
field_type: column_type,
|
||||
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
|
||||
sub_aggregation,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
&limits,
|
||||
)?,
|
||||
agg: agg.clone(),
|
||||
limits: limits.clone(),
|
||||
missing_value_for_accessor: None,
|
||||
str_dict_column: None,
|
||||
column_block_accessor: Default::default(),
|
||||
};
|
||||
aggs.push(res);
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let add_agg_with_accessors = |agg: &Aggregation,
|
||||
accessors: Vec<(Column<u64>, ColumnType)>,
|
||||
aggs: &mut Vec<AggregationWithAccessor>,
|
||||
value_accessors: HashMap<String, Vec<DynamicColumn>>|
|
||||
-> crate::Result<()> {
|
||||
let (accessor, field_type) = accessors.first().expect("at least one accessor");
|
||||
let limits = limits.clone();
|
||||
let res = AggregationWithAccessor {
|
||||
segment_ordinal,
|
||||
// TODO: We should do away with the `accessor` field altogether
|
||||
accessor: accessor.clone(),
|
||||
value_accessors,
|
||||
field_type: *field_type,
|
||||
accessors,
|
||||
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
|
||||
sub_aggregation,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
&limits,
|
||||
)?,
|
||||
agg: agg.clone(),
|
||||
limits,
|
||||
missing_value_for_accessor: None,
|
||||
str_dict_column: None,
|
||||
column_block_accessor: Default::default(),
|
||||
};
|
||||
aggs.push(res);
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let mut res: Vec<AggregationWithAccessor> = Vec::new();
|
||||
use AggregationVariants::*;
|
||||
|
||||
match agg.agg {
|
||||
Range(RangeAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
}) => {
|
||||
let (accessor, column_type) =
|
||||
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
Histogram(HistogramAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
}) => {
|
||||
let (accessor, column_type) =
|
||||
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
DateHistogram(DateHistogramAggregationReq {
|
||||
field: ref field_name,
|
||||
..
|
||||
}) => {
|
||||
let (accessor, column_type) =
|
||||
// Only DateTime is supported for DateHistogram
|
||||
get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
Terms(TermsAggregation {
|
||||
field: ref field_name,
|
||||
ref missing,
|
||||
..
|
||||
})
|
||||
| Cardinality(CardinalityAggregationReq {
|
||||
field: ref field_name,
|
||||
ref missing,
|
||||
..
|
||||
}) => {
|
||||
let str_dict_column = reader.fast_fields().str(field_name)?;
|
||||
let allowed_column_types = [
|
||||
ColumnType::I64,
|
||||
ColumnType::U64,
|
||||
ColumnType::F64,
|
||||
ColumnType::Str,
|
||||
ColumnType::DateTime,
|
||||
ColumnType::Bool,
|
||||
ColumnType::IpAddr,
|
||||
// ColumnType::Bytes Unsupported
|
||||
];
|
||||
|
||||
// In case the column is empty we want the shim column to match the missing type
|
||||
let fallback_type = missing
|
||||
.as_ref()
|
||||
.map(|missing| match missing {
|
||||
Key::Str(_) => ColumnType::Str,
|
||||
Key::F64(_) => ColumnType::F64,
|
||||
Key::I64(_) => ColumnType::I64,
|
||||
Key::U64(_) => ColumnType::U64,
|
||||
})
|
||||
.unwrap_or(ColumnType::U64);
|
||||
let column_and_types = get_all_ff_reader_or_empty(
|
||||
reader,
|
||||
field_name,
|
||||
Some(&allowed_column_types),
|
||||
fallback_type,
|
||||
)?;
|
||||
let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some();
|
||||
let text_on_non_text_col = column_and_types.len() == 1
|
||||
&& column_and_types[0].1.numerical_type().is_some()
|
||||
&& missing
|
||||
.as_ref()
|
||||
.map(|m| matches!(m, Key::Str(_)))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Actually we could convert the text to a number and have the fast path, if it is
|
||||
// provided in Rfc3339 format. But this use case is probably common
|
||||
// enough to justify the effort.
|
||||
let text_on_date_col = column_and_types.len() == 1
|
||||
&& column_and_types[0].1 == ColumnType::DateTime
|
||||
&& missing
|
||||
.as_ref()
|
||||
.map(|m| matches!(m, Key::Str(_)))
|
||||
.unwrap_or(false);
|
||||
|
||||
let use_special_missing_agg =
|
||||
missing_and_more_than_one_col || text_on_non_text_col || text_on_date_col;
|
||||
if use_special_missing_agg {
|
||||
let column_and_types =
|
||||
get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?;
|
||||
|
||||
let accessors = column_and_types
|
||||
.iter()
|
||||
.map(|c_t| (c_t.0.clone(), c_t.1))
|
||||
.collect();
|
||||
add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?;
|
||||
}
|
||||
|
||||
for (accessor, column_type) in column_and_types {
|
||||
let missing_value_term_agg = if use_special_missing_agg {
|
||||
None
|
||||
} else {
|
||||
missing.clone()
|
||||
};
|
||||
|
||||
let missing_value_for_accessor =
|
||||
if let Some(missing) = missing_value_term_agg.as_ref() {
|
||||
get_missing_val_as_u64_lenient(
|
||||
column_type,
|
||||
missing,
|
||||
agg.agg.get_fast_field_names()[0],
|
||||
)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let limits = limits.clone();
|
||||
let agg = AggregationWithAccessor {
|
||||
segment_ordinal,
|
||||
missing_value_for_accessor,
|
||||
accessor,
|
||||
accessors: Default::default(),
|
||||
value_accessors: Default::default(),
|
||||
field_type: column_type,
|
||||
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
|
||||
sub_aggregation,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
&limits,
|
||||
)?,
|
||||
agg: agg.clone(),
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
limits,
|
||||
column_block_accessor: Default::default(),
|
||||
};
|
||||
res.push(agg);
|
||||
}
|
||||
}
|
||||
Average(AverageAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
})
|
||||
| Max(MaxAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
})
|
||||
| Min(MinAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
})
|
||||
| Stats(StatsAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
})
|
||||
| ExtendedStats(ExtendedStatsAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
})
|
||||
| Sum(SumAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
}) => {
|
||||
let (accessor, column_type) =
|
||||
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
Count(CountAggregation {
|
||||
field: ref field_name,
|
||||
..
|
||||
}) => {
|
||||
let allowed_column_types = [
|
||||
ColumnType::I64,
|
||||
ColumnType::U64,
|
||||
ColumnType::F64,
|
||||
ColumnType::Str,
|
||||
ColumnType::DateTime,
|
||||
ColumnType::Bool,
|
||||
ColumnType::IpAddr,
|
||||
// ColumnType::Bytes Unsupported
|
||||
];
|
||||
let (accessor, column_type) =
|
||||
get_ff_reader(reader, field_name, Some(&allowed_column_types))?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
Percentiles(ref percentiles) => {
|
||||
let (accessor, column_type) = get_ff_reader(
|
||||
reader,
|
||||
percentiles.field_name(),
|
||||
Some(get_numeric_or_date_column_types()),
|
||||
)?;
|
||||
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
|
||||
}
|
||||
TopHits(ref mut top_hits) => {
|
||||
top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?;
|
||||
let accessors: Vec<(Column<u64>, ColumnType)> = top_hits
|
||||
.field_names()
|
||||
.iter()
|
||||
.map(|field| {
|
||||
get_ff_reader(reader, field, Some(get_numeric_or_date_column_types()))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
let value_accessors = top_hits
|
||||
.value_field_names()
|
||||
.iter()
|
||||
.map(|field_name| {
|
||||
Ok((
|
||||
field_name.to_string(),
|
||||
get_dynamic_columns(reader, field_name)?,
|
||||
))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the missing value as internal u64 representation
|
||||
///
|
||||
/// For terms we use u64::MAX as sentinel value
|
||||
/// For numerical data we convert the value into the representation
|
||||
/// we would get from the fast field, when we open it as u64_lenient_for_type.
|
||||
///
|
||||
/// That way we can use it the same way as if it would come from the fastfield.
|
||||
fn get_missing_val_as_u64_lenient(
|
||||
column_type: ColumnType,
|
||||
missing: &Key,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Option<u64>> {
|
||||
let missing_val = match missing {
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
// Allow fallback to number on text fields
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::F64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val, &column_type)
|
||||
}
|
||||
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
|
||||
Key::I64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val as f64, &column_type)
|
||||
}
|
||||
Key::U64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val as f64, &column_type)
|
||||
}
|
||||
_ => {
|
||||
return Err(crate::TantivyError::InvalidArgument(format!(
|
||||
"Missing value {missing:?} for field {field_name} is not supported for column \
|
||||
type {column_type:?}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(missing_val)
|
||||
}
|
||||
|
||||
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
|
||||
&[
|
||||
ColumnType::F64,
|
||||
ColumnType::U64,
|
||||
ColumnType::I64,
|
||||
ColumnType::DateTime,
|
||||
]
|
||||
}
|
||||
|
||||
pub(crate) fn get_aggs_with_segment_accessor_and_validate(
|
||||
aggs: &Aggregations,
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
limits: &AggregationLimitsGuard,
|
||||
) -> crate::Result<AggregationsWithAccessor> {
|
||||
let mut aggss = Vec::new();
|
||||
for (key, agg) in aggs.iter() {
|
||||
let aggs = AggregationWithAccessor::try_from_agg(
|
||||
agg,
|
||||
agg.sub_aggregation(),
|
||||
reader,
|
||||
segment_ordinal,
|
||||
limits.clone(),
|
||||
)?;
|
||||
for agg in aggs {
|
||||
aggss.push((key.to_string(), agg));
|
||||
}
|
||||
}
|
||||
Ok(AggregationsWithAccessor::from_data(
|
||||
VecWithNames::from_entries(aggss),
|
||||
))
|
||||
}
|
||||
|
||||
/// Get fast field reader or empty as default.
|
||||
fn get_ff_reader(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
|
||||
let ff_fields = reader.fast_fields();
|
||||
let ff_field_with_type = ff_fields
|
||||
.u64_lenient_for_type(allowed_column_types, field_name)?
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
Column::build_empty_column(reader.num_docs()),
|
||||
ColumnType::U64,
|
||||
)
|
||||
});
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
|
||||
fn get_dynamic_columns(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Vec<columnar::DynamicColumn>> {
|
||||
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
|
||||
let cols = ff_fields
|
||||
.iter()
|
||||
.map(|h| h.open())
|
||||
.collect::<io::Result<_>>()?;
|
||||
assert!(!ff_fields.is_empty(), "field {field_name} not found");
|
||||
Ok(cols)
|
||||
}
|
||||
|
||||
/// Get all fast field reader or empty as default.
|
||||
///
|
||||
/// Is guaranteed to return at least one column.
|
||||
fn get_all_ff_reader_or_empty(
|
||||
reader: &SegmentReader,
|
||||
field_name: &str,
|
||||
allowed_column_types: Option<&[ColumnType]>,
|
||||
fallback_type: ColumnType,
|
||||
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
|
||||
let ff_fields = reader.fast_fields();
|
||||
let mut ff_field_with_type =
|
||||
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
|
||||
if ff_field_with_type.is_empty() {
|
||||
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
|
||||
}
|
||||
Ok(ff_field_with_type)
|
||||
}
|
||||
@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The final aggegation result.
|
||||
/// The final aggregation result.
|
||||
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
|
||||
|
||||
impl AggregationResults {
|
||||
@@ -156,6 +156,8 @@ pub enum BucketResult {
|
||||
/// The upper bound error for the doc count of each term.
|
||||
doc_count_error_upper_bound: Option<u64>,
|
||||
},
|
||||
/// This is the filter result - a single bucket with sub-aggregations
|
||||
Filter(FilterBucketResult),
|
||||
}
|
||||
|
||||
impl BucketResult {
|
||||
@@ -172,6 +174,11 @@ impl BucketResult {
|
||||
sum_other_doc_count: _,
|
||||
doc_count_error_upper_bound: _,
|
||||
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
|
||||
BucketResult::Filter(filter_result) => {
|
||||
// Filter doesn't add to bucket count - it's not a user-facing bucket
|
||||
// Only count sub-aggregation buckets
|
||||
filter_result.sub_aggregations.get_bucket_count()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -308,3 +315,25 @@ impl RangeBucketEntry {
|
||||
1 + self.sub_aggregation.get_bucket_count()
|
||||
}
|
||||
}
|
||||
|
||||
/// This is the filter bucket result, which contains the document count and sub-aggregations.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
/// {
|
||||
/// "electronics_only": {
|
||||
/// "doc_count": 2,
|
||||
/// "avg_price": {
|
||||
/// "value": 150.0
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct FilterBucketResult {
|
||||
/// Number of documents in the filter bucket
|
||||
pub doc_count: u64,
|
||||
/// Sub-aggregation results
|
||||
#[serde(flatten)]
|
||||
pub sub_aggregations: AggregationResults,
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||
use crate::aggregation::collector::AggregationCollector;
|
||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
|
||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||
use crate::aggregation::DistributedAggregationCollector;
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
@@ -128,10 +127,8 @@ fn test_aggregation_flushing(
|
||||
.unwrap();
|
||||
|
||||
let agg_res: AggregationResults = if use_distributed_collector {
|
||||
let collector = DistributedAggregationCollector::from_aggs(
|
||||
agg_req.clone(),
|
||||
AggregationLimitsGuard::default(),
|
||||
);
|
||||
let collector =
|
||||
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
|
||||
|
||||
let searcher = reader.searcher();
|
||||
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
1754
src/aggregation/bucket/filter.rs
Normal file
1754
src/aggregation/bucket/filter.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,25 +1,54 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy_bitpacker::minmax;
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::agg_result::BucketEntry;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateHistogramBucketEntry,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Contains all information required by the SegmentHistogramCollector to perform the
|
||||
/// histogram or date_histogram aggregation on a segment.
|
||||
pub struct HistogramAggReqData {
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// The field type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
|
||||
/// Will be filled during initialization of the collector.
|
||||
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The histogram aggregation request.
|
||||
pub req: HistogramAggregation,
|
||||
/// True if this is a date_histogram aggregation.
|
||||
pub is_date_histogram: bool,
|
||||
/// The bounds to limit the buckets to.
|
||||
pub bounds: HistogramBounds,
|
||||
/// The offset used to calculate the bucket position.
|
||||
pub offset: f64,
|
||||
}
|
||||
impl HistogramAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
|
||||
/// Each document value is rounded down to its bucket.
|
||||
///
|
||||
@@ -234,12 +263,12 @@ impl SegmentHistogramBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateHistogramBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?;
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
|
||||
}
|
||||
Ok(IntermediateHistogramBucketEntry {
|
||||
key: self.key,
|
||||
@@ -256,24 +285,20 @@ pub struct SegmentHistogramCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
|
||||
sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
column_type: ColumnType,
|
||||
interval: f64,
|
||||
offset: f64,
|
||||
bounds: HistogramBounds,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
|
||||
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
|
||||
let name = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
let bucket = self.into_intermediate_bucket_result(agg_data)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
@@ -283,56 +308,52 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
|
||||
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
|
||||
let bounds = self.bounds;
|
||||
let interval = self.interval;
|
||||
let offset = self.offset;
|
||||
let get_bucket_pos = |val| (get_bucket_pos_f64(val, interval, offset) as i64);
|
||||
let bounds = req.bounds;
|
||||
let interval = req.req.interval;
|
||||
let offset = req.offset;
|
||||
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
|
||||
|
||||
bucket_agg_accessor
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &bucket_agg_accessor.accessor);
|
||||
|
||||
for (doc, val) in bucket_agg_accessor
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let val = self.f64_from_fastfield_u64(val);
|
||||
|
||||
let val = f64_from_fastfield_u64(val, &req.field_type);
|
||||
let bucket_pos = get_bucket_pos(val);
|
||||
|
||||
if bounds.contains(val) {
|
||||
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
|
||||
SegmentHistogramBucketEntry { key, doc_count: 0 }
|
||||
});
|
||||
bucket.doc_count += 1;
|
||||
if let Some(sub_aggregation_blueprint) = self.sub_aggregation_blueprint.as_mut() {
|
||||
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
|
||||
self.sub_aggregations
|
||||
.entry(bucket_pos)
|
||||
.or_insert_with(|| sub_aggregation_blueprint.clone())
|
||||
.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
|
||||
.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
bucket_agg_accessor
|
||||
agg_data
|
||||
.context
|
||||
.limits
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
@@ -340,12 +361,9 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
|
||||
let sub_aggregation_accessor =
|
||||
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for sub_aggregation in self.sub_aggregations.values_mut() {
|
||||
sub_aggregation.flush(sub_aggregation_accessor)?;
|
||||
sub_aggregation.flush(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -362,65 +380,58 @@ impl SegmentHistogramCollector {
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
pub fn into_intermediate_bucket_result(
|
||||
self,
|
||||
agg_with_accessor: &AggregationWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut buckets = Vec::with_capacity(self.buckets.len());
|
||||
|
||||
for (bucket_pos, bucket) in self.buckets {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(
|
||||
self.sub_aggregations.get(&bucket_pos).cloned(),
|
||||
&agg_with_accessor.sub_aggregation,
|
||||
agg_data,
|
||||
);
|
||||
|
||||
buckets.push(bucket_res?);
|
||||
}
|
||||
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
|
||||
|
||||
let is_date_agg = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.field_type
|
||||
== ColumnType::DateTime;
|
||||
Ok(IntermediateBucketResult::Histogram {
|
||||
buckets,
|
||||
is_date_agg: self.column_type == ColumnType::DateTime,
|
||||
is_date_agg,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn from_req_and_validate(
|
||||
mut req: HistogramAggregation,
|
||||
sub_aggregation: &mut AggregationsWithAccessor,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
req.validate()?;
|
||||
if field_type == ColumnType::DateTime {
|
||||
req.normalize_date_time();
|
||||
}
|
||||
|
||||
let sub_aggregation_blueprint = if sub_aggregation.is_empty() {
|
||||
None
|
||||
let blueprint = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
let sub_aggregation = build_segment_agg_collector(sub_aggregation)?;
|
||||
Some(sub_aggregation)
|
||||
None
|
||||
};
|
||||
|
||||
let bounds = req.hard_bounds.unwrap_or(HistogramBounds {
|
||||
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
|
||||
req_data.req.validate()?;
|
||||
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
|
||||
req_data.req.normalize_date_time();
|
||||
}
|
||||
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
|
||||
min: f64::MIN,
|
||||
max: f64::MAX,
|
||||
});
|
||||
req_data.offset = req_data.req.offset.unwrap_or(0.0);
|
||||
|
||||
req_data.sub_aggregation_blueprint = blueprint;
|
||||
|
||||
Ok(Self {
|
||||
buckets: Default::default(),
|
||||
column_type: field_type,
|
||||
interval: req.interval,
|
||||
offset: req.offset.unwrap_or(0.0),
|
||||
bounds,
|
||||
sub_aggregations: Default::default(),
|
||||
sub_aggregation_blueprint,
|
||||
accessor_idx,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn f64_from_fastfield_u64(&self, val: u64) -> f64 {
|
||||
f64_from_fastfield_u64(val, &self.column_type)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
//! - [Range](RangeAggregation)
|
||||
//! - [Terms](TermsAggregation)
|
||||
|
||||
mod filter;
|
||||
mod histogram;
|
||||
mod range;
|
||||
mod term_agg;
|
||||
@@ -30,6 +31,7 @@ mod term_missing_agg;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
pub use filter::*;
|
||||
pub use histogram::*;
|
||||
pub use range::*;
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
@@ -1,20 +1,43 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Contains all information required by the SegmentRangeCollector to perform the
|
||||
/// range aggregation on a segment.
|
||||
pub struct RangeAggReqData {
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// The type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The range aggregation request.
|
||||
pub req: RangeAggregation,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl RangeAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Provide user-defined buckets to aggregate on.
|
||||
///
|
||||
/// Two special buckets will automatically be created to cover the whole range of values.
|
||||
@@ -161,12 +184,12 @@ impl Debug for SegmentRangeBucketEntry {
|
||||
impl SegmentRangeBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateRangeBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = self.sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
@@ -184,12 +207,14 @@ impl SegmentRangeBucketEntry {
|
||||
impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let field_type = self.column_type;
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
|
||||
let name = agg_data
|
||||
.get_range_req_data(self.accessor_idx)
|
||||
.name
|
||||
.to_string();
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
|
||||
.buckets
|
||||
@@ -199,7 +224,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
range_to_string(&range_bucket.range, &field_type)?,
|
||||
range_bucket
|
||||
.bucket
|
||||
.into_intermediate_bucket_entry(sub_agg)?,
|
||||
.into_intermediate_bucket_entry(agg_data)?,
|
||||
))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
@@ -218,66 +243,70 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
// Take request data to avoid borrow conflicts during sub-aggregation
|
||||
let mut req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
|
||||
bucket_agg_accessor
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &bucket_agg_accessor.accessor);
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
|
||||
for (doc, val) in bucket_agg_accessor
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let bucket_pos = self.get_bucket_pos(val);
|
||||
|
||||
let bucket = &mut self.buckets[bucket_pos];
|
||||
|
||||
bucket.bucket.doc_count += 1;
|
||||
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
|
||||
sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
agg_data.put_back_range_req_data(self.accessor_idx, req);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
|
||||
let sub_aggregation_accessor =
|
||||
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for bucket in self.buckets.iter_mut() {
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.flush(sub_aggregation_accessor)?;
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentRangeCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req: &RangeAggregation,
|
||||
sub_aggregation: &mut AggregationsWithAccessor,
|
||||
limits: &mut AggregationLimitsGuard,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let (field_type, ranges) = {
|
||||
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
|
||||
(req_view.field_type, req_view.req.ranges.clone())
|
||||
};
|
||||
|
||||
// The range input on the request is f64.
|
||||
// We need to convert to u64 ranges, because we read the values as u64.
|
||||
// The mapping from the conversion is monotonic so ordering is preserved.
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
|
||||
let sub_agg_prototype = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(req_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let key = range
|
||||
@@ -295,11 +324,7 @@ impl SegmentRangeCollector {
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
let sub_aggregation = if sub_aggregation.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(build_segment_agg_collector(sub_aggregation)?)
|
||||
};
|
||||
let sub_aggregation = sub_agg_prototype.clone();
|
||||
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
@@ -314,7 +339,7 @@ impl SegmentRangeCollector {
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
limits.add_memory_consumed(
|
||||
req_data.context.limits.add_memory_consumed(
|
||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||
)?;
|
||||
|
||||
@@ -467,15 +492,45 @@ mod tests {
|
||||
ranges,
|
||||
..Default::default()
|
||||
};
|
||||
// Build buckets directly as in from_req_and_validate without AggregationsData
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
|
||||
.expect("unexpected error in extend_validate_ranges")
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
.map(|key| Ok(Key::Str(key)))
|
||||
.unwrap_or_else(|| range_to_key(&range.range, &field_type))
|
||||
.expect("unexpected error in range_to_key");
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation: None,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
SegmentRangeCollector::from_req_and_validate(
|
||||
&req,
|
||||
&mut Default::default(),
|
||||
&mut AggregationLimitsGuard::default(),
|
||||
field_type,
|
||||
0,
|
||||
)
|
||||
.expect("unexpected error")
|
||||
SegmentRangeCollector {
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,39 @@
|
||||
use columnar::{Column, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::term_agg::TermsAggregation;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
|
||||
/// Special aggregation to handle missing values for term aggregations.
|
||||
/// This missing aggregation will check multiple columns for existence.
|
||||
///
|
||||
/// This is needed when:
|
||||
/// - The field is multi-valued and we therefore have multiple columns
|
||||
/// - The field is not text and missing is provided as string (we cannot use the numeric missing
|
||||
/// value optimization)
|
||||
#[derive(Default)]
|
||||
pub struct MissingTermAggReqData {
|
||||
/// The accessors to check for existence of a value.
|
||||
pub accessors: Vec<(Column<u64>, ColumnType)>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The original terms aggregation request.
|
||||
pub req: TermsAggregation,
|
||||
}
|
||||
|
||||
impl MissingTermAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -18,12 +44,13 @@ pub struct TermMissingAgg {
|
||||
}
|
||||
impl TermMissingAgg {
|
||||
pub(crate) fn new(
|
||||
accessor_idx: usize,
|
||||
sub_aggregations: &mut AggregationsWithAccessor,
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let has_sub_aggregations = !sub_aggregations.is_empty();
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let sub_agg = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collector(sub_aggregations)?;
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
@@ -40,16 +67,11 @@ impl TermMissingAgg {
|
||||
impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
let term_agg = agg_with_accessor
|
||||
.agg
|
||||
.agg
|
||||
.as_term()
|
||||
.expect("TermMissingAgg collector must be term agg req");
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let term_agg = &req_data.req;
|
||||
let missing = term_agg
|
||||
.missing
|
||||
.as_ref()
|
||||
@@ -64,10 +86,7 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
};
|
||||
if let Some(sub_agg) = self.sub_agg {
|
||||
let mut res = IntermediateAggregationResults::default();
|
||||
sub_agg.add_intermediate_aggregation_result(
|
||||
&agg_with_accessor.sub_aggregation,
|
||||
&mut res,
|
||||
)?;
|
||||
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
|
||||
missing_entry.sub_aggregation = res;
|
||||
}
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
@@ -80,7 +99,10 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
},
|
||||
};
|
||||
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
results.push(
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Bucket(bucket),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -88,17 +110,17 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
let has_value = agg
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
self.missing_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.collect(doc, &mut agg.sub_aggregation)?;
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@@ -107,10 +129,10 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for doc in docs {
|
||||
self.collect(*doc, agg_with_accessor)?;
|
||||
self.collect(*doc, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use super::agg_req_with_accessor::AggregationsWithAccessor;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
|
||||
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
@@ -15,7 +20,7 @@ pub(crate) struct BufAggregationCollector {
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
@@ -37,23 +42,23 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results)
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.staged_docs[self.num_staged_docs] = doc;
|
||||
self.num_staged_docs += 1;
|
||||
if self.num_staged_docs == self.staged_docs.len() {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
@@ -63,20 +68,19 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_with_accessor)?;
|
||||
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
|
||||
self.collector.flush(agg_with_accessor)?;
|
||||
self.collector.flush(agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use super::agg_req::Aggregations;
|
||||
use super::agg_req_with_accessor::AggregationsWithAccessor;
|
||||
use super::agg_result::AggregationResults;
|
||||
use super::buf_collector::BufAggregationCollector;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::{
|
||||
build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector,
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use super::AggContextParams;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate;
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::index::SegmentReader;
|
||||
use crate::{DocId, SegmentOrdinal, TantivyError};
|
||||
@@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
|
||||
/// The collector collects all aggregations by the underlying aggregation request.
|
||||
pub struct AggregationCollector {
|
||||
agg: Aggregations,
|
||||
limits: AggregationLimitsGuard,
|
||||
context: AggContextParams,
|
||||
}
|
||||
|
||||
impl AggregationCollector {
|
||||
@@ -30,8 +30,8 @@ impl AggregationCollector {
|
||||
///
|
||||
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
||||
/// bucket limit)
|
||||
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
|
||||
Self { agg, limits }
|
||||
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
|
||||
Self { agg, context }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ impl AggregationCollector {
|
||||
/// into the final `AggregationResults` via the `into_final_result()` method.
|
||||
pub struct DistributedAggregationCollector {
|
||||
agg: Aggregations,
|
||||
limits: AggregationLimitsGuard,
|
||||
context: AggContextParams,
|
||||
}
|
||||
|
||||
impl DistributedAggregationCollector {
|
||||
@@ -53,8 +53,8 @@ impl DistributedAggregationCollector {
|
||||
///
|
||||
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
||||
/// bucket limit)
|
||||
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
|
||||
Self { agg, limits }
|
||||
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
|
||||
Self { agg, context }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
|
||||
&self.agg,
|
||||
reader,
|
||||
segment_local_id,
|
||||
&self.limits,
|
||||
&self.context,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ impl Collector for AggregationCollector {
|
||||
&self.agg,
|
||||
reader,
|
||||
segment_local_id,
|
||||
&self.limits,
|
||||
&self.context,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ impl Collector for AggregationCollector {
|
||||
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
|
||||
) -> crate::Result<Self::Fruit> {
|
||||
let res = merge_fruits(segment_fruits)?;
|
||||
res.into_final_result(self.agg.clone(), self.limits.clone())
|
||||
res.into_final_result(self.agg.clone(), self.context.limits.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ fn merge_fruits(
|
||||
|
||||
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
|
||||
pub struct AggregationSegmentCollector {
|
||||
aggs_with_accessor: AggregationsWithAccessor,
|
||||
aggs_with_accessor: AggregationsSegmentCtx,
|
||||
agg_collector: BufAggregationCollector,
|
||||
error: Option<TantivyError>,
|
||||
}
|
||||
@@ -147,14 +147,15 @@ impl AggregationSegmentCollector {
|
||||
agg: &Aggregations,
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
limits: &AggregationLimitsGuard,
|
||||
context: &AggContextParams,
|
||||
) -> crate::Result<Self> {
|
||||
let mut aggs_with_accessor =
|
||||
get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?;
|
||||
let mut agg_data =
|
||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||
let result =
|
||||
BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?);
|
||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
|
||||
Ok(AggregationSegmentCollector {
|
||||
aggs_with_accessor,
|
||||
aggs_with_accessor: agg_data,
|
||||
agg_collector: result,
|
||||
error: None,
|
||||
})
|
||||
|
||||
@@ -24,7 +24,9 @@ use super::metric::{
|
||||
};
|
||||
use super::segment_agg_result::AggregationLimitsGuard;
|
||||
use super::{format_date, AggregationError, Key, SerializedKey};
|
||||
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
|
||||
use crate::aggregation::agg_result::{
|
||||
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
|
||||
};
|
||||
use crate::aggregation::bucket::TermsAggregationInternal;
|
||||
use crate::aggregation::metric::CardinalityCollector;
|
||||
use crate::TantivyError;
|
||||
@@ -179,12 +181,17 @@ impl IntermediateAggregationResults {
|
||||
}
|
||||
|
||||
/// Merge another intermediate aggregation result into this result.
|
||||
///
|
||||
/// The order of the values need to be the same on both results. This is ensured when the same
|
||||
/// (key values) are present on the underlying `VecWithNames` struct.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> {
|
||||
for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) {
|
||||
left.merge_fruits(right)?;
|
||||
pub fn merge_fruits(&mut self, mut other: IntermediateAggregationResults) -> crate::Result<()> {
|
||||
for (key, left) in self.aggs_res.iter_mut() {
|
||||
if let Some(key) = other.aggs_res.remove(key) {
|
||||
left.merge_fruits(key)?;
|
||||
}
|
||||
}
|
||||
// Move remainder of other aggs_res into self.
|
||||
// Note: Currently we don't expect this to happen, as we create empty intermediate results
|
||||
// via [IntermediateAggregationResults::empty_from_req].
|
||||
for (key, value) in other.aggs_res {
|
||||
self.aggs_res.insert(key, value);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -241,11 +248,16 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
|
||||
Cardinality(_) => IntermediateAggregationResult::Metric(
|
||||
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
|
||||
),
|
||||
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
|
||||
doc_count: 0,
|
||||
sub_aggregations: IntermediateAggregationResults::default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// An aggregation is either a bucket or a metric.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum IntermediateAggregationResult {
|
||||
/// Bucket variant
|
||||
Bucket(IntermediateBucketResult),
|
||||
@@ -426,6 +438,13 @@ pub enum IntermediateBucketResult {
|
||||
/// The term buckets
|
||||
buckets: IntermediateTermBucketResult,
|
||||
},
|
||||
/// Filter aggregation - a single bucket with sub-aggregations
|
||||
Filter {
|
||||
/// Document count in the filter bucket
|
||||
doc_count: u64,
|
||||
/// Sub-aggregation results
|
||||
sub_aggregations: IntermediateAggregationResults,
|
||||
},
|
||||
}
|
||||
|
||||
impl IntermediateBucketResult {
|
||||
@@ -509,6 +528,18 @@ impl IntermediateBucketResult {
|
||||
req.sub_aggregation(),
|
||||
limits,
|
||||
),
|
||||
IntermediateBucketResult::Filter {
|
||||
doc_count,
|
||||
sub_aggregations,
|
||||
} => {
|
||||
// Convert sub-aggregation results to final format
|
||||
let final_sub_aggregations = sub_aggregations
|
||||
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
|
||||
Ok(BucketResult::Filter(FilterBucketResult {
|
||||
doc_count,
|
||||
sub_aggregations: final_sub_aggregations,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,6 +593,19 @@ impl IntermediateBucketResult {
|
||||
|
||||
*buckets_left = buckets?;
|
||||
}
|
||||
(
|
||||
IntermediateBucketResult::Filter {
|
||||
doc_count: doc_count_left,
|
||||
sub_aggregations: sub_aggs_left,
|
||||
},
|
||||
IntermediateBucketResult::Filter {
|
||||
doc_count: doc_count_right,
|
||||
sub_aggregations: sub_aggs_right,
|
||||
},
|
||||
) => {
|
||||
*doc_count_left += doc_count_right;
|
||||
sub_aggs_left.merge_fruits(sub_aggs_right)?;
|
||||
}
|
||||
(IntermediateBucketResult::Range(_), _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
@@ -571,6 +615,9 @@ impl IntermediateBucketResult {
|
||||
(IntermediateBucketResult::Terms { .. }, _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
(IntermediateBucketResult::Filter { .. }, _) => {
|
||||
panic!("try merge on different types")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,15 +2,13 @@ use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::Dictionary;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
@@ -97,6 +95,32 @@ pub struct CardinalityAggregationReq {
|
||||
pub missing: Option<Key>,
|
||||
}
|
||||
|
||||
/// Contains all information required by the SegmentCardinalityCollector to perform the
|
||||
/// cardinality aggregation on a segment.
|
||||
pub struct CardinalityAggReqData {
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// The column_type of the field.
|
||||
pub column_type: ColumnType,
|
||||
/// The string dictionary column if the field is of type string.
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_value_for_accessor: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The aggregation request.
|
||||
pub req: CardinalityAggregationReq,
|
||||
}
|
||||
|
||||
impl CardinalityAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
impl CardinalityAggregationReq {
|
||||
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
@@ -115,47 +139,44 @@ impl CardinalityAggregationReq {
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
missing: Option<Key>,
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: Default::default(),
|
||||
column_type,
|
||||
accessor_idx,
|
||||
missing: missing.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_accessor: &mut AggregationWithAccessor,
|
||||
agg_data: &mut CardinalityAggReqData,
|
||||
) {
|
||||
if let Some(missing) = agg_accessor.missing_value_for_accessor {
|
||||
agg_accessor.column_block_accessor.fetch_block_with_missing(
|
||||
if let Some(missing) = agg_data.missing_value_for_accessor {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_accessor.accessor,
|
||||
&agg_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_accessor
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_accessor.accessor);
|
||||
.fetch_block(docs, &agg_data.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
agg_with_accessor: &AggregationWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
if self.column_type == ColumnType::Str {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = agg_with_accessor
|
||||
let dict = req_data
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
.map(|el| el.dictionary())
|
||||
@@ -180,10 +201,10 @@ impl SegmentCardinalityCollector {
|
||||
})?;
|
||||
if has_missing {
|
||||
// Replace missing with the actual value provided
|
||||
let missing_key = self
|
||||
.missing
|
||||
.as_ref()
|
||||
.expect("Found sentinel value u64::MAX for term_ord but `missing` is not set");
|
||||
let missing_key =
|
||||
req_data.req.missing.as_ref().expect(
|
||||
"Found sentinel value u64::MAX for term_ord but `missing` is not set",
|
||||
);
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
self.cardinality.sketch.insert_any(&missing);
|
||||
@@ -209,13 +230,13 @@ impl SegmentCardinalityCollector {
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
let name = req_data.name.to_string();
|
||||
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -227,26 +248,26 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
self.fetch_block_with_field(docs, bucket_agg_accessor);
|
||||
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
|
||||
self.fetch_block_with_field(docs, req_data);
|
||||
|
||||
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
let col_block_accessor = &req_data.column_block_accessor;
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
self.entries.insert(term_ord);
|
||||
}
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = bucket_agg_accessor
|
||||
} else if req_data.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = req_data
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
|
||||
@@ -4,12 +4,11 @@ use std::mem;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
@@ -63,7 +62,7 @@ impl ExtendedStatsAggregation {
|
||||
|
||||
/// Extended stats contains a collection of statistics
|
||||
/// they extends stats adding variance, standard deviation
|
||||
/// and bound informations
|
||||
/// and bound information
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExtendedStats {
|
||||
/// The number of documents.
|
||||
@@ -348,20 +347,20 @@ impl SegmentExtendedStatsCollector {
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
agg_accessor: &mut AggregationWithAccessor,
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
agg_accessor.column_block_accessor.fetch_block_with_missing(
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_accessor.accessor,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
agg_accessor
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_accessor.accessor);
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
for val in agg_accessor.column_block_accessor.iter_vals() {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
@@ -372,10 +371,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
|
||||
@@ -390,12 +389,12 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = self.missing {
|
||||
let mut has_val = false;
|
||||
for val in field.values_for_doc(doc) {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
has_val = true;
|
||||
@@ -405,7 +404,7 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in field.values_for_doc(doc) {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
@@ -418,10 +417,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
self.collect_block_with_field(docs, field);
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
@@ -44,6 +45,35 @@ pub use top_hits::*;
|
||||
|
||||
use crate::schema::OwnedValue;
|
||||
|
||||
/// Contains all information required by metric aggregations like avg, min, max, sum, stats,
|
||||
/// extended_stats, count, percentiles.
|
||||
#[repr(C)]
|
||||
pub struct MetricAggReqData {
|
||||
/// True if the field is of number or date type.
|
||||
pub is_number_or_date_type: bool,
|
||||
/// The type of the field.
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// Used when converting to intermediate result
|
||||
pub collecting_for: StatsType,
|
||||
/// The missing value
|
||||
pub missing: Option<f64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl MetricAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Single-metric aggregations use this common result structure.
|
||||
///
|
||||
/// Main reason to wrap it in value is to match elasticsearch output structure.
|
||||
|
||||
@@ -3,12 +3,11 @@ use std::fmt::Debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
@@ -112,7 +111,8 @@ impl PercentilesAggregationReq {
|
||||
&self.field
|
||||
}
|
||||
|
||||
fn validate(&self) -> crate::Result<()> {
|
||||
/// Validates the request parameters.
|
||||
pub fn validate(&self) -> crate::Result<()> {
|
||||
if let Some(percents) = self.percents.as_ref() {
|
||||
let all_in_range = percents
|
||||
.iter()
|
||||
@@ -133,10 +133,8 @@ impl PercentilesAggregationReq {
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentPercentilesCollector {
|
||||
field_type: ColumnType,
|
||||
pub(crate) percentiles: PercentilesCollector,
|
||||
pub(crate) accessor_idx: usize,
|
||||
missing: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -231,43 +229,32 @@ impl PercentilesCollector {
|
||||
}
|
||||
|
||||
impl SegmentPercentilesCollector {
|
||||
pub fn from_req_and_validate(
|
||||
req: &PercentilesAggregationReq,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
) -> crate::Result<Self> {
|
||||
req.validate()?;
|
||||
let missing = req
|
||||
.missing
|
||||
.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
|
||||
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
field_type,
|
||||
percentiles: PercentilesCollector::new(),
|
||||
accessor_idx,
|
||||
missing,
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
agg_accessor: &mut AggregationWithAccessor,
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
agg_accessor.column_block_accessor.fetch_block_with_missing(
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_accessor.accessor,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
agg_accessor
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_accessor.accessor);
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for val in agg_accessor.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
@@ -277,10 +264,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
|
||||
|
||||
results.push(
|
||||
@@ -295,24 +282,24 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
|
||||
if let Some(missing) = self.missing {
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in field.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.percentiles
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in field.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
@@ -324,10 +311,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
self.collect_block_with_field(docs, field);
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,11 @@ use std::fmt::Debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
@@ -166,74 +165,65 @@ impl IntermediateStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) enum SegmentStatsType {
|
||||
/// The type of stats aggregation to perform.
|
||||
/// Note that not all stats types are supported in the stats aggregation.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum StatsType {
|
||||
/// The average of the values.
|
||||
Average,
|
||||
/// The count of the values.
|
||||
Count,
|
||||
/// The maximum value.
|
||||
Max,
|
||||
/// The minimum value.
|
||||
Min,
|
||||
/// The stats (count, sum, min, max, avg) of the values.
|
||||
Stats,
|
||||
/// The extended stats (count, sum, min, max, avg, sum_of_squares, variance, std_deviation,
|
||||
ExtendedStats(Option<f64>), // sigma
|
||||
/// The sum of the values.
|
||||
Sum,
|
||||
/// The percentiles of the values.
|
||||
Percentiles,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
missing: Option<u64>,
|
||||
field_type: ColumnType,
|
||||
pub(crate) collecting_for: SegmentStatsType,
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
}
|
||||
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(
|
||||
field_type: ColumnType,
|
||||
collecting_for: SegmentStatsType,
|
||||
accessor_idx: usize,
|
||||
missing: Option<f64>,
|
||||
) -> Self {
|
||||
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
pub fn from_req(accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
field_type,
|
||||
collecting_for,
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
missing,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
agg_accessor: &mut AggregationWithAccessor,
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
agg_accessor.column_block_accessor.fetch_block_with_missing(
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_accessor.accessor,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
agg_accessor
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_accessor.accessor);
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
if [
|
||||
ColumnType::I64,
|
||||
ColumnType::U64,
|
||||
ColumnType::F64,
|
||||
ColumnType::DateTime,
|
||||
]
|
||||
.contains(&self.field_type)
|
||||
{
|
||||
for val in agg_accessor.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
if req_data.is_number_or_date_type {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in agg_accessor.column_block_accessor.iter_vals() {
|
||||
for _val in req_data.column_block_accessor.iter_vals() {
|
||||
// we ignore the value and simply record that we got something
|
||||
self.stats.collect(0.0);
|
||||
}
|
||||
@@ -245,27 +235,28 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let req = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let name = req.name.clone();
|
||||
|
||||
let intermediate_metric_result = match self.collecting_for {
|
||||
SegmentStatsType::Average => {
|
||||
let intermediate_metric_result = match req.collecting_for {
|
||||
StatsType::Average => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
|
||||
}
|
||||
SegmentStatsType::Count => {
|
||||
StatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
|
||||
}
|
||||
SegmentStatsType::Max => {
|
||||
IntermediateMetricResult::Max(IntermediateMax::from_collector(*self))
|
||||
}
|
||||
SegmentStatsType::Min => {
|
||||
IntermediateMetricResult::Min(IntermediateMin::from_collector(*self))
|
||||
}
|
||||
SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
SegmentStatsType::Sum => {
|
||||
IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self))
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
|
||||
_ => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Unsupported stats type for stats aggregation: {:?}",
|
||||
req.collecting_for
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -281,23 +272,23 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
|
||||
if let Some(missing) = self.missing {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in field.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.stats
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in field.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
}
|
||||
@@ -309,10 +300,10 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
self.collect_block_with_field(docs, field);
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,15 +9,41 @@ use serde::ser::SerializeMap;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use super::{TopHitsMetricResult, TopHitsVecEntry};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::bucket::Order;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
// duplicate import removed; already imported above
|
||||
|
||||
/// Contains all information required by the TopHitsSegmentCollector to perform the
|
||||
/// top_hits aggregation on a segment.
|
||||
#[derive(Default)]
|
||||
pub struct TopHitsAggReqData {
|
||||
/// The accessors to access the fast field values.
|
||||
pub accessors: Vec<(Column<u64>, ColumnType)>,
|
||||
/// The accessors to access the fast field values for retrieving document fields.
|
||||
pub value_accessors: HashMap<String, Vec<DynamicColumn>>,
|
||||
/// The ordinal of the segment this request data is for.
|
||||
pub segment_ordinal: SegmentOrdinal,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The top_hits aggregation request.
|
||||
pub req: TopHitsAggregationReq,
|
||||
}
|
||||
|
||||
impl TopHitsAggReqData {
|
||||
/// Estimate the memory consumption of this struct in bytes.
|
||||
pub fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
/// # Top Hits
|
||||
///
|
||||
@@ -433,7 +459,7 @@ impl Eq for DocSortValuesAndFields {}
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct TopHitsTopNComputer {
|
||||
req: TopHitsAggregationReq,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for TopHitsTopNComputer {
|
||||
@@ -457,7 +483,7 @@ impl TopHitsTopNComputer {
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
|
||||
for doc in other_fruit.top_n.into_vec() {
|
||||
self.collect(doc.feature, doc.doc);
|
||||
self.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -469,9 +495,9 @@ impl TopHitsTopNComputer {
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|doc| TopHitsVecEntry {
|
||||
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
|
||||
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
|
||||
doc_value_fields: doc
|
||||
.feature
|
||||
.sort_key
|
||||
.doc_value_fields
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
@@ -492,7 +518,7 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -519,7 +545,7 @@ impl TopHitsSegmentCollector {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
top_hits_computer.collect(
|
||||
DocSortValuesAndFields {
|
||||
sorts: res.feature,
|
||||
sorts: res.sort_key,
|
||||
doc_value_fields,
|
||||
},
|
||||
res.doc,
|
||||
@@ -566,23 +592,18 @@ impl TopHitsSegmentCollector {
|
||||
impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
|
||||
let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors;
|
||||
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
|
||||
.agg
|
||||
.agg
|
||||
.as_top_hits()
|
||||
.expect("aggregation request must be of type top hits");
|
||||
let value_accessors = &req_data.value_accessors;
|
||||
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(
|
||||
self.into_top_hits_collector(value_accessors, tophits_req),
|
||||
self.into_top_hits_collector(value_accessors, &req_data.req),
|
||||
);
|
||||
results.push(
|
||||
name,
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
)
|
||||
}
|
||||
@@ -591,32 +612,22 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
|
||||
.agg
|
||||
.agg
|
||||
.as_top_hits()
|
||||
.expect("aggregation request must be of type top hits");
|
||||
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
|
||||
self.collect_with(doc_id, tophits_req, accessors)?;
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
|
||||
.agg
|
||||
.agg
|
||||
.as_top_hits()
|
||||
.expect("aggregation request must be of type top hits");
|
||||
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
// TODO: Consider getting fields with the column block accessor.
|
||||
for doc in docs {
|
||||
self.collect_with(*doc, tophits_req, accessors)?;
|
||||
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -635,6 +646,7 @@ mod tests {
|
||||
use crate::aggregation::bucket::tests::get_test_index_from_docs;
|
||||
use crate::aggregation::tests::get_test_index_from_values;
|
||||
use crate::aggregation::AggregationCollector;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::ComparableDoc;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::OwnedValue;
|
||||
@@ -650,7 +662,7 @@ mod tests {
|
||||
|
||||
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
|
||||
super::TopHitsTopNComputer {
|
||||
top_n: super::TopNComputer::new(capacity),
|
||||
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
|
||||
req: Default::default(),
|
||||
}
|
||||
}
|
||||
@@ -764,12 +776,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
|
||||
let docs = vec![
|
||||
ComparableDoc::<_, _, false> {
|
||||
ComparableDoc::<_, _> {
|
||||
doc: crate::DocAddress {
|
||||
segment_ord: 0,
|
||||
doc_id: 0,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(1),
|
||||
order: Order::Asc,
|
||||
@@ -782,7 +794,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 2,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(3),
|
||||
order: Order::Asc,
|
||||
@@ -795,7 +807,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 1,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(5),
|
||||
order: Order::Asc,
|
||||
@@ -807,7 +819,7 @@ mod tests {
|
||||
|
||||
let mut collector = collector_with_capacity(3);
|
||||
for doc in docs.clone() {
|
||||
collector.collect(doc.feature, doc.doc);
|
||||
collector.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
|
||||
let res = collector.into_final_result();
|
||||
@@ -817,15 +829,15 @@ mod tests {
|
||||
super::TopHitsMetricResult {
|
||||
hits: vec![
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[0].feature.sorts[0].value],
|
||||
sort: vec![docs[0].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[1].feature.sorts[0].value],
|
||||
sort: vec![docs[1].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[2].feature.sorts[0].value],
|
||||
sort: vec![docs[2].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
]
|
||||
|
||||
@@ -127,9 +127,10 @@
|
||||
//! [`AggregationResults`](agg_result::AggregationResults) via the
|
||||
//! [`into_final_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_result) method.
|
||||
|
||||
mod accessor_helpers;
|
||||
mod agg_data;
|
||||
mod agg_limits;
|
||||
pub mod agg_req;
|
||||
mod agg_req_with_accessor;
|
||||
pub mod agg_result;
|
||||
pub mod bucket;
|
||||
mod buf_collector;
|
||||
@@ -140,7 +141,6 @@ pub mod intermediate_agg_result;
|
||||
pub mod metric;
|
||||
|
||||
mod segment_agg_result;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -160,6 +160,28 @@ use itertools::Itertools;
|
||||
use serde::de::{self, Visitor};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
|
||||
/// Context parameters for aggregation execution
|
||||
///
|
||||
/// This struct holds shared resources needed during aggregation execution:
|
||||
/// - `limits`: Memory and bucket limits for the aggregation
|
||||
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
|
||||
#[derive(Clone, Default)]
|
||||
pub struct AggContextParams {
|
||||
/// Aggregation limits (memory and bucket count)
|
||||
pub limits: AggregationLimitsGuard,
|
||||
/// Tokenizer manager for query string parsing
|
||||
pub tokenizers: TokenizerManager,
|
||||
}
|
||||
|
||||
impl AggContextParams {
|
||||
/// Create new aggregation context parameters
|
||||
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
|
||||
Self { limits, tokenizers }
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
|
||||
let parsed = value
|
||||
.parse::<f64>()
|
||||
@@ -257,80 +279,6 @@ where D: Deserializer<'de> {
|
||||
deserializer.deserialize_any(StringOrFloatVisitor)
|
||||
}
|
||||
|
||||
/// Represents an associative array `(key => values)` in a very efficient manner.
|
||||
#[derive(PartialEq, Serialize, Deserialize)]
|
||||
pub(crate) struct VecWithNames<T> {
|
||||
pub(crate) values: Vec<T>,
|
||||
keys: Vec<String>,
|
||||
}
|
||||
|
||||
impl<T: Clone> Clone for VecWithNames<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
values: self.values.clone(),
|
||||
keys: self.keys.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for VecWithNames<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
values: Default::default(),
|
||||
keys: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug> std::fmt::Debug for VecWithNames<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_map().entries(self.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<HashMap<String, T>> for VecWithNames<T> {
|
||||
fn from(map: HashMap<String, T>) -> Self {
|
||||
VecWithNames::from_entries(map.into_iter().collect_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> VecWithNames<T> {
|
||||
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
|
||||
// Sort to ensure order of elements match across multiple instances
|
||||
entries.sort_by(|left, right| left.0.cmp(&right.0));
|
||||
let mut data = Vec::with_capacity(entries.len());
|
||||
let mut data_names = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
data_names.push(entry.0);
|
||||
data.push(entry.1);
|
||||
}
|
||||
VecWithNames {
|
||||
values: data,
|
||||
keys: data_names,
|
||||
}
|
||||
}
|
||||
fn iter(&self) -> impl Iterator<Item = (&str, &T)> + '_ {
|
||||
self.keys().zip(self.values.iter())
|
||||
}
|
||||
fn keys(&self) -> impl Iterator<Item = &str> + '_ {
|
||||
self.keys.iter().map(|key| key.as_str())
|
||||
}
|
||||
fn values_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
|
||||
self.values.iter_mut()
|
||||
}
|
||||
fn is_empty(&self) -> bool {
|
||||
self.keys.is_empty()
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.keys.len()
|
||||
}
|
||||
fn get(&self, name: &str) -> Option<&T> {
|
||||
self.keys()
|
||||
.position(|key| key == name)
|
||||
.map(|pos| &self.values[pos])
|
||||
}
|
||||
}
|
||||
|
||||
/// The serialized key is used in a `HashMap`.
|
||||
pub type SerializedKey = String;
|
||||
|
||||
@@ -464,7 +412,10 @@ mod tests {
|
||||
query: Option<(&str, &str)>,
|
||||
limits: AggregationLimitsGuard,
|
||||
) -> crate::Result<Value> {
|
||||
let collector = AggregationCollector::from_aggs(agg_req, limits);
|
||||
let collector = AggregationCollector::from_aggs(
|
||||
agg_req,
|
||||
AggContextParams::new(limits, index.tokenizers().clone()),
|
||||
);
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
@@ -6,48 +6,38 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub(crate) use super::agg_limits::AggregationLimitsGuard;
|
||||
use super::agg_req::AggregationVariants;
|
||||
use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor};
|
||||
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::metric::{
|
||||
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
|
||||
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
|
||||
SumAggregation,
|
||||
};
|
||||
use crate::aggregation::bucket::TermMissingAgg;
|
||||
use crate::aggregation::metric::{
|
||||
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
|
||||
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
/// A SegmentAggregationCollector is used to collect aggregation results.
|
||||
pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
|
||||
/// This method ensures those staged docs will be collected.
|
||||
fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
|
||||
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait CollectorClone {
|
||||
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
|
||||
pub trait CollectorClone {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
|
||||
}
|
||||
|
||||
@@ -65,119 +55,6 @@ impl Clone for Box<dyn SegmentAggregationCollector> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_agg_collector(
|
||||
req: &mut AggregationsWithAccessor,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
// Single collector special case
|
||||
if req.aggs.len() == 1 {
|
||||
let req = &mut req.aggs.values[0];
|
||||
let accessor_idx = 0;
|
||||
return build_single_agg_segment_collector(req, accessor_idx);
|
||||
}
|
||||
|
||||
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
|
||||
Ok(Box::new(agg))
|
||||
}
|
||||
|
||||
pub(crate) fn build_single_agg_segment_collector(
|
||||
req: &mut AggregationWithAccessor,
|
||||
accessor_idx: usize,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
use AggregationVariants::*;
|
||||
match &req.agg.agg {
|
||||
Terms(terms_req) => {
|
||||
if req.accessors.is_empty() {
|
||||
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
|
||||
terms_req,
|
||||
&mut req.sub_aggregation,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?))
|
||||
} else {
|
||||
Ok(Box::new(TermMissingAgg::new(
|
||||
accessor_idx,
|
||||
&mut req.sub_aggregation,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||
range_req,
|
||||
&mut req.sub_aggregation,
|
||||
&mut req.limits,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?)),
|
||||
Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
histogram.clone(),
|
||||
&mut req.sub_aggregation,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?)),
|
||||
DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
histogram.to_histogram_req()?,
|
||||
&mut req.sub_aggregation,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?)),
|
||||
Average(AverageAggregation { missing, .. }) => {
|
||||
Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Average,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
)))
|
||||
}
|
||||
Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Count,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
))),
|
||||
Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Max,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
))),
|
||||
Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Min,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
))),
|
||||
Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Stats,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
))),
|
||||
ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(
|
||||
SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing),
|
||||
)),
|
||||
Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
req.field_type,
|
||||
SegmentStatsType::Sum,
|
||||
accessor_idx,
|
||||
*missing,
|
||||
))),
|
||||
Percentiles(percentiles_req) => Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(
|
||||
percentiles_req,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?,
|
||||
)),
|
||||
TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req(
|
||||
top_hits_req,
|
||||
accessor_idx,
|
||||
req.segment_ordinal,
|
||||
))),
|
||||
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
|
||||
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
|
||||
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
|
||||
@@ -197,11 +74,11 @@ impl Debug for GenericSegmentAggregationResultsCollector {
|
||||
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
for agg in self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_with_accessor, results)?;
|
||||
agg.add_intermediate_aggregation_result(agg_data, results)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -210,9 +87,9 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_with_accessor)?;
|
||||
self.collect_block(&[doc], agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -220,32 +97,19 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.collect_block(docs, agg_with_accessor)?;
|
||||
collector.collect_block(docs, agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.flush(agg_with_accessor)?;
|
||||
collector.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GenericSegmentAggregationResultsCollector {
|
||||
pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result<Self> {
|
||||
let aggs = req
|
||||
.aggs
|
||||
.values_mut()
|
||||
.enumerate()
|
||||
.map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx))
|
||||
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
|
||||
|
||||
Ok(GenericSegmentAggregationResultsCollector { aggs })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::{DocAddress, DocId, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where TScore: Clone + PartialOrd
|
||||
{
|
||||
pub(crate) fn new(
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
|
||||
CustomScoreTopCollector {
|
||||
custom_scorer,
|
||||
collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom segment scorer makes it possible to define any kind of score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`CustomScorer`].
|
||||
pub trait CustomSegmentScorer<TScore>: 'static {
|
||||
/// Computes the score of a specific `doc`.
|
||||
fn score(&mut self, doc: DocId) -> TScore;
|
||||
}
|
||||
|
||||
/// `CustomScorer` makes it possible to define any kind of score.
|
||||
///
|
||||
/// The `CustomerScorer` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the score at a segment scale.
|
||||
pub trait CustomScorer<TScore>: Sync {
|
||||
/// Type of the associated [`CustomSegmentScorer`].
|
||||
type Child: CustomSegmentScorer<TScore>;
|
||||
/// Builds a child scorer for a specific segment. The child scorer is associated with
|
||||
/// a specific segment.
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where
|
||||
TCustomScorer: CustomScorer<TScore> + Send + Sync,
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
|
||||
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
|
||||
Ok(CustomScoreTopSegmentCollector {
|
||||
segment_collector,
|
||||
segment_scorer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
|
||||
self.collector.merge_fruits(segment_fruits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
segment_collector: TopSegmentCollector<TScore>,
|
||||
segment_scorer: T,
|
||||
}
|
||||
|
||||
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
T: 'static + CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, _score: Score) {
|
||||
let score = self.segment_scorer.score(doc);
|
||||
self.segment_collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Vec<(TScore, DocAddress)> {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore, T> CustomScorer<TScore> for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Child = T;
|
||||
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore> CustomSegmentScorer<TScore> for F
|
||||
where F: 'static + FnMut(DocId) -> TScore
|
||||
{
|
||||
fn score(&mut self, doc: DocId) -> TScore {
|
||||
(self)(doc)
|
||||
}
|
||||
}
|
||||
@@ -484,7 +484,6 @@ impl FacetCounts {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
use std::iter;
|
||||
|
||||
use columnar::Dictionary;
|
||||
use rand::distributions::Uniform;
|
||||
|
||||
@@ -12,6 +12,7 @@ use std::marker::PhantomData;
|
||||
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
|
||||
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// The `FilterCollector` filters docs using a fast field value and a predicate.
|
||||
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
|
||||
///
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
|
||||
///
|
||||
/// assert_eq!(filtered_top_docs.len(), 0);
|
||||
@@ -104,6 +105,11 @@ where
|
||||
|
||||
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -120,6 +126,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
t_predicate_value: PhantomData,
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
t_predicate_value: PhantomData<TPredicateValue>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate, TPredicateValue>
|
||||
@@ -176,6 +184,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
@@ -218,7 +240,7 @@ where
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
@@ -258,6 +280,10 @@ where
|
||||
|
||||
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -274,6 +300,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
buffer: Vec::new(),
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -296,6 +323,7 @@ where TPredicate: 'static
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
buffer: Vec<u8>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
|
||||
@@ -334,6 +362,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
//! # let query = query_parser.parse_query("diary")?;
|
||||
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@@ -83,11 +83,15 @@
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
|
||||
|
||||
mod count_collector;
|
||||
pub use self::count_collector::Count;
|
||||
|
||||
/// Sort keys
|
||||
pub mod sort_key;
|
||||
|
||||
mod histogram_collector;
|
||||
pub use histogram_collector::HistogramCollector;
|
||||
|
||||
@@ -95,16 +99,13 @@ mod multi_collector;
|
||||
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
|
||||
|
||||
mod top_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
pub use self::top_score_collector::{TopDocs, TopNComputer};
|
||||
|
||||
mod custom_score_top_collector;
|
||||
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
|
||||
|
||||
mod tweak_score_top_collector;
|
||||
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
|
||||
mod sort_key_top_collector;
|
||||
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
mod facet_collector;
|
||||
pub use self::facet_collector::{FacetCollector, FacetCounts};
|
||||
use crate::query::Weight;
|
||||
@@ -145,6 +146,11 @@ pub trait Collector: Sync + Send {
|
||||
/// Type of the `SegmentCollector` associated with this collector.
|
||||
type Child: SegmentCollector;
|
||||
|
||||
/// Returns an error if the schema is not compatible with the collector.
|
||||
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `set_segment` is called before beginning to enumerate
|
||||
/// on this segment.
|
||||
fn for_segment(
|
||||
@@ -170,41 +176,50 @@ pub trait Collector: Sync + Send {
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
|
||||
let with_scoring = self.requires_scoring();
|
||||
let mut segment_collector = self.for_segment(segment_ord, reader)?;
|
||||
|
||||
match (reader.alive_bitset(), self.requires_scoring()) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
|
||||
Ok(segment_collector.harvest())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
|
||||
segment_collector: &mut TSegmentCollector,
|
||||
weight: &dyn Weight,
|
||||
reader: &SegmentReader,
|
||||
with_scoring: bool,
|
||||
) -> crate::Result<()> {
|
||||
match (reader.alive_bitset(), with_scoring) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
|
||||
type Fruit = Option<TSegmentCollector::Fruit>;
|
||||
|
||||
@@ -214,6 +229,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if let Some(segment_collector) = self {
|
||||
segment_collector.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
self.map(|segment_collector| segment_collector.harvest())
|
||||
}
|
||||
@@ -224,6 +245,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
|
||||
|
||||
type Child = Option<<TCollector as Collector>::Child>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
if let Some(underlying_collector) = self {
|
||||
underlying_collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -299,6 +327,12 @@ where
|
||||
type Fruit = (Left::Fruit, Right::Fruit);
|
||||
type Child = (Left::Child, Right::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -342,6 +376,11 @@ where
|
||||
self.1.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest())
|
||||
}
|
||||
@@ -358,6 +397,13 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -407,6 +453,12 @@ where
|
||||
self.2.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest(), self.2.harvest())
|
||||
}
|
||||
@@ -424,6 +476,14 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
self.3.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -482,6 +542,13 @@ where
|
||||
self.3.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
self.3.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(
|
||||
self.0.harvest(),
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::ops::Deref;
|
||||
|
||||
use super::{Collector, SegmentCollector};
|
||||
use crate::collector::Fruit;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
|
||||
|
||||
/// MultiFruit keeps Fruits from every nested Collector
|
||||
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
|
||||
type Fruit = Box<dyn Fruit>;
|
||||
type Child = Box<dyn BoxableSegmentCollector>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
|
||||
/// let searcher = reader.searcher();
|
||||
///
|
||||
/// let mut collectors = MultiCollector::new();
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
/// let count_handle = collectors.add_collector(Count);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary").unwrap();
|
||||
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
|
||||
type Fruit = MultiFruit;
|
||||
type Child = MultiCollectorChild;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
for collector in &self.collector_wrappers {
|
||||
collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
for child in &mut self.children {
|
||||
child.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> MultiFruit {
|
||||
MultiFruit {
|
||||
sub_fruits: self
|
||||
@@ -293,7 +311,7 @@ mod tests {
|
||||
let query = TermQuery::new(term, IndexRecordOption::Basic);
|
||||
|
||||
let mut collectors = MultiCollector::new();
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
let count_handler = collectors.add_collector(Count);
|
||||
let mut multifruits = searcher.search(&query, &collectors).unwrap();
|
||||
|
||||
|
||||
393
src/collector/sort_key/mod.rs
Normal file
393
src/collector/sort_key/mod.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
mod order;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, QueryParser};
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
|
||||
|
||||
fn make_index() -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id = schema_builder.add_u64_field("id", FAST);
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
|
||||
let altitude = schema_builder.add_f64_field("altitude", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
index_writer.set_merge_policy(Box::new(NoMergePolicy));
|
||||
for doc in docs {
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
create_segment(
|
||||
&index,
|
||||
vec![
|
||||
doc!(
|
||||
id => 0_u64,
|
||||
city => "austin",
|
||||
catchphrase => "Hills, Barbeque, Glow",
|
||||
altitude => 149.0,
|
||||
),
|
||||
doc!(
|
||||
id => 1_u64,
|
||||
city => "greenville",
|
||||
catchphrase => "Grow, Glow, Glow",
|
||||
altitude => 27.0,
|
||||
),
|
||||
],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 2_u64,
|
||||
city => "tokyo",
|
||||
catchphrase => "Glow, Glow, Glow",
|
||||
altitude => 40.0,
|
||||
)],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 3_u64,
|
||||
catchphrase => "No, No, No",
|
||||
altitude => 0.0,
|
||||
)],
|
||||
)?;
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
|
||||
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
|
||||
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
|
||||
searcher
|
||||
.search(&AllQuery, &DocSetCollector)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|doc_address| {
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
|
||||
.fast_fields()
|
||||
.u64("id")
|
||||
.unwrap();
|
||||
(doc_address, column.first(doc_address.doc_id).unwrap())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
#[track_caller]
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: Vec<(Option<String>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::for_doc_range(doc_range)
|
||||
.order_by((SortByString::for_field("city"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..3,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..2,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..1,
|
||||
vec![(Some("austin".to_string()), 0)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..1,
|
||||
vec![(Some("tokyo".to_owned()), 2)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_f64() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
expected: Vec<(Option<f64>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::with_limit(3)
|
||||
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
|
||||
let field = index.schema().get_field("catchphrase").unwrap();
|
||||
let query_parser = QueryParser::for_index(index, vec![field]);
|
||||
let text_query = query_parser.parse_query("glow")?;
|
||||
|
||||
Ok(searcher
|
||||
.search(&text_query, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(score, doc)| (score, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Desc)?,
|
||||
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc)?,
|
||||
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, Option<String>);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByString::for_field("city"), city_order),
|
||||
));
|
||||
Ok(searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(f, doc)| (f, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_string_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..64_usize,
|
||||
offset in 0..64_usize,
|
||||
segments_terms in
|
||||
proptest::collection::vec(
|
||||
proptest::collection::vec(0..32_u8, 1..32_usize),
|
||||
0..8_usize,
|
||||
)
|
||||
) {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
|
||||
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
|
||||
// represents terms.
|
||||
for segment_terms in segments_terms.into_iter() {
|
||||
for term in segment_terms.into_iter() {
|
||||
let term = format!("{term:0>3}");
|
||||
index_writer.add_document(doc!(
|
||||
city => term,
|
||||
))?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
}
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by_string_fast_field("city", order))?;
|
||||
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
|
||||
// Get the term for this address.
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
|
||||
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
|
||||
let mut city = Vec::new();
|
||||
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
|
||||
String::try_from(city).unwrap()
|
||||
});
|
||||
(value, doc_address)
|
||||
});
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = if order.is_desc() {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
} else {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
prop_assert_eq!(
|
||||
expected_docs,
|
||||
top_n_results
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
348
src/collector/sort_key/order.rs
Normal file
348
src/collector/sort_key/order.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
/// Comparator trait defining the order in which documents should be ordered.
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
}
|
||||
|
||||
/// With the natural comparator, the top k collector will return
|
||||
/// the top documents in decreasing order.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalComparator;
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order.
|
||||
///
|
||||
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
|
||||
/// first.
|
||||
///
|
||||
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
|
||||
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
impl<T> Comparator<T> for ReverseComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order, but considers None as having the lowest value.
|
||||
///
|
||||
/// This is usually what is wanted when sorting by a field in an ascending order.
|
||||
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
|
||||
/// cheapest items first, and the items without a price at last.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct ReverseNoneIsLowerComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
|
||||
where ReverseComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Less,
|
||||
(Some(_), None) => Ordering::Greater,
|
||||
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
|
||||
pub enum ComparatorEnum {
|
||||
/// Natural order (See [NaturalComparator])
|
||||
#[default]
|
||||
Natural,
|
||||
/// Reverse order (See [ReverseComparator])
|
||||
Reverse,
|
||||
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
|
||||
ReverseNoneLower,
|
||||
}
|
||||
|
||||
impl From<Order> for ComparatorEnum {
|
||||
fn from(order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => ComparatorEnum::ReverseNoneLower,
|
||||
Order::Desc => ComparatorEnum::Natural,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Comparator<T> for ComparatorEnum
|
||||
where
|
||||
ReverseNoneIsLowerComparator: Comparator<T>,
|
||||
NaturalComparator: Comparator<T>,
|
||||
ReverseComparator: Comparator<T>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
match self {
|
||||
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
|
||||
for (LeftComparator, RightComparator)
|
||||
where
|
||||
LeftComparator: Comparator<Head>,
|
||||
RightComparator: Comparator<Tail>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, (Type2, (Type3, Type4)))>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
rhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
|
||||
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, Type2, Type3, Type4)>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, Type2, Type3, Type4),
|
||||
rhs: &(Type1, Type2, Type3, Type4),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1.into()
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A segment sort key computer with a custom ordering.
|
||||
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
|
||||
segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
comparator: TComparator,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
|
||||
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.comparator.compare(left, right)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
self.segment_sort_key_computer
|
||||
.convert_segment_sort_key(sort_key)
|
||||
}
|
||||
}
|
||||
77
src/collector/sort_key/sort_by_score.rs
Normal file
77
src/collector/sort_key/sort_by_score.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::{DocAddress, DocId, Score};
|
||||
|
||||
/// Sort by similarity score.
|
||||
#[derive(Clone, Debug, Copy)]
|
||||
pub struct SortBySimilarityScore;
|
||||
|
||||
impl SortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type Child = SortBySimilarityScore;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
_segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
Ok(SortBySimilarityScore)
|
||||
}
|
||||
|
||||
// Sorting by score is special in that it allows for the Block-Wand optimization.
|
||||
fn collect_segment_top_k(
|
||||
&self,
|
||||
k: usize,
|
||||
weight: &dyn crate::query::Weight,
|
||||
reader: &crate::SegmentReader,
|
||||
segment_ord: u32,
|
||||
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
|
||||
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
|
||||
TopNComputer::new_with_comparator(k, self.comparator());
|
||||
|
||||
if let Some(alive_bitset) = reader.alive_bitset() {
|
||||
let mut threshold = Score::MIN;
|
||||
top_n.threshold = Some(threshold);
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
if alive_bitset.is_deleted(doc) {
|
||||
return threshold;
|
||||
}
|
||||
top_n.push(score, doc);
|
||||
threshold = top_n.threshold.unwrap_or(Score::MIN);
|
||||
threshold
|
||||
})?;
|
||||
} else {
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
top_n.push(score, doc);
|
||||
top_n.threshold.unwrap_or(Score::MIN)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(top_n
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type SegmentSortKey = Score;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
}
|
||||
98
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
98
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use columnar::Column;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// Sorts by a fast value (u64, i64, f64, bool).
|
||||
///
|
||||
/// The field must appear explicitly in the schema, with the right type, and declared as
|
||||
/// a fast field..
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByStaticFastValue<T: FastValue> {
|
||||
field: String,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortByStaticFastValue<T> {
|
||||
/// Creates a new `SortByStaticFastValue` instance for the given field.
|
||||
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
|
||||
Self {
|
||||
field: column_name.to_string(),
|
||||
typ: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
|
||||
type Child = SortByFastValueSegmentSortKeyComputer<T>;
|
||||
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
|
||||
// At the segment sort key computer level, we rely on the u64 representation.
|
||||
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
|
||||
let field = schema.get_field(&self.field)?;
|
||||
let field_entry = schema.get_field_entry(field);
|
||||
if !field_entry.is_fast() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is not a fast field.",
|
||||
self.field,
|
||||
)));
|
||||
}
|
||||
let schema_type = field_entry.field_type().value_type();
|
||||
if schema_type != T::to_type() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
|
||||
&self.field,
|
||||
T::to_type()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
|
||||
let (sort_column, _sort_column_type) =
|
||||
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
|
||||
field_name: self.field.clone(),
|
||||
})?;
|
||||
Ok(SortByFastValueSegmentSortKeyComputer {
|
||||
sort_column,
|
||||
typ: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SortByFastValueSegmentSortKeyComputer<T> {
|
||||
sort_column: Column<u64>,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type SegmentSortKey = Option<u64>;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_column.first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key.map(T::from_u64)
|
||||
}
|
||||
}
|
||||
72
src/collector/sort_key/sort_by_string.rs
Normal file
72
src/collector/sort_key/sort_by_string.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use columnar::StrColumn;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Sort by the first value of a string column.
|
||||
///
|
||||
/// The string can be dynamic (coming from a json field)
|
||||
/// or static (being specificaly defined in the configuration).
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByString {
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl SortByString {
|
||||
/// Creates a new sort by string sort key computer.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
SortByString {
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByString {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type Child = ByStringColumnSegmentSortKeyComputer;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
|
||||
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ByStringColumnSegmentSortKeyComputer {
|
||||
str_column_opt: Option<StrColumn>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
str_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
let term_ord = term_ord_opt?;
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
str_column
|
||||
.dictionary()
|
||||
.ord_to_term(term_ord, &mut bytes)
|
||||
.ok()?;
|
||||
String::try_from(bytes).ok()
|
||||
}
|
||||
}
|
||||
631
src/collector/sort_key/sort_key_computer.rs
Normal file
631
src/collector/sort_key/sort_key_computer.rs
Normal file
@@ -0,0 +1,631 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, NaturalComparator};
|
||||
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
|
||||
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
/// A `SegmentSortKeyComputer` makes it possible to modify the default score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`SortKeyComputer`].
|
||||
pub trait SegmentSortKeyComputer: 'static {
|
||||
/// The final score being emitted.
|
||||
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
|
||||
|
||||
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
|
||||
///
|
||||
/// It is typically small like a `u64`, and is meant to be converted
|
||||
/// to the final score at the end of the collection of the segment.
|
||||
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Computes the sort key for the given document and score.
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
|
||||
|
||||
/// Computes the sort key and pushes the document in a TopN Computer.
|
||||
///
|
||||
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key = self.segment_sort_key(doc, score);
|
||||
top_n_computer.push(sort_key, doc);
|
||||
}
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
/// This method must be consistent with the `SortKey` ordering.
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
NaturalComparator.compare(left, right)
|
||||
}
|
||||
|
||||
/// Implementing this method makes it possible to avoid computing
|
||||
/// a sort_key entirely if we can assess that it won't pass a threshold
|
||||
/// with a partial computation.
|
||||
///
|
||||
/// This is currently used for lexicographic sorting.
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let sort_key = self.segment_sort_key(doc_id, score);
|
||||
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
|
||||
if cmp == Ordering::Less {
|
||||
None
|
||||
} else {
|
||||
Some((cmp, sort_key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a segment level sort key into the global sort key.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
|
||||
}
|
||||
|
||||
/// `SortKeyComputer` defines the sort key to be used by a TopK Collector.
|
||||
///
|
||||
/// The `SortKeyComputer` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the sort key at a segment scale.
|
||||
pub trait SortKeyComputer: Sync {
|
||||
/// The sort key type.
|
||||
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
|
||||
/// Type of the associated [`SegmentSortKeyComputer`].
|
||||
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
|
||||
/// Comparator type.
|
||||
type Comparator: Comparator<Self::SortKey>
|
||||
+ Comparator<<Self::Child as SegmentSortKeyComputer>::SegmentSortKey>
|
||||
+ 'static;
|
||||
|
||||
/// Checks whether the schema is compatible with the sort key computer.
|
||||
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the sort key comparator.
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
Self::Comparator::default()
|
||||
}
|
||||
|
||||
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
|
||||
/// If set to false, the similary score might not be computed (as an optimization),
|
||||
/// and the score fed in the segment sort key computer could take any value.
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND.
|
||||
fn collect_segment_top_k(
|
||||
&self,
|
||||
k: usize,
|
||||
weight: &dyn crate::query::Weight,
|
||||
reader: &crate::SegmentReader,
|
||||
segment_ord: u32,
|
||||
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
|
||||
let with_scoring = self.requires_scoring();
|
||||
let segment_sort_key_computer = self.segment_sort_key_computer(reader)?;
|
||||
let topn_computer = TopNComputer::new_with_comparator(k, self.comparator());
|
||||
let mut segment_top_key_collector = TopBySortKeySegmentCollector {
|
||||
topn_computer,
|
||||
segment_ord,
|
||||
segment_sort_key_computer,
|
||||
};
|
||||
default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?;
|
||||
Ok(segment_top_key_collector.harvest())
|
||||
}
|
||||
|
||||
/// Builds a child sort key computer for a specific segment.
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
|
||||
for (HeadSortKeyComputer, TailSortKeyComputer)
|
||||
where
|
||||
HeadSortKeyComputer: SortKeyComputer,
|
||||
TailSortKeyComputer: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
);
|
||||
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
|
||||
|
||||
type Comparator = (
|
||||
HeadSortKeyComputer::Comparator,
|
||||
TailSortKeyComputer::Comparator,
|
||||
);
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
(self.0.comparator(), self.1.comparator())
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((
|
||||
self.0.segment_sort_key_computer(segment_reader)?,
|
||||
self.1.segment_sort_key_computer(segment_reader)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Checks whether the schema is compatible with the sort key computer.
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
|
||||
/// If set to false, the similary score might not be computed (as an optimization),
|
||||
/// and the score fed in the segment sort key computer could take any value.
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring() || self.1.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
|
||||
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
|
||||
where
|
||||
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
HeadSegmentSortKeyComputer::SortKey,
|
||||
TailSegmentSortKeyComputer::SortKey,
|
||||
);
|
||||
type SegmentSortKey = (
|
||||
HeadSegmentSortKeyComputer::SegmentSortKey,
|
||||
TailSegmentSortKeyComputer::SegmentSortKey,
|
||||
);
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
/// By default, it uses the natural ordering.
|
||||
#[inline]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare_segment_sort_key(&left.0, &right.0)
|
||||
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key: Self::SegmentSortKey;
|
||||
if let Some(threshold) = &top_n_computer.threshold {
|
||||
if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) {
|
||||
sort_key = lazy_sort_key;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
sort_key = self.segment_sort_key(doc, score);
|
||||
};
|
||||
top_n_computer.append_doc(doc, sort_key);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
let head_sort_key = self.0.segment_sort_key(doc, score);
|
||||
let tail_sort_key = self.1.segment_sort_key(doc, score);
|
||||
(head_sort_key, tail_sort_key)
|
||||
}
|
||||
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let (head_threshold, tail_threshold) = threshold;
|
||||
let (head_cmp, head_sort_key) =
|
||||
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
|
||||
if head_cmp == Ordering::Equal {
|
||||
let (tail_cmp, tail_sort_key) =
|
||||
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
|
||||
Some((tail_cmp, (head_sort_key, tail_sort_key)))
|
||||
} else {
|
||||
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
|
||||
Some((head_cmp, (head_sort_key, tail_sort_key)))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
let (head_sort_key, tail_sort_key) = sort_key;
|
||||
(
|
||||
self.0.convert_segment_sort_key(head_sort_key),
|
||||
self.1.convert_segment_sort_key(tail_sort_key),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct is used as an adapter to take a sort key computer and map its score to another
|
||||
/// new sort key.
|
||||
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
|
||||
sort_key_computer: T,
|
||||
map: fn(PreviousSortKey) -> NewSortKey,
|
||||
}
|
||||
|
||||
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
|
||||
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
|
||||
where
|
||||
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
|
||||
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
NewScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
{
|
||||
type SortKey = NewScore;
|
||||
type SegmentSortKey = T::SegmentSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
self.sort_key_computer
|
||||
.accept_sort_key_lazy(doc_id, score, threshold)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
self.sort_key_computer
|
||||
.compute_sort_key_and_collect(doc, score, top_n_computer);
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
(self.map)(
|
||||
self.sort_key_computer
|
||||
.convert_segment_sort_key(segment_sort_key),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c,
|
||||
// ...) as the chain (a, (b, (c, ...)))
|
||||
|
||||
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3> SortKeyComputer
|
||||
for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3)
|
||||
where
|
||||
SortKeyComputer1: SortKeyComputer,
|
||||
SortKeyComputer2: SortKeyComputer,
|
||||
SortKeyComputer3: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
SortKeyComputer1::SortKey,
|
||||
SortKeyComputer2::SortKey,
|
||||
SortKeyComputer3::SortKey,
|
||||
);
|
||||
type Child = MappedSegmentSortKeyComputer<
|
||||
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
|
||||
type Comparator = (
|
||||
SortKeyComputer1::Comparator,
|
||||
SortKeyComputer2::Comparator,
|
||||
SortKeyComputer3::Comparator,
|
||||
);
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
(
|
||||
self.0.comparator(),
|
||||
self.1.comparator(),
|
||||
self.2.comparator(),
|
||||
)
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
|
||||
map,
|
||||
})
|
||||
}
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3, SortKeyComputer4> SortKeyComputer
|
||||
for (
|
||||
SortKeyComputer1,
|
||||
SortKeyComputer2,
|
||||
SortKeyComputer3,
|
||||
SortKeyComputer4,
|
||||
)
|
||||
where
|
||||
SortKeyComputer1: SortKeyComputer,
|
||||
SortKeyComputer2: SortKeyComputer,
|
||||
SortKeyComputer3: SortKeyComputer,
|
||||
SortKeyComputer4: SortKeyComputer,
|
||||
{
|
||||
type Child = MappedSegmentSortKeyComputer<
|
||||
<(
|
||||
SortKeyComputer1,
|
||||
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
|
||||
) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(
|
||||
SortKeyComputer2::SortKey,
|
||||
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
|
||||
),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
type SortKey = (
|
||||
SortKeyComputer1::SortKey,
|
||||
SortKeyComputer2::SortKey,
|
||||
SortKeyComputer3::SortKey,
|
||||
SortKeyComputer4::SortKey,
|
||||
);
|
||||
type Comparator = (
|
||||
SortKeyComputer1::Comparator,
|
||||
SortKeyComputer2::Comparator,
|
||||
SortKeyComputer3::Comparator,
|
||||
SortKeyComputer4::Comparator,
|
||||
);
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: (
|
||||
sort_key_computer1,
|
||||
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
|
||||
),
|
||||
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
|
||||
(sort_key1, sort_key2, sort_key3, sort_key4)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
self.3.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
|| self.1.requires_scoring()
|
||||
|| self.2.requires_scoring()
|
||||
|| self.3.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, SegmentF, TSortKey> SortKeyComputer for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
|
||||
SegmentF: 'static + FnMut(DocId) -> TSortKey,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type Child = SegmentF;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TSortKey> SegmentSortKeyComputer for F
|
||||
where
|
||||
F: 'static + FnMut(DocId) -> TSortKey,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
|
||||
(self)(doc)
|
||||
}
|
||||
|
||||
/// Convert a segment level score into the global level score.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Index, Order, SegmentReader};
|
||||
|
||||
fn build_test_index() -> Index {
|
||||
let schema = Schema::builder().build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
index_writer
|
||||
.add_document(crate::TantivyDocument::default())
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
index
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lazy_score_computer() {
|
||||
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
|
||||
let call_count = Arc::new(AtomicUsize::new(0));
|
||||
let call_count_clone = call_count.clone();
|
||||
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
|
||||
let call_count_new_clone = call_count_clone.clone();
|
||||
move |_doc: DocId| {
|
||||
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
"b"
|
||||
}
|
||||
};
|
||||
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
|
||||
let index = build_test_index();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let mut segment_sort_key_computer = lazy_score_computer
|
||||
.segment_sort_key_computer(searcher.segment_reader(0))
|
||||
.unwrap();
|
||||
let expected_sort_key = (200, "b");
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c"));
|
||||
assert!(sort_key_opt.is_none());
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a"));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c"));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lazy_score_computer_dynamic_ordering() {
|
||||
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
|
||||
let call_count = Arc::new(AtomicUsize::new(0));
|
||||
let call_count_clone = call_count.clone();
|
||||
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
|
||||
let call_count_new_clone = call_count_clone.clone();
|
||||
move |_doc: DocId| {
|
||||
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
2u32
|
||||
}
|
||||
};
|
||||
let lazy_score_computer = (
|
||||
(score_computer_primary, Order::Desc),
|
||||
(score_computer_secondary, Order::Asc),
|
||||
);
|
||||
let index = build_test_index();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let mut segment_sort_key_computer = lazy_score_computer
|
||||
.segment_sort_key_computer(searcher.segment_reader(0))
|
||||
.unwrap();
|
||||
let expected_sort_key = (200, 2u32);
|
||||
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32));
|
||||
assert!(sort_key_opt.is_none());
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
|
||||
}
|
||||
assert_eq!(
|
||||
segment_sort_key_computer.convert_segment_sort_key(expected_sort_key),
|
||||
(200u32, 2u32)
|
||||
);
|
||||
}
|
||||
}
|
||||
193
src/collector/sort_key_top_collector.rs
Normal file
193
src/collector/sort_key_top_collector.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{Collector, SegmentCollector, TopNComputer};
|
||||
use crate::query::Weight;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
|
||||
sort_key_computer: TSortKeyComputer,
|
||||
doc_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
|
||||
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
|
||||
TopBySortKeyCollector {
|
||||
sort_key_computer,
|
||||
doc_range,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
|
||||
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
|
||||
{
|
||||
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
type Child =
|
||||
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.sort_key_computer.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let segment_sort_key_computer = self
|
||||
.sort_key_computer
|
||||
.segment_sort_key_computer(segment_reader)?;
|
||||
let topn_computer = TopNComputer::new_with_comparator(
|
||||
self.doc_range.end,
|
||||
self.sort_key_computer.comparator(),
|
||||
);
|
||||
Ok(TopBySortKeySegmentCollector {
|
||||
topn_computer,
|
||||
segment_ord,
|
||||
segment_sort_key_computer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.sort_key_computer.requires_scoring()
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
|
||||
Ok(merge_top_k(
|
||||
segment_fruits.into_iter().flatten(),
|
||||
self.doc_range.clone(),
|
||||
self.sort_key_computer.comparator(),
|
||||
))
|
||||
}
|
||||
|
||||
fn collect_segment(
|
||||
&self,
|
||||
weight: &dyn Weight,
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
|
||||
let k = self.doc_range.end;
|
||||
let docs = self
|
||||
.sort_key_computer
|
||||
.collect_segment_top_k(k, weight, reader, segment_ord)?;
|
||||
Ok(docs)
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
|
||||
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
|
||||
doc_range: Range<usize>,
|
||||
comparator: C,
|
||||
) -> Vec<(TSortKey, D)> {
|
||||
if doc_range.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut top_collector: TopNComputer<TSortKey, D, C> =
|
||||
TopNComputer::new_with_comparator(doc_range.end, comparator);
|
||||
for (sort_key, doc) in sort_key_docs {
|
||||
top_collector.push(sort_key, doc);
|
||||
}
|
||||
top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(doc_range.start)
|
||||
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
|
||||
{
|
||||
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
|
||||
pub(crate) segment_ord: u32,
|
||||
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, C> SegmentCollector
|
||||
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
|
||||
{
|
||||
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
self.segment_sort_key_computer.compute_sort_key_and_collect(
|
||||
doc,
|
||||
score,
|
||||
&mut self.topn_computer,
|
||||
);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
let segment_ord = self.segment_ord;
|
||||
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
|
||||
.topn_computer
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
let sort_key = self
|
||||
.segment_sort_key_computer
|
||||
.convert_segment_sort_key(comparable_doc.sort_key);
|
||||
(
|
||||
sort_key,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
segment_hits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use rand;
|
||||
use rand::seq::SliceRandom as _;
|
||||
|
||||
use super::merge_top_k;
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::Order;
|
||||
|
||||
fn test_merge_top_k_aux(
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: &[(crate::Score, usize)],
|
||||
) {
|
||||
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
|
||||
vals.shuffle(&mut rand::thread_rng());
|
||||
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
|
||||
assert_eq!(&vals_merged, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_top_k() {
|
||||
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(
|
||||
Order::Asc,
|
||||
0..11,
|
||||
&[
|
||||
(0.0f32, 0),
|
||||
(1.0f32, 1),
|
||||
(2.0f32, 2),
|
||||
(3.0f32, 3),
|
||||
(4.0f32, 4),
|
||||
(5.0f32, 5),
|
||||
(6.0f32, 6),
|
||||
(7.0f32, 7),
|
||||
(8.0f32, 8),
|
||||
(9.0f32, 9),
|
||||
],
|
||||
);
|
||||
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
|
||||
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_some_collector = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value: u64| value > 20_120u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let top_docs = searcher.search(&query, &filter_some_collector)?;
|
||||
|
||||
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value| value < 5u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
|
||||
|
||||
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
> 0
|
||||
}
|
||||
|
||||
let filter_dates_collector =
|
||||
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
|
||||
let filter_dates_collector = FilterCollector::new(
|
||||
"date".to_string(),
|
||||
&date_filter,
|
||||
TopDocs::with_limit(5).order_by_score(),
|
||||
);
|
||||
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
|
||||
|
||||
assert_eq!(filtered_date_docs.len(), 2);
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::top_score_collector::TopNComputer;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// It guarantees stable sorting: in case of a tie on the feature, the document
|
||||
@@ -19,7 +14,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is any type that implements `PartialOrd`.
|
||||
pub feature: T,
|
||||
pub sort_key: T,
|
||||
/// The document address. In practice, this is any
|
||||
/// type that implements `PartialOrd`, and is guaranteed
|
||||
/// to be unique for each document.
|
||||
@@ -28,9 +23,9 @@ pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
|
||||
for ComparableDoc<T, D, R>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
|
||||
.field("feature", &self.feature)
|
||||
.field("feature", &self.sort_key)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
@@ -46,8 +41,8 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R>
|
||||
#[inline]
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let by_feature = self
|
||||
.feature
|
||||
.partial_cmp(&other.feature)
|
||||
.sort_key
|
||||
.partial_cmp(&other.sort_key)
|
||||
.map(|ord| if R { ord.reverse() } else { ord })
|
||||
.unwrap_or(Ordering::Equal);
|
||||
|
||||
@@ -67,308 +62,3 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
|
||||
|
||||
pub(crate) struct TopCollector<T> {
|
||||
pub limit: usize,
|
||||
pub offset: usize,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> TopCollector<T>
|
||||
where T: PartialOrd + Clone
|
||||
{
|
||||
/// Creates a top collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(limit: usize) -> TopCollector<T> {
|
||||
assert!(limit >= 1, "Limit must be strictly greater than 0.");
|
||||
Self {
|
||||
limit,
|
||||
offset: 0,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip the first "offset" documents when collecting.
|
||||
///
|
||||
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
|
||||
/// Lucene's TopDocsCollector.
|
||||
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
|
||||
self.offset = offset;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn merge_fruits(
|
||||
&self,
|
||||
children: Vec<Vec<(T, DocAddress)>>,
|
||||
) -> crate::Result<Vec<(T, DocAddress)>> {
|
||||
if self.limit == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
|
||||
for child_fruit in children {
|
||||
for (feature, doc) in child_fruit {
|
||||
top_collector.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(self.offset)
|
||||
.map(|cdoc| (cdoc.feature, cdoc.doc))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub(crate) fn for_segment<F: PartialOrd + Clone>(
|
||||
&self,
|
||||
segment_id: SegmentOrdinal,
|
||||
_: &SegmentReader,
|
||||
) -> TopSegmentCollector<F> {
|
||||
TopSegmentCollector::new(segment_id, self.limit + self.offset)
|
||||
}
|
||||
|
||||
/// Create a new TopCollector with the same limit and offset.
|
||||
///
|
||||
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
|
||||
/// to fail.
|
||||
#[doc(hidden)]
|
||||
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
|
||||
TopCollector {
|
||||
limit: self.limit,
|
||||
offset: self.offset,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The Top Collector keeps track of the K documents
|
||||
/// sorted by type `T`.
|
||||
///
|
||||
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
|
||||
/// The theoretical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n + K)`.
|
||||
pub(crate) struct TopSegmentCollector<T> {
|
||||
/// We reverse the order of the feature in order to
|
||||
/// have top-semantics instead of bottom semantics.
|
||||
topn_computer: TopNComputer<T, DocId>,
|
||||
segment_ord: u32,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
topn_computer: TopNComputer::new(limit),
|
||||
segment_ord,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
pub fn harvest(self) -> Vec<(T, DocAddress)> {
|
||||
let segment_ord = self.segment_ord;
|
||||
self.topn_computer
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
(
|
||||
comparable_doc.feature,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
#[inline]
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
self.topn_computer.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{TopCollector, TopSegmentCollector};
|
||||
use crate::DocAddress;
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
top_collector.collect(7, 0.9);
|
||||
top_collector.collect(9, -0.2);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.9, DocAddress::new(0, 7)),
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
|
||||
// given that the documents are collected in ascending doc id order,
|
||||
// when harvesting we have to guarantee stable sorting in case of a tie
|
||||
// on the score
|
||||
let doc_ids_collection = [4, 5, 6];
|
||||
let score = 3.3f32;
|
||||
|
||||
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_2.collect(*id, score);
|
||||
}
|
||||
|
||||
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_3.collect(*id, score);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
top_collector_limit_2.harvest(),
|
||||
top_collector_limit_3.harvest()[..2].to_vec(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
(0.7, DocAddress::new(0, 3)),
|
||||
(0.6, DocAddress::new(0, 4)),
|
||||
(0.5, DocAddress::new(0, 5)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_larger_than_set_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset_larger_than_set() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(20);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
use test::Bencher;
|
||||
|
||||
use super::TopSegmentCollector;
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 400);
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
let mut score = 1.0;
|
||||
|
||||
for i in 0..100 {
|
||||
score += 1.0;
|
||||
top_collector.collect(i, score);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user