mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-10 11:02:55 +00:00
Compare commits
23 Commits
stuhood.la
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d904630e6a | ||
|
|
65b5a1a306 | ||
|
|
db2ecc6057 | ||
|
|
77505c3d03 | ||
|
|
735c588f4f | ||
|
|
242a1531bf | ||
|
|
6443b63177 | ||
|
|
4987495ee4 | ||
|
|
b11605f045 | ||
|
|
75d7989cc6 | ||
|
|
923f0508f2 | ||
|
|
e0b62e00ac | ||
|
|
ce97beb86f | ||
|
|
c0f21a45ae | ||
|
|
73657dff77 | ||
|
|
e3c9be1f92 | ||
|
|
ba61ed6ef3 | ||
|
|
d0e1600135 | ||
|
|
e9020d17d4 | ||
|
|
5ba0031f7d | ||
|
|
22dde8f9ae | ||
|
|
14cc24614e | ||
|
|
8a1079b2dc |
29
.github/workflows/coverage.yml
vendored
Normal file
29
.github/workflows/coverage.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Coverage
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
# Ensures that we cancel running jobs for the same PR / same workflow.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
continue-on-error: true
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
|
||||
files: lcov.info
|
||||
fail_ci_if_error: true
|
||||
34
.github/workflows/test.yml
vendored
34
.github/workflows/test.yml
vendored
@@ -39,11 +39,11 @@ jobs:
|
||||
|
||||
- name: Check Formatting
|
||||
run: cargo +nightly fmt --all -- --check
|
||||
|
||||
|
||||
- name: Check Stable Compilation
|
||||
run: cargo build --all-features
|
||||
|
||||
|
||||
|
||||
- name: Check Bench Compilation
|
||||
run: cargo +nightly bench --no-run --profile=dev --all-features
|
||||
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
features: [
|
||||
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
|
||||
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
]
|
||||
features:
|
||||
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
|
||||
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
- { label: "none", flags: "" }
|
||||
|
||||
name: test-${{ matrix.features.label}}
|
||||
|
||||
@@ -76,13 +76,25 @@ jobs:
|
||||
profile: minimal
|
||||
override: true
|
||||
|
||||
- uses: taiki-e/install-action@v2
|
||||
with:
|
||||
tool: 'nextest'
|
||||
- uses: taiki-e/install-action@nextest
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Run tests
|
||||
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
cargo +stable nextest run --lib --no-default-features --verbose --workspace
|
||||
else
|
||||
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
|
||||
fi
|
||||
|
||||
- name: Run doctests
|
||||
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
echo "no doctest for no feature flag"
|
||||
else
|
||||
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
fi
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,6 +6,7 @@ target
|
||||
target/debug
|
||||
.vscode
|
||||
target/release
|
||||
Cargo.lock
|
||||
benchmark
|
||||
.DS_Store
|
||||
*.bk
|
||||
@@ -14,7 +15,3 @@ trace.dat
|
||||
cargo-timing*
|
||||
control
|
||||
variable
|
||||
|
||||
# for `sample record -p`
|
||||
profile.json
|
||||
profile.json.gz
|
||||
|
||||
2361
Cargo.lock
generated
2361
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
55
Cargo.toml
55
Cargo.toml
@@ -21,11 +21,11 @@ byteorder = "1.4.3"
|
||||
crc32fast = "1.3.2"
|
||||
once_cell = "1.10.0"
|
||||
regex = { version = "1.5.5", default-features = false, features = [
|
||||
"std",
|
||||
"unicode",
|
||||
"std",
|
||||
"unicode",
|
||||
] }
|
||||
aho-corasick = "1.0"
|
||||
tantivy-fst = { git = "https://github.com/paradedb/fst.git" }
|
||||
tantivy-fst = "0.5"
|
||||
memmap2 = { version = "0.9.0", optional = true }
|
||||
lz4_flex = { version = "0.11", default-features = false, optional = true }
|
||||
zstd = { version = "0.13", optional = true, default-features = false }
|
||||
@@ -37,11 +37,10 @@ fs4 = { version = "0.13.1", optional = true }
|
||||
levenshtein_automata = "0.2.1"
|
||||
uuid = { version = "1.0.0", features = ["v4", "serde"] }
|
||||
crossbeam-channel = "0.5.4"
|
||||
rust-stemmers = "1.2.0"
|
||||
tantivy-stemmers = { version = "0.4.0", default-features = false, features = ["polish_yarovoy"] }
|
||||
rust-stemmers = { version = "1.2.0", optional = true }
|
||||
downcast-rs = "2.0.1"
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
bitpacking = { version = "0.9.3", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
] }
|
||||
census = "0.4.2"
|
||||
rustc-hash = "2.0.0"
|
||||
@@ -49,10 +48,6 @@ thiserror = "2.0.1"
|
||||
htmlescape = "0.3.1"
|
||||
fail = { version = "0.5.0", optional = true }
|
||||
time = { version = "0.3.35", features = ["serde-well-known"] }
|
||||
# TODO: We have integer wrappers with PartialOrd, and a misfeature of
|
||||
# `deranged` causes inference to fail in a bunch of cases. See
|
||||
# https://github.com/jhpratt/deranged/issues/18#issuecomment-2746844093
|
||||
deranged = "=0.4.0"
|
||||
smallvec = "1.8.0"
|
||||
rayon = "1.5.2"
|
||||
lru = "0.12.0"
|
||||
@@ -74,19 +69,18 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
parking_lot = "0.12.4"
|
||||
typetag = "0.2.21"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = "0.3.9"
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.14.0"
|
||||
binggan = "0.14.2"
|
||||
rand = "0.8.5"
|
||||
maplit = "1.0.2"
|
||||
matches = "0.1.9"
|
||||
pretty_assertions = "1.2.1"
|
||||
proptest = "1.0.0"
|
||||
proptest = "1.7.0"
|
||||
test-log = "0.2.10"
|
||||
futures = "0.3.21"
|
||||
paste = "1.0.11"
|
||||
@@ -94,7 +88,7 @@ more-asserts = "0.3.1"
|
||||
rand_distr = "0.4.3"
|
||||
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
||||
postcard = { version = "1.0.4", features = [
|
||||
"use-std",
|
||||
"use-std",
|
||||
], default-features = false }
|
||||
|
||||
[target.'cfg(not(windows))'.dev-dependencies]
|
||||
@@ -119,7 +113,8 @@ debug-assertions = true
|
||||
overflow-checks = true
|
||||
|
||||
[features]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
|
||||
stemmer = ["rust-stemmers"]
|
||||
mmap = ["fs4", "tempfile", "memmap2"]
|
||||
stopwords = []
|
||||
|
||||
@@ -141,14 +136,14 @@ compare_hash_only = ["stacker/compare_hash_only"]
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
"query-grammar",
|
||||
"bitpacker",
|
||||
"common",
|
||||
"ownedbytes",
|
||||
"stacker",
|
||||
"sstable",
|
||||
"tokenizer-api",
|
||||
"columnar",
|
||||
"query-grammar",
|
||||
"bitpacker",
|
||||
"common",
|
||||
"ownedbytes",
|
||||
"stacker",
|
||||
"sstable",
|
||||
"tokenizer-api",
|
||||
"columnar",
|
||||
]
|
||||
|
||||
# Following the "fail" crate best practises, we isolate
|
||||
@@ -179,6 +174,18 @@ harness = false
|
||||
name = "exists_json"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_query"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "and_or_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bool_queries_with_range"
|
||||
harness = false
|
||||
|
||||
@@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, stats_f64);
|
||||
register!(group, extendedstats_f64);
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_7);
|
||||
register!(group, terms_all_unique);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_150_000);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_few_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status);
|
||||
register!(group, terms_few_with_histogram);
|
||||
register!(group, terms_status_with_histogram);
|
||||
register!(group, terms_zipf_1000);
|
||||
register!(group, terms_zipf_1000_with_histogram);
|
||||
register!(group, terms_zipf_1000_with_avg_sub_agg);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, terms_few_with_cardinality_agg);
|
||||
register!(group, terms_status_with_cardinality_agg);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
register!(group, range_agg_with_term_agg_few);
|
||||
register!(group, range_agg_with_term_agg_status);
|
||||
register!(group, range_agg_with_term_agg_many);
|
||||
register!(group, histogram);
|
||||
register!(group, histogram_hard_bounds);
|
||||
register!(group, histogram_with_avg_sub_agg);
|
||||
register!(group, histogram_with_term_agg_few);
|
||||
register!(group, histogram_with_term_agg_status);
|
||||
register!(group, avg_and_range_with_avg_sub_agg);
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
@@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
@@ -175,13 +175,7 @@ fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status(index: &Index) {
|
||||
fn terms_7(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
});
|
||||
@@ -194,7 +188,7 @@ fn terms_all_unique(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many(index: &Index) {
|
||||
fn terms_150_000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms" } },
|
||||
});
|
||||
@@ -253,17 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -276,17 +259,18 @@ fn terms_status_with_histogram(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few_with_avg_sub_agg(index: &Index) {
|
||||
fn terms_zipf_1000_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
},
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -299,6 +283,25 @@ fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -354,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn range_agg_with_term_agg_few(index: &Index) {
|
||||
fn range_agg_with_term_agg_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"range": {
|
||||
@@ -369,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -425,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn histogram_with_term_agg_few(index: &Index) {
|
||||
fn histogram_with_term_agg_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"histogram": { "field": "score_f64", "interval": 10 },
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } }
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -475,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
// Flag to use existing index
|
||||
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
|
||||
if reuse_index && std::path::Path::new("agg_bench").exists() {
|
||||
return Index::open_in_dir("agg_bench");
|
||||
}
|
||||
// crreate dir
|
||||
std::fs::create_dir_all("agg_bench")?;
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_fieldtype = tantivy::schema::TextOptions::default()
|
||||
.set_indexing_options(
|
||||
@@ -486,24 +496,44 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
|
||||
let text_field_few_terms_status =
|
||||
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
|
||||
let text_field_1000_terms_zipf =
|
||||
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
|
||||
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
|
||||
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
|
||||
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
|
||||
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
|
||||
let index = Index::create_from_tempdir(schema_builder.build())?;
|
||||
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
|
||||
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
|
||||
// use tmp dir
|
||||
let index = if reuse_index {
|
||||
Index::create_in_dir("agg_bench", schema_builder.build())?
|
||||
} else {
|
||||
Index::create_from_tempdir(schema_builder.build())?
|
||||
};
|
||||
// Approximate log proportions
|
||||
let status_field_data = [
|
||||
("INFO", 8000),
|
||||
("ERROR", 300),
|
||||
("WARN", 1200),
|
||||
("DEBUG", 500),
|
||||
("OK", 500),
|
||||
("CRITICAL", 20),
|
||||
("EMERGENCY", 1),
|
||||
];
|
||||
let log_level_distribution =
|
||||
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
|
||||
|
||||
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
|
||||
|
||||
let many_terms_data = (0..150_000)
|
||||
.map(|num| format!("author{num}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Prepare 1000 unique terms sampled using a Zipf distribution.
|
||||
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
|
||||
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
|
||||
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
|
||||
|
||||
{
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
|
||||
@@ -513,8 +543,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!())?;
|
||||
}
|
||||
if cardinality == Cardinality::Multivalued {
|
||||
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let term_1000_a = &terms_1000[idx_a];
|
||||
let term_1000_b = &terms_1000[idx_b];
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
@@ -524,10 +558,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
text_field_all_unique_terms => "coolo",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms_status => log_level_sample_a,
|
||||
text_field_few_terms_status => log_level_sample_b,
|
||||
text_field_1000_terms_zipf => term_1000_a.as_str(),
|
||||
text_field_1000_terms_zipf => term_1000_b.as_str(),
|
||||
score_field => 1u64,
|
||||
score_field => 1u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
@@ -554,8 +588,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
|
||||
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
|
||||
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
|
||||
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
|
||||
score_field => val as u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
score_field_i64 => val as i64,
|
||||
@@ -607,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -623,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
288
benches/bool_queries_with_range.rs
Normal file
288
benches/bool_queries_with_range.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
|
||||
// Unified schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
|
||||
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.gen_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.gen_range(0u64..1000u64);
|
||||
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.gen_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.gen_range(0u64..10000000u64);
|
||||
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Build query parser for title field
|
||||
let qp_title = QueryParser::for_index(&index, vec![f_title]);
|
||||
|
||||
BenchIndex {
|
||||
index,
|
||||
searcher,
|
||||
query_parser: qp_title,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define all four field types
|
||||
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Define the three terms we want to test with
|
||||
let terms = ["a", "b", "z"];
|
||||
|
||||
// Generate all combinations of terms and field names
|
||||
let mut queries = Vec::new();
|
||||
for &term in &terms {
|
||||
for &field_name in &field_names {
|
||||
let query_str = format!(
|
||||
"{} AND {}:[{} TO {}]",
|
||||
term, field_name, range_low, range_high
|
||||
);
|
||||
queries.push((query_str, field_name.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let query_str = format!(
|
||||
"{}:[{} TO {}] AND {}:[{} TO {}]",
|
||||
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
|
||||
);
|
||||
queries.push((query_str, "num_asc_fast".to_string()));
|
||||
|
||||
// Run all benchmark tasks for each query and its corresponding field name
|
||||
for (query_str, field_name) in queries {
|
||||
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given query string and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
field_name: &str,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task(bench_group, bench_index, query_str, Count, "count");
|
||||
|
||||
// Test all results
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
DocSetCollector,
|
||||
"all results",
|
||||
);
|
||||
|
||||
// Test top 100 by the field (if it's a FAST field)
|
||||
if field_name.ends_with("_fast") {
|
||||
// Ascending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
|
||||
// Descending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
|
||||
*count
|
||||
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(top_docs) =
|
||||
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
|
||||
{
|
||||
doc_set.len()
|
||||
} else {
|
||||
eprintln!(
|
||||
"Unknown collector result type: {:?}",
|
||||
std::any::type_name::<C::Fruit>()
|
||||
);
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
365
benches/range_queries.rs
Normal file
365
benches/range_queries.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use std::ops::Bound;
|
||||
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::RangeQuery;
|
||||
use tantivy::schema::{Schema, FAST, INDEXED};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
|
||||
// Schema with fast fields only
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.gen_range(0u64..1000u64);
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.gen_range(0u64..10000000u64);
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
BenchIndex { index, searcher }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
// Dense distribution - random values in small range (0-999)
|
||||
(
|
||||
"dense_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"dense_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
1000,
|
||||
1002,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
10_000_000,
|
||||
10_000_002,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define fast field types
|
||||
let field_names = ["num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Generate range queries for fast fields
|
||||
for &field_name in &field_names {
|
||||
// Create the range query
|
||||
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
|
||||
let lower_term = Term::from_field_u64(field, range_low);
|
||||
let upper_term = Term::from_field_u64(field, range_high);
|
||||
|
||||
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
|
||||
|
||||
run_benchmark_tasks(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given range query and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task_count(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
"count",
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test top 100 by the field (ascending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_asc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
|
||||
// Test top 100 by the field (descending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_desc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query,
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task_count(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = CountSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_docset(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = DocSetSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_asc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100AscSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_desc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100DescSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct CountSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl CountSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &Count).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct DocSetSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl DocSetSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100AscSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100AscSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100DescSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100DescSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
260
benches/range_query.rs
Normal file
260
benches/range_query.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
use std::fmt::Display;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy::collector::{Count, TopDocs};
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::*;
|
||||
use tantivy::{doc, Index};
|
||||
|
||||
#[global_allocator]
|
||||
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
|
||||
|
||||
fn main() {
|
||||
bench_range_query();
|
||||
}
|
||||
|
||||
fn bench_range_query() {
|
||||
let index = get_index_0_to_100();
|
||||
let mut runner = BenchRunner::new();
|
||||
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
|
||||
|
||||
runner.set_name("range_query on u64");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("id", "Single Valued Range Field"),
|
||||
("ids", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent()),
|
||||
("10_percent", get_10_percent()),
|
||||
("1_percent", get_1_percent()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
|
||||
runner.set_name("range_query on ip");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("ip", "Single Valued Range Field"),
|
||||
("ips", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent_ip()),
|
||||
("10_percent", get_10_percent_ip()),
|
||||
("1_percent", get_1_percent_ip()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
}
|
||||
|
||||
fn test_range<T: Display>(
|
||||
runner: &mut BenchRunner,
|
||||
index: &Index,
|
||||
field_name_and_descr: &[(&str, &str)],
|
||||
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
|
||||
) {
|
||||
for (field, suffix) in field_name_and_descr {
|
||||
let term_num_hits = vec![
|
||||
("", ""),
|
||||
("1_percent", "veryfew"),
|
||||
("10_percent", "few"),
|
||||
("90_percent", "most"),
|
||||
];
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(suffix);
|
||||
// all intersect combinations
|
||||
for (range_name, range) in &range_num_hits {
|
||||
for (term_name, term) in &term_num_hits {
|
||||
let index = &index;
|
||||
let test_name = if term_name.is_empty() {
|
||||
format!("id_range_hit_{}", range_name)
|
||||
} else {
|
||||
format!(
|
||||
"id_range_hit_{}_intersect_with_term_{}",
|
||||
range_name, term_name
|
||||
)
|
||||
};
|
||||
group.register(test_name, move |_| {
|
||||
let query = if term_name.is_empty() {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("AND id_name:{}", term)
|
||||
};
|
||||
black_box(execute_query(field, range, &query, index));
|
||||
});
|
||||
}
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id_name = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"most".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id_name,
|
||||
id: rng.gen_range(0..100),
|
||||
// Multiply by 1000, so that we create most buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_docs(&docs)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
pub id_name: String,
|
||||
pub id: u64,
|
||||
pub ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
|
||||
let ids_u64_field =
|
||||
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
|
||||
|
||||
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
|
||||
let ids_f64_field = schema_builder.add_f64_field(
|
||||
"ids_f64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
|
||||
let ids_i64_field = schema_builder.add_i64_field(
|
||||
"ids_i64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
|
||||
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
|
||||
|
||||
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
|
||||
|
||||
let schema = schema_builder.build();
|
||||
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
|
||||
for doc in docs.iter() {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_u64_field => doc.id,
|
||||
ids_u64_field => doc.id,
|
||||
id_u64_field => doc.id,
|
||||
id_f64_field => doc.id as f64,
|
||||
id_i64_field => doc.id as i64,
|
||||
text_field => doc.id_name.to_string(),
|
||||
text_field2 => doc.id_name.to_string(),
|
||||
ips_field => doc.ip,
|
||||
ips_field => doc.ip,
|
||||
ip_field => doc.ip,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<u64> {
|
||||
0..=90
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<u64> {
|
||||
0..=10
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<u64> {
|
||||
10..=10
|
||||
}
|
||||
|
||||
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
struct NumHits {
|
||||
count: usize,
|
||||
}
|
||||
impl OutputValue for NumHits {
|
||||
fn column_title() -> &'static str {
|
||||
"NumHits"
|
||||
}
|
||||
fn format(&self) -> Option<String> {
|
||||
Some(self.count.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_query<T: Display>(
|
||||
field: &str,
|
||||
id_range: &RangeInclusive<T>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> NumHits {
|
||||
let gen_query_inclusive = |from: &T, to: &T| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(id_range.start(), id_range.end());
|
||||
execute_query_(&query, index)
|
||||
}
|
||||
|
||||
fn execute_query_(query: &str, index: &Index) -> NumHits {
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let num_hits = searcher
|
||||
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
|
||||
.unwrap()
|
||||
.1;
|
||||
NumHits { count: num_hits }
|
||||
}
|
||||
@@ -11,6 +11,9 @@ keywords = []
|
||||
documentation = "https://docs.rs/tantivy-bitpacker/latest/tantivy_bitpacker"
|
||||
homepage = "https://github.com/quickwit-oss/tantivy"
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ impl BitPacker {
|
||||
|
||||
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
|
||||
if self.mini_buffer_written > 0 {
|
||||
let num_bytes = (self.mini_buffer_written + 7) / 8;
|
||||
let num_bytes = self.mini_buffer_written.div_ceil(8);
|
||||
let bytes = self.mini_buffer.to_le_bytes();
|
||||
output.write_all(&bytes[..num_bytes])?;
|
||||
self.mini_buffer_written = 0;
|
||||
@@ -65,16 +65,10 @@ impl BitPacker {
|
||||
|
||||
#[derive(Clone, Debug, Default, Copy)]
|
||||
pub struct BitUnpacker {
|
||||
num_bits: u32,
|
||||
num_bits: usize,
|
||||
mask: u64,
|
||||
}
|
||||
|
||||
pub type BlockNumber = usize;
|
||||
|
||||
// 16k
|
||||
const BLOCK_SIZE_MIN_POW: u8 = 14;
|
||||
const BLOCK_SIZE_MIN: usize = 2 << BLOCK_SIZE_MIN_POW;
|
||||
|
||||
impl BitUnpacker {
|
||||
/// Creates a bit unpacker, that assumes the same bitwidth for all values.
|
||||
///
|
||||
@@ -88,9 +82,8 @@ impl BitUnpacker {
|
||||
} else {
|
||||
(1u64 << num_bits) - 1u64
|
||||
};
|
||||
|
||||
BitUnpacker {
|
||||
num_bits: u32::from(num_bits),
|
||||
num_bits: usize::from(num_bits),
|
||||
mask,
|
||||
}
|
||||
}
|
||||
@@ -99,69 +92,16 @@ impl BitUnpacker {
|
||||
self.num_bits as u8
|
||||
}
|
||||
|
||||
/// Calculates a block number for the given `idx`.
|
||||
#[inline]
|
||||
pub fn block_num(&self, idx: u32) -> BlockNumber {
|
||||
// Find the address in bits of the index.
|
||||
let addr_in_bits = (idx * self.num_bits) as usize;
|
||||
|
||||
// Then round down to the nearest byte.
|
||||
let addr_in_bytes = addr_in_bits >> 3;
|
||||
|
||||
// And compute the containing BlockNumber.
|
||||
addr_in_bytes >> (BLOCK_SIZE_MIN_POW + 1)
|
||||
}
|
||||
|
||||
/// Given a block number and dataset length, calculates a data Range for the block.
|
||||
pub fn block(&self, block: BlockNumber, data_len: usize) -> Range<usize> {
|
||||
let block_addr = block << (BLOCK_SIZE_MIN_POW + 1);
|
||||
// We extend the end of the block by a constant factor, so that it overlaps the next
|
||||
// block. That ensures that we never need to read on a block boundary.
|
||||
block_addr..(std::cmp::min(block_addr + BLOCK_SIZE_MIN + 8, data_len))
|
||||
}
|
||||
|
||||
/// Calculates the number of blocks for the given data_len.
|
||||
///
|
||||
/// Usually only called at startup to pre-allocate structures.
|
||||
pub fn block_count(&self, data_len: usize) -> usize {
|
||||
let block_count = data_len / (BLOCK_SIZE_MIN as usize);
|
||||
if data_len % (BLOCK_SIZE_MIN as usize) == 0 {
|
||||
block_count
|
||||
} else {
|
||||
block_count + 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a range within the data which covers the given id_range.
|
||||
///
|
||||
/// NOTE: This method is used for batch reads which bypass blocks to avoid dealing with block
|
||||
/// boundaries.
|
||||
#[inline]
|
||||
pub fn block_oblivious_range(&self, id_range: Range<u32>, data_len: usize) -> Range<usize> {
|
||||
let start_in_bits = id_range.start * self.num_bits;
|
||||
let start = (start_in_bits >> 3) as usize;
|
||||
let end_in_bits = id_range.end * self.num_bits;
|
||||
let end = (end_in_bits >> 3) as usize;
|
||||
// TODO: We fetch more than we need and then truncate.
|
||||
start..(std::cmp::min(end + 8, data_len))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
|
||||
self.get_from_subset(idx, 0, data)
|
||||
}
|
||||
|
||||
/// Get the value at the given idx, which must exist within the given subset of the data.
|
||||
#[inline]
|
||||
pub fn get_from_subset(&self, idx: u32, data_offset: usize, data: &[u8]) -> u64 {
|
||||
let addr_in_bits = idx * self.num_bits;
|
||||
let addr = (addr_in_bits >> 3) as usize - data_offset;
|
||||
let addr_in_bits = idx as usize * self.num_bits;
|
||||
let addr = addr_in_bits >> 3;
|
||||
if addr + 8 > data.len() {
|
||||
if self.num_bits == 0 {
|
||||
return 0;
|
||||
}
|
||||
let bit_shift = addr_in_bits & 7;
|
||||
return self.get_slow_path(addr, bit_shift, data);
|
||||
return self.get_slow_path(addr, bit_shift as u32, data);
|
||||
}
|
||||
let bit_shift = addr_in_bits & 7;
|
||||
let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
|
||||
@@ -173,7 +113,6 @@ impl BitUnpacker {
|
||||
#[inline(never)]
|
||||
fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
|
||||
let mut bytes: [u8; 8] = [0u8; 8];
|
||||
|
||||
let available_bytes = data.len() - addr;
|
||||
// This function is meant to only be called if we did not have 8 bytes to load.
|
||||
debug_assert!(available_bytes < 8);
|
||||
@@ -189,25 +128,26 @@ impl BitUnpacker {
|
||||
// #Panics
|
||||
//
|
||||
// This methods panics if `num_bits` is > 32.
|
||||
fn get_batch_u32s(&self, start_idx: u32, data_offset: usize, data: &[u8], output: &mut [u32]) {
|
||||
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
|
||||
assert!(
|
||||
self.bit_width() <= 32,
|
||||
"Bitwidth must be <= 32 to use this method."
|
||||
);
|
||||
|
||||
let end_idx = start_idx + output.len() as u32;
|
||||
let end_idx: u32 = start_idx + output.len() as u32;
|
||||
|
||||
let end_bit_read = end_idx * self.num_bits;
|
||||
let end_byte_read = (end_bit_read + 7) / 8;
|
||||
// We use `usize` here to avoid overflow issues.
|
||||
let end_bit_read = (end_idx as usize) * self.num_bits;
|
||||
let end_byte_read = end_bit_read.div_ceil(8);
|
||||
assert!(
|
||||
end_byte_read as usize <= data_offset + data.len(),
|
||||
end_byte_read <= data.len(),
|
||||
"Requested index is out of bounds."
|
||||
);
|
||||
|
||||
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
|
||||
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
|
||||
for (out, idx) in output.iter_mut().zip(start_idx..) {
|
||||
*out = self.get_from_subset(idx, data_offset, data) as u32;
|
||||
*out = self.get(idx, data) as u32;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -220,24 +160,24 @@ impl BitUnpacker {
|
||||
// We want the start of the fast track to start align with bytes.
|
||||
// A sufficient condition is to start with an idx that is a multiple of 8,
|
||||
// so highway start is the closest multiple of 8 that is >= start_idx.
|
||||
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
|
||||
let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
|
||||
|
||||
let highway_start: u32 = start_idx + entrance_ramp_len;
|
||||
|
||||
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
|
||||
if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
|
||||
// We don't have enough values to have even a single block of highway.
|
||||
// Let's just supply the values the simple way.
|
||||
get_batch_ramp(start_idx, output);
|
||||
return;
|
||||
}
|
||||
|
||||
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
|
||||
let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
|
||||
|
||||
// Entrance ramp
|
||||
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
|
||||
|
||||
// Highway
|
||||
let mut offset = ((highway_start * self.num_bits) as usize / 8) - data_offset;
|
||||
let mut offset = (highway_start as usize * self.num_bits) / 8;
|
||||
let mut output_cursor = (highway_start - start_idx) as usize;
|
||||
for _ in 0..num_blocks {
|
||||
offset += BitPacker1x.decompress(
|
||||
@@ -249,7 +189,7 @@ impl BitUnpacker {
|
||||
}
|
||||
|
||||
// Exit ramp
|
||||
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
|
||||
let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
|
||||
get_batch_ramp(highway_end, &mut output[output_cursor..]);
|
||||
}
|
||||
|
||||
@@ -259,27 +199,16 @@ impl BitUnpacker {
|
||||
id_range: Range<u32>,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
self.get_ids_for_value_range_from_subset(range, id_range, 0, data, positions)
|
||||
}
|
||||
|
||||
pub fn get_ids_for_value_range_from_subset(
|
||||
&self,
|
||||
range: RangeInclusive<u64>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
if self.bit_width() > 32 {
|
||||
self.get_ids_for_value_range_slow(range, id_range, data_offset, data, positions)
|
||||
self.get_ids_for_value_range_slow(range, id_range, data, positions)
|
||||
} else {
|
||||
if *range.start() > u32::MAX as u64 {
|
||||
positions.clear();
|
||||
return;
|
||||
}
|
||||
let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
|
||||
self.get_ids_for_value_range_fast(range_u32, id_range, data_offset, data, positions)
|
||||
self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,7 +216,6 @@ impl BitUnpacker {
|
||||
&self,
|
||||
range: RangeInclusive<u64>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
@@ -295,7 +223,7 @@ impl BitUnpacker {
|
||||
for i in id_range {
|
||||
// If we cared we could make this branchless, but the slow implementation should rarely
|
||||
// kick in.
|
||||
let val = self.get_from_subset(i, data_offset, data);
|
||||
let val = self.get(i, data);
|
||||
if range.contains(&val) {
|
||||
positions.push(i);
|
||||
}
|
||||
@@ -306,12 +234,11 @@ impl BitUnpacker {
|
||||
&self,
|
||||
value_range: RangeInclusive<u32>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
positions.resize(id_range.len(), 0u32);
|
||||
self.get_batch_u32s(id_range.start, data_offset, data, positions);
|
||||
self.get_batch_u32s(id_range.start, data, positions);
|
||||
crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
|
||||
}
|
||||
}
|
||||
@@ -402,14 +329,14 @@ mod test {
|
||||
fn test_get_batch_panics_over_32_bits() {
|
||||
let bitunpacker = BitUnpacker::new(33);
|
||||
let mut output: [u32; 1] = [0u32];
|
||||
bitunpacker.get_batch_u32s(0, 0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_batch_limit() {
|
||||
let bitunpacker = BitUnpacker::new(1);
|
||||
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 3, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -418,7 +345,7 @@ mod test {
|
||||
let bitunpacker = BitUnpacker::new(1);
|
||||
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
|
||||
// We are missing exactly one bit.
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 2, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
@@ -441,7 +368,7 @@ mod test {
|
||||
for len in [0, 1, 2, 32, 33, 34, 64] {
|
||||
for start_idx in 0u32..32u32 {
|
||||
output.resize(len, 0);
|
||||
bitunpacker.get_batch_u32s(start_idx, 0, &buffer, &mut output);
|
||||
bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
|
||||
for (i, output_byte) in output.iter().enumerate() {
|
||||
let expected = (start_idx + i as u32) & mask;
|
||||
assert_eq!(*output_byte, expected);
|
||||
|
||||
@@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
|
||||
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
|
||||
common = { version= "0.10", path = "../common", package = "tantivy-common" }
|
||||
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde = "1.0.152"
|
||||
downcast-rs = "2.0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use binggan::{InputGroup, black_box};
|
||||
use common::*;
|
||||
use tantivy_columnar::{Column, ValueRange};
|
||||
use tantivy_columnar::Column;
|
||||
|
||||
pub mod common;
|
||||
|
||||
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
|
||||
runner.register("access_first_vals", |column| {
|
||||
let mut sum = 0;
|
||||
const BLOCK_SIZE: usize = 32;
|
||||
let mut docs = Vec::with_capacity(BLOCK_SIZE);
|
||||
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
|
||||
let mut docs = vec![0; BLOCK_SIZE];
|
||||
let mut buffer = vec![None; BLOCK_SIZE];
|
||||
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
|
||||
docs.clear();
|
||||
// fill docs
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for idx in 0..BLOCK_SIZE {
|
||||
docs.push(idx as u32 + i);
|
||||
docs[idx] = idx as u32 + i;
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
|
||||
column.first_vals(&docs, &mut buffer);
|
||||
for val in buffer.iter() {
|
||||
let Some(val) = val else { continue };
|
||||
sum += *val;
|
||||
|
||||
@@ -40,14 +40,7 @@ fn main() {
|
||||
let columnar_readers = columnar_readers.iter().collect::<Vec<_>>();
|
||||
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
|
||||
|
||||
merge_columnar(
|
||||
&columnar_readers,
|
||||
&[],
|
||||
merge_row_order.into(),
|
||||
&mut out,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
|
||||
Some(out.len() as u64)
|
||||
},
|
||||
);
|
||||
|
||||
@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
|
||||
pub fn fetch_block_with_missing(
|
||||
&mut self,
|
||||
docs: &[u32],
|
||||
accessor: &Column<T>,
|
||||
missing: Option<T>,
|
||||
) {
|
||||
self.fetch_block(docs, accessor);
|
||||
// no missing values
|
||||
if accessor.index.get_cardinality().is_full() {
|
||||
return;
|
||||
}
|
||||
let Some(missing) = missing else {
|
||||
return;
|
||||
};
|
||||
|
||||
// We can compare docid_cache length with docs to find missing docs
|
||||
// For multi value columns we can't rely on the length and always need to scan
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
mod dictionary_encoded;
|
||||
mod serialize;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::{self, Debug};
|
||||
use std::io::Write;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
@@ -20,11 +19,6 @@ use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
|
||||
use crate::column_values::{ColumnValues, monotonic_map_column};
|
||||
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
|
||||
|
||||
thread_local! {
|
||||
static ROWS: RefCell<Vec<RowId>> = const { RefCell::new(Vec::new()) };
|
||||
static DOCS: RefCell<Vec<DocId>> = const { RefCell::new(Vec::new()) };
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Column<T = u64> {
|
||||
pub index: ColumnIndex,
|
||||
@@ -91,8 +85,33 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn first(&self, row_id: RowId) -> Option<T> {
|
||||
self.values_for_doc(row_id).next()
|
||||
pub fn first(&self, doc_id: DocId) -> Option<T> {
|
||||
self.values_for_doc(doc_id).next()
|
||||
}
|
||||
|
||||
/// Load the first value for each docid in the provided slice.
|
||||
#[inline]
|
||||
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
|
||||
match &self.index {
|
||||
ColumnIndex::Empty { .. } => {}
|
||||
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
|
||||
ColumnIndex::Optional(optional_index) => {
|
||||
for (i, docid) in docids.iter().enumerate() {
|
||||
output[i] = optional_index
|
||||
.rank_if_exists(*docid)
|
||||
.map(|rowid| self.values.get_val(rowid));
|
||||
}
|
||||
}
|
||||
ColumnIndex::Multivalued(multivalued_index) => {
|
||||
for (i, docid) in docids.iter().enumerate() {
|
||||
let range = multivalued_index.range(*docid);
|
||||
let is_empty = range.start == range.end;
|
||||
if !is_empty {
|
||||
output[i] = Some(self.values.get_val(range.start));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Translates a block of docids to row_ids.
|
||||
@@ -124,7 +143,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
#[inline]
|
||||
pub fn get_docids_for_value_range(
|
||||
&self,
|
||||
value_range: ValueRange<T>,
|
||||
value_range: RangeInclusive<T>,
|
||||
selected_docid_range: Range<u32>,
|
||||
doc_ids: &mut Vec<u32>,
|
||||
) {
|
||||
@@ -149,181 +168,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// Separate impl block for methods requiring `Default` for `T`.
|
||||
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
|
||||
/// Load the first value for each docid in the provided slice.
|
||||
///
|
||||
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
|
||||
/// The `values` vector is populated with the values of the remaining documents.
|
||||
#[inline]
|
||||
pub fn first_vals_in_value_range(
|
||||
&self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
match (&self.index, value_range) {
|
||||
(ColumnIndex::Empty { .. }, value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
if nulls_match {
|
||||
for &doc in input_docs {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
(ColumnIndex::Full, value_range) => {
|
||||
self.values
|
||||
.get_vals_in_value_range(input_docs, input_docs, output, value_range);
|
||||
}
|
||||
(ColumnIndex::Optional(optional_index), value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
|
||||
let fallback_needed = ROWS.with(|rows_cell| {
|
||||
DOCS.with(|docs_cell| {
|
||||
let mut rows = rows_cell.borrow_mut();
|
||||
let mut docs = docs_cell.borrow_mut();
|
||||
rows.clear();
|
||||
docs.clear();
|
||||
|
||||
let mut has_nulls = false;
|
||||
|
||||
for &doc_id in input_docs {
|
||||
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
|
||||
rows.push(row_id);
|
||||
docs.push(doc_id);
|
||||
} else {
|
||||
has_nulls = true;
|
||||
if nulls_match {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_nulls || !nulls_match {
|
||||
self.values.get_vals_in_value_range(
|
||||
&rows,
|
||||
&docs,
|
||||
output,
|
||||
value_range.clone(),
|
||||
);
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
});
|
||||
|
||||
if fallback_needed {
|
||||
for &doc_id in input_docs {
|
||||
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
|
||||
let val = self.values.get_val(row_id);
|
||||
let value_matches = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
|
||||
ValueRange::LessThan(t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(t, _) => val <= *t,
|
||||
};
|
||||
|
||||
if value_matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc_id,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
} else if nulls_match {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc_id,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
for i in 0..input_docs.len() {
|
||||
let docid = input_docs[i];
|
||||
let row_range = multivalued_index.range(docid);
|
||||
let is_empty = row_range.start == row_range.end;
|
||||
if !is_empty {
|
||||
let val = self.values.get_val(row_range.start);
|
||||
let matches = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
|
||||
ValueRange::LessThan(t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(t, _) => val <= *t,
|
||||
};
|
||||
if matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: docid,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
} else if nulls_match {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: docid,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A range of values.
|
||||
///
|
||||
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
|
||||
/// is outweighed by the time spent processing a batch.
|
||||
///
|
||||
/// Implementers should pattern match on the variants to use optimized loops for each case.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ValueRange<T> {
|
||||
/// A range that includes both start and end.
|
||||
Inclusive(RangeInclusive<T>),
|
||||
/// A range that matches all values.
|
||||
All,
|
||||
/// A range that matches all values greater than the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
GreaterThan(T, bool),
|
||||
/// A range that matches all values greater than or equal to the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
GreaterThanOrEqual(T, bool),
|
||||
/// A range that matches all values less than the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
LessThan(T, bool),
|
||||
/// A range that matches all values less than or equal to the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
LessThanOrEqual(T, bool),
|
||||
}
|
||||
|
||||
impl BinarySerializable for Cardinality {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
self.to_code().serialize(writer)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::OwnedBytes;
|
||||
use sstable::Dictionary;
|
||||
|
||||
use crate::column::{BytesColumn, Column};
|
||||
@@ -41,13 +41,12 @@ pub fn serialize_column_mappable_to_u64<T: MonotonicallyMappableToU64>(
|
||||
}
|
||||
|
||||
pub fn open_column_u64<T: MonotonicallyMappableToU64>(
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<T>> {
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -62,13 +61,12 @@ pub fn open_column_u64<T: MonotonicallyMappableToU64>(
|
||||
}
|
||||
|
||||
pub fn open_column_u128<T: MonotonicallyMappableToU128>(
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<T>> {
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -86,13 +84,12 @@ pub fn open_column_u128<T: MonotonicallyMappableToU128>(
|
||||
///
|
||||
/// See [`open_u128_as_compact_u64`] for more details.
|
||||
pub fn open_column_u128_as_compact_u64(
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<u64>> {
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -106,21 +103,11 @@ pub fn open_column_u128_as_compact_u64(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_column_bytes(
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<BytesColumn> {
|
||||
let (body, dictionary_len_bytes) = file_slice.split_from_end(4);
|
||||
let dictionary_len = u32::from_le_bytes(
|
||||
dictionary_len_bytes
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Result<BytesColumn> {
|
||||
let (body, dictionary_len_bytes) = data.rsplit(4);
|
||||
let dictionary_len = u32::from_le_bytes(dictionary_len_bytes.as_slice().try_into().unwrap());
|
||||
let (dictionary_bytes, column_bytes) = body.split(dictionary_len as usize);
|
||||
|
||||
let dictionary = Arc::new(Dictionary::open(dictionary_bytes)?);
|
||||
let dictionary = Arc::new(Dictionary::from_bytes(dictionary_bytes)?);
|
||||
let term_ord_column = crate::column::open_column_u64::<u64>(column_bytes, format_version)?;
|
||||
Ok(BytesColumn {
|
||||
dictionary,
|
||||
@@ -128,7 +115,7 @@ pub fn open_column_bytes(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_column_str(file_slice: FileSlice, format_version: Version) -> io::Result<StrColumn> {
|
||||
let bytes_column = open_column_bytes(file_slice, format_version)?;
|
||||
pub fn open_column_str(data: OwnedBytes, format_version: Version) -> io::Result<StrColumn> {
|
||||
let bytes_column = open_column_bytes(data, format_version)?;
|
||||
Ok(StrColumn::wrap(bytes_column))
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ pub fn merge_column_index<'a>(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use common::file_slice::FileSlice;
|
||||
use common::OwnedBytes;
|
||||
|
||||
use crate::column_index::merge::detect_cardinality;
|
||||
use crate::column_index::multivalued_index::{
|
||||
@@ -178,7 +178,7 @@ mod tests {
|
||||
let mut output = Vec::new();
|
||||
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
|
||||
let multivalue =
|
||||
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
|
||||
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
|
||||
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
|
||||
assert_eq!(&start_indexes, &[0, 3, 5]);
|
||||
}
|
||||
@@ -216,7 +216,7 @@ mod tests {
|
||||
let mut output = Vec::new();
|
||||
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
|
||||
let multivalue =
|
||||
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
|
||||
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
|
||||
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
|
||||
assert_eq!(&start_indexes, &[0, 3, 5, 6]);
|
||||
}
|
||||
|
||||
@@ -3,8 +3,7 @@ use std::io::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::CountingWriter;
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{CountingWriter, OwnedBytes};
|
||||
|
||||
use super::optional_index::{open_optional_index, serialize_optional_index};
|
||||
use super::{OptionalIndex, SerializableOptionalIndex, Set};
|
||||
@@ -45,26 +44,21 @@ pub fn serialize_multivalued_index(
|
||||
}
|
||||
|
||||
pub fn open_multivalued_index(
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
format_version: Version,
|
||||
) -> io::Result<MultiValueIndex> {
|
||||
match format_version {
|
||||
Version::V1 => {
|
||||
let start_index_column: Arc<dyn ColumnValues<RowId>> =
|
||||
load_u64_based_column_values(file_slice)?;
|
||||
load_u64_based_column_values(bytes)?;
|
||||
Ok(MultiValueIndex::MultiValueIndexV1(MultiValueIndexV1 {
|
||||
start_index_column,
|
||||
}))
|
||||
}
|
||||
Version::V2 => {
|
||||
let (body_bytes, optional_index_len) = file_slice.split_from_end(4);
|
||||
let optional_index_len = u32::from_le_bytes(
|
||||
optional_index_len
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let (body_bytes, optional_index_len) = bytes.rsplit(4);
|
||||
let optional_index_len =
|
||||
u32::from_le_bytes(optional_index_len.as_slice().try_into().unwrap());
|
||||
let (optional_index_bytes, start_index_bytes) =
|
||||
body_bytes.split(optional_index_len as usize);
|
||||
let optional_index = open_optional_index(optional_index_bytes)?;
|
||||
@@ -191,8 +185,8 @@ impl MultiValueIndex {
|
||||
};
|
||||
let mut buffer = Vec::new();
|
||||
serialize_multivalued_index(&serializable_multivalued_index, &mut buffer).unwrap();
|
||||
let file_slice = FileSlice::from(buffer);
|
||||
open_multivalued_index(file_slice, Version::V2).unwrap()
|
||||
let bytes = OwnedBytes::new(buffer);
|
||||
open_multivalued_index(bytes, Version::V2).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_start_index_column(&self) -> &Arc<dyn crate::ColumnValues<RowId>> {
|
||||
@@ -339,7 +333,7 @@ mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use super::MultiValueIndex;
|
||||
use crate::{ColumnarReader, DynamicColumn, ValueRange};
|
||||
use crate::{ColumnarReader, DynamicColumn};
|
||||
|
||||
fn index_to_pos_helper(
|
||||
index: &MultiValueIndex,
|
||||
@@ -419,7 +413,7 @@ mod tests {
|
||||
assert_eq!(row_id_range, 0..4);
|
||||
|
||||
let check = |range, expected| {
|
||||
let full_range = ValueRange::All;
|
||||
let full_range = 0..=u64::MAX;
|
||||
let mut docids = Vec::new();
|
||||
column.get_docids_for_value_range(full_range, range, &mut docids);
|
||||
assert_eq!(docids, expected);
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::sync::Arc;
|
||||
mod set;
|
||||
mod set_block;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, OwnedBytes, VInt};
|
||||
pub use set::{SelectCursor, Set, SetCodec};
|
||||
use set_block::{
|
||||
@@ -269,8 +268,8 @@ impl OptionalIndex {
|
||||
);
|
||||
let mut buffer = Vec::new();
|
||||
serialize_optional_index(&row_ids, num_rows, &mut buffer).unwrap();
|
||||
let file_slice = FileSlice::from(buffer);
|
||||
open_optional_index(file_slice).unwrap()
|
||||
let bytes = OwnedBytes::new(buffer);
|
||||
open_optional_index(bytes).unwrap()
|
||||
}
|
||||
|
||||
pub fn num_docs(&self) -> RowId {
|
||||
@@ -487,17 +486,10 @@ fn deserialize_optional_index_block_metadatas(
|
||||
(block_metas.into_boxed_slice(), non_null_rows_before_block)
|
||||
}
|
||||
|
||||
pub fn open_optional_index(file_slice: FileSlice) -> io::Result<OptionalIndex> {
|
||||
let (bytes, num_non_empty_blocks_bytes) = file_slice.split_from_end(2);
|
||||
let num_non_empty_block_bytes = u16::from_le_bytes(
|
||||
num_non_empty_blocks_bytes
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut bytes = bytes.read_bytes()?;
|
||||
pub fn open_optional_index(bytes: OwnedBytes) -> io::Result<OptionalIndex> {
|
||||
let (mut bytes, num_non_empty_blocks_bytes) = bytes.rsplit(2);
|
||||
let num_non_empty_block_bytes =
|
||||
u16::from_le_bytes(num_non_empty_blocks_bytes.as_slice().try_into().unwrap());
|
||||
let num_docs = VInt::deserialize_u64(&mut bytes)? as u32;
|
||||
let block_metas_num_bytes =
|
||||
num_non_empty_block_bytes as usize * SERIALIZED_BLOCK_META_NUM_BYTES;
|
||||
|
||||
@@ -59,7 +59,7 @@ fn test_with_random_sets_simple() {
|
||||
let vals = 10..ELEMENTS_PER_BLOCK * 2;
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
serialize_optional_index(&vals, 100, &mut out).unwrap();
|
||||
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
|
||||
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
|
||||
let ranks: Vec<u32> = (65_472u32..65_473u32).collect();
|
||||
let els: Vec<u32> = ranks.iter().copied().map(|rank| rank + 10).collect();
|
||||
let mut select_cursor = null_index.select_cursor();
|
||||
@@ -102,7 +102,7 @@ impl<'a> Iterable<RowId> for &'a [bool] {
|
||||
fn test_null_index(data: &[bool]) {
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
serialize_optional_index(&data, data.len() as RowId, &mut out).unwrap();
|
||||
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
|
||||
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
|
||||
let orig_idx_with_value: Vec<u32> = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -223,170 +223,3 @@ fn test_optional_index_for_tests() {
|
||||
assert!(!optional_index.contains(3));
|
||||
assert_eq!(optional_index.num_docs(), 4);
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::*;
|
||||
|
||||
const TOTAL_NUM_VALUES: u32 = 1_000_000;
|
||||
fn gen_bools(fill_ratio: f64) -> OptionalIndex {
|
||||
let mut out = Vec::new();
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let vals: Vec<RowId> = (0..TOTAL_NUM_VALUES)
|
||||
.map(|_| rng.gen_bool(fill_ratio))
|
||||
.enumerate()
|
||||
.filter(|(_pos, val)| *val)
|
||||
.map(|(pos, _)| pos as RowId)
|
||||
.collect();
|
||||
serialize_optional_index(&&vals[..], TOTAL_NUM_VALUES, &mut out).unwrap();
|
||||
|
||||
open_optional_index(FileSlice::from(out)).unwrap()
|
||||
}
|
||||
|
||||
fn random_range_iterator(
|
||||
start: u32,
|
||||
end: u32,
|
||||
avg_step_size: u32,
|
||||
avg_deviation: u32,
|
||||
) -> impl Iterator<Item = u32> {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let mut current = start;
|
||||
std::iter::from_fn(move || {
|
||||
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
if current >= end { None } else { Some(current) }
|
||||
})
|
||||
}
|
||||
|
||||
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
|
||||
let ratio = percent / 100.0;
|
||||
let step_size = (1f32 / ratio) as u32;
|
||||
let deviation = step_size - 1;
|
||||
random_range_iterator(0, num_values, step_size, deviation)
|
||||
}
|
||||
|
||||
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
|
||||
walk_over_data_from_positions(
|
||||
codec,
|
||||
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
|
||||
)
|
||||
}
|
||||
|
||||
fn walk_over_data_from_positions(
|
||||
codec: &OptionalIndex,
|
||||
positions: impl Iterator<Item = u32>,
|
||||
) -> Option<u32> {
|
||||
let mut dense_idx: Option<u32> = None;
|
||||
for idx in positions {
|
||||
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
|
||||
}
|
||||
dense_idx
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 1000));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_1percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_10percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_90percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_10percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_50percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.5f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_90percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_10percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.1f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 10f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 100f32, bench);
|
||||
}
|
||||
|
||||
fn bench_translate_codec_to_orig_util(
|
||||
percent_filled: f64,
|
||||
percent_hit: f32,
|
||||
bench: &mut Bencher,
|
||||
) {
|
||||
let codec = gen_bools(percent_filled);
|
||||
let num_non_nulls = codec.num_non_nulls();
|
||||
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
|
||||
(0..num_non_nulls).collect()
|
||||
} else {
|
||||
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
|
||||
};
|
||||
let mut output = vec![0u32; idxs.len()];
|
||||
bench.iter(|| {
|
||||
output.copy_from_slice(&idxs[..]);
|
||||
codec.select_batch(&mut output);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 0.005, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 100.0f32, bench);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{CountingWriter, HasLen};
|
||||
use common::{CountingWriter, OwnedBytes};
|
||||
|
||||
use super::OptionalIndex;
|
||||
use super::multivalued_index::SerializableMultivalueIndex;
|
||||
@@ -66,28 +65,27 @@ pub fn serialize_column_index(
|
||||
|
||||
/// Open a serialized column index.
|
||||
pub fn open_column_index(
|
||||
file_slice: FileSlice,
|
||||
mut bytes: OwnedBytes,
|
||||
format_version: Version,
|
||||
) -> io::Result<ColumnIndex> {
|
||||
if file_slice.len() == 0 {
|
||||
if bytes.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Failed to deserialize column index. Empty buffer.",
|
||||
));
|
||||
}
|
||||
let (header, body) = file_slice.split(1);
|
||||
let cardinality_code = header.read_bytes()?.as_slice()[0];
|
||||
let cardinality_code = bytes[0];
|
||||
let cardinality = Cardinality::try_from_code(cardinality_code)?;
|
||||
|
||||
bytes.advance(1);
|
||||
match cardinality {
|
||||
Cardinality::Full => Ok(ColumnIndex::Full),
|
||||
Cardinality::Optional => {
|
||||
let optional_index = super::optional_index::open_optional_index(body)?;
|
||||
let optional_index = super::optional_index::open_optional_index(bytes)?;
|
||||
Ok(ColumnIndex::Optional(optional_index))
|
||||
}
|
||||
Cardinality::Multivalued => {
|
||||
let multivalue_index =
|
||||
super::multivalued_index::open_multivalued_index(body, format_version)?;
|
||||
super::multivalued_index::open_multivalued_index(bytes, format_version)?;
|
||||
Ok(ColumnIndex::Multivalued(multivalue_index))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,15 +7,13 @@
|
||||
//! - Monotonically map values to u64/u128
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::sync::Arc;
|
||||
|
||||
use downcast_rs::DowncastSync;
|
||||
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
|
||||
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
|
||||
|
||||
use crate::column::ValueRange;
|
||||
|
||||
mod merge;
|
||||
pub(crate) mod monotonic_mapping;
|
||||
pub(crate) mod monotonic_mapping_u128;
|
||||
@@ -29,7 +27,8 @@ mod monotonic_column;
|
||||
pub(crate) use merge::MergedColumnValues;
|
||||
pub use stats::ColumnStats;
|
||||
pub use u64_based::{
|
||||
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values, serialize_u64_based_column_values,
|
||||
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values,
|
||||
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
|
||||
};
|
||||
pub use u128_based::{
|
||||
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
|
||||
@@ -110,307 +109,6 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the values for the provided docids.
|
||||
///
|
||||
/// The values are filtered by the provided value range.
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
let len = input_indexes.len();
|
||||
let mut read_head = 0;
|
||||
|
||||
match value_range {
|
||||
ValueRange::All => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::Inclusive(ref range) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if range.contains(&val0) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if range.contains(&val1) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if range.contains(&val2) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if range.contains(&val3) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process remaining elements (0 to 3)
|
||||
while read_head < len {
|
||||
let idx = input_indexes[read_head];
|
||||
let doc = input_doc_ids[read_head];
|
||||
let val = self.get_val(idx);
|
||||
let matches = match value_range {
|
||||
// 'value_range' is still moved here. This is the outer `value_range`
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(ref r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(ref t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(ref t, _) => val >= *t,
|
||||
ValueRange::LessThan(ref t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(ref t, _) => val <= *t,
|
||||
};
|
||||
if matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
read_head += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills an output buffer with the fast field values
|
||||
/// associated with the `DocId` going from
|
||||
/// `start` to `start + output.len()`.
|
||||
@@ -431,54 +129,15 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
|
||||
/// Note that position == docid for single value fast fields
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: ValueRange<T>,
|
||||
value_range: RangeInclusive<T>,
|
||||
row_id_range: Range<RowId>,
|
||||
row_id_hits: &mut Vec<RowId>,
|
||||
) {
|
||||
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
|
||||
match value_range {
|
||||
ValueRange::Inclusive(range) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if range.contains(&val) {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val > threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val >= threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val < threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val <= threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::All => {
|
||||
row_id_hits.extend(row_id_range);
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if value_range.contains(&val) {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -534,17 +193,6 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
|
||||
fn num_vals(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
let _ = (input_indexes, input_doc_ids, output, value_range);
|
||||
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
|
||||
@@ -558,18 +206,6 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
|
||||
self.as_ref().get_vals_opt(indexes, output)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
self.as_ref()
|
||||
.get_vals_in_value_range(input_indexes, input_doc_ids, output, value_range)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn min_value(&self) -> T {
|
||||
self.as_ref().min_value()
|
||||
@@ -598,7 +234,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
|
||||
#[inline(always)]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: ValueRange<T>,
|
||||
range: RangeInclusive<T>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
|
||||
use crate::ColumnValues;
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
|
||||
|
||||
struct MonotonicMappingColumn<C, T, Input> {
|
||||
@@ -81,52 +80,16 @@ where
|
||||
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: ValueRange<Output>,
|
||||
range: RangeInclusive<Output>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
match range {
|
||||
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(
|
||||
self.monotonic_mapping.inverse(range.start().clone())
|
||||
..=self.monotonic_mapping.inverse(range.end().clone()),
|
||||
),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::All => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::All,
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::GreaterThanOrEqual(
|
||||
self.monotonic_mapping.inverse(threshold),
|
||||
false,
|
||||
),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::LessThanOrEqual(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
}
|
||||
}
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
self.monotonic_mapping.inverse(range.start().clone())
|
||||
..=self.monotonic_mapping.inverse(range.end().clone()),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
}
|
||||
|
||||
// We voluntarily do not implement get_range as it yields a regression,
|
||||
|
||||
@@ -2,8 +2,7 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, HasLen, VInt};
|
||||
use common::{BinarySerializable, VInt};
|
||||
|
||||
use crate::RowId;
|
||||
|
||||
@@ -28,55 +27,6 @@ impl ColumnStats {
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnStats {
|
||||
/// Deserialize from the tail of the given FileSlice, and return the stats and remaining prefix
|
||||
/// FileSlice.
|
||||
pub fn deserialize_from_tail(file_slice: FileSlice) -> io::Result<(Self, FileSlice)> {
|
||||
// [`deserialize_with_size`] deserializes 4 variable-width encoded u64s, which
|
||||
// could end up being, in the worst case, 9 bytes each. this is where the 36 comes from
|
||||
let (stats, _) = file_slice.clone().split(36.min(file_slice.len())); // hope that's enough bytes
|
||||
let mut stats = stats.read_bytes()?;
|
||||
let (stats, stats_nbytes) = ColumnStats::deserialize_with_size(&mut stats)?;
|
||||
let (_, remainder) = file_slice.split(stats_nbytes);
|
||||
Ok((stats, remainder))
|
||||
}
|
||||
|
||||
/// Same as [`BinarySeerializable::deserialize`] but also returns the number of bytes
|
||||
/// consumed from the reader `R`
|
||||
fn deserialize_with_size<R: io::Read>(reader: &mut R) -> io::Result<(Self, usize)> {
|
||||
let mut nbytes = 0;
|
||||
|
||||
let (min_value, len) = VInt::deserialize_with_size(reader)?;
|
||||
let min_value = min_value.0;
|
||||
nbytes += len;
|
||||
|
||||
let (gcd, len) = VInt::deserialize_with_size(reader)?;
|
||||
let gcd = gcd.0;
|
||||
let gcd = NonZeroU64::new(gcd)
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "GCD of 0 is forbidden"))?;
|
||||
nbytes += len;
|
||||
|
||||
let (amplitude, len) = VInt::deserialize_with_size(reader)?;
|
||||
let amplitude = amplitude.0 * gcd.get();
|
||||
let max_value = min_value + amplitude;
|
||||
nbytes += len;
|
||||
|
||||
let (num_rows, len) = VInt::deserialize_with_size(reader)?;
|
||||
let num_rows = num_rows.0 as RowId;
|
||||
nbytes += len;
|
||||
|
||||
Ok((
|
||||
ColumnStats {
|
||||
min_value,
|
||||
max_value,
|
||||
num_rows,
|
||||
gcd,
|
||||
},
|
||||
nbytes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl BinarySerializable for ColumnStats {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
|
||||
VInt(self.min_value).serialize(writer)?;
|
||||
|
||||
@@ -25,7 +25,6 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker};
|
||||
|
||||
use crate::RowId;
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::ColumnValues;
|
||||
|
||||
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
|
||||
@@ -339,48 +338,14 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
|
||||
#[inline]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: ValueRange<u64>,
|
||||
value_range: RangeInclusive<u64>,
|
||||
position_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
match value_range {
|
||||
ValueRange::Inclusive(value_range) => {
|
||||
let value_range = ValueRange::Inclusive(
|
||||
self.0.compact_to_u128(*value_range.start() as u32)
|
||||
..=self.0.compact_to_u128(*value_range.end() as u32),
|
||||
);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::All => {
|
||||
let position_range = position_range.start..position_range.end.min(self.num_vals());
|
||||
positions.extend(position_range);
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::GreaterThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::LessThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
}
|
||||
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
|
||||
..=self.0.compact_to_u128(*value_range.end() as u32);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,47 +375,10 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
|
||||
#[inline]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: ValueRange<u128>,
|
||||
value_range: RangeInclusive<u128>,
|
||||
position_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
let value_range = match value_range {
|
||||
ValueRange::Inclusive(value_range) => value_range,
|
||||
ValueRange::All => {
|
||||
let position_range = position_range.start..position_range.end.min(self.num_vals());
|
||||
positions.extend(position_range);
|
||||
return;
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
let max = self.max_value();
|
||||
if threshold >= max {
|
||||
return;
|
||||
}
|
||||
(threshold + 1)..=max
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
let max = self.max_value();
|
||||
if threshold > max {
|
||||
return;
|
||||
}
|
||||
threshold..=max
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
let min = self.min_value();
|
||||
if threshold <= min {
|
||||
return;
|
||||
}
|
||||
min..=(threshold - 1)
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
let min = self.min_value();
|
||||
if threshold < min {
|
||||
return;
|
||||
}
|
||||
min..=threshold
|
||||
}
|
||||
};
|
||||
|
||||
if value_range.start() > value_range.end() {
|
||||
return;
|
||||
}
|
||||
@@ -632,7 +560,7 @@ mod tests {
|
||||
.collect::<Vec<_>>();
|
||||
let mut positions = Vec::new();
|
||||
decompressor.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(range),
|
||||
range,
|
||||
0..decompressor.num_vals(),
|
||||
&mut positions,
|
||||
);
|
||||
@@ -676,11 +604,7 @@ mod tests {
|
||||
let val = *val;
|
||||
let pos = pos as u32;
|
||||
let mut positions = Vec::new();
|
||||
decomp.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(val..=val),
|
||||
pos..pos + 1,
|
||||
&mut positions,
|
||||
);
|
||||
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
|
||||
assert_eq!(positions, vec![pos]);
|
||||
}
|
||||
|
||||
@@ -822,11 +746,7 @@ mod tests {
|
||||
doc_id_range: Range<u32>,
|
||||
) -> Vec<u32> {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(value_range),
|
||||
doc_id_range,
|
||||
&mut positions,
|
||||
);
|
||||
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
|
||||
positions
|
||||
}
|
||||
|
||||
@@ -849,7 +769,7 @@ mod tests {
|
||||
];
|
||||
let mut out = Vec::new();
|
||||
serialize_column_values_u128(&&vals[..], &mut out).unwrap();
|
||||
let decomp = open_u128_mapped(FileSlice::from(out)).unwrap();
|
||||
let decomp = open_u128_mapped(OwnedBytes::new(out)).unwrap();
|
||||
let complete_range = 0..vals.len() as u32;
|
||||
|
||||
assert_eq!(
|
||||
@@ -903,7 +823,6 @@ mod tests {
|
||||
let _data = test_aux_vals(vals);
|
||||
}
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn num_strategy() -> impl Strategy<Value = u128> {
|
||||
|
||||
@@ -5,8 +5,7 @@ use std::sync::Arc;
|
||||
|
||||
mod compact_space;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, VInt};
|
||||
use common::{BinarySerializable, OwnedBytes, VInt};
|
||||
pub use compact_space::{
|
||||
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
|
||||
};
|
||||
@@ -102,9 +101,8 @@ impl U128FastFieldCodecType {
|
||||
|
||||
/// Returns the correct codec reader wrapped in the `Arc` for the data.
|
||||
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
|
||||
file_slice: FileSlice,
|
||||
mut bytes: OwnedBytes,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let mut bytes = file_slice.read_bytes()?;
|
||||
let header = U128Header::deserialize(&mut bytes)?;
|
||||
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
|
||||
let reader = CompactSpaceDecompressor::open(bytes)?;
|
||||
@@ -122,8 +120,7 @@ pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
|
||||
/// # Notice
|
||||
/// In case there are new codecs added, check for usages of `CompactSpaceDecompressorU64` and
|
||||
/// also handle the new codecs.
|
||||
pub fn open_u128_as_compact_u64(file_slice: FileSlice) -> io::Result<Arc<dyn ColumnValues<u64>>> {
|
||||
let mut bytes = file_slice.read_bytes()?;
|
||||
pub fn open_u128_as_compact_u64(mut bytes: OwnedBytes) -> io::Result<Arc<dyn ColumnValues<u64>>> {
|
||||
let header = U128Header::deserialize(&mut bytes)?;
|
||||
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
|
||||
let reader = CompactSpaceU64Accessor::open(bytes)?;
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
use std::io::{self, Write};
|
||||
use std::num::NonZeroU64;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, HasLen, OwnedBytes};
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
use fastdivide::DividerU64;
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
|
||||
use crate::{ColumnValues, RowId};
|
||||
|
||||
@@ -16,40 +13,9 @@ use crate::{ColumnValues, RowId};
|
||||
/// fast field is required.
|
||||
#[derive(Clone)]
|
||||
pub struct BitpackedReader {
|
||||
data: FileSlice,
|
||||
data: OwnedBytes,
|
||||
bit_unpacker: BitUnpacker,
|
||||
stats: ColumnStats,
|
||||
blocks: Arc<[OnceLock<Block>]>,
|
||||
}
|
||||
|
||||
impl BitpackedReader {
|
||||
#[inline(always)]
|
||||
fn unpack_val(&self, doc: u32) -> u64 {
|
||||
let block_num = self.bit_unpacker.block_num(doc);
|
||||
|
||||
if block_num == 0 && self.blocks.len() == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let block = self.blocks[block_num].get_or_init(|| {
|
||||
let block_range = self.bit_unpacker.block(block_num, self.data.len());
|
||||
let offset = block_range.start;
|
||||
let data = self
|
||||
.data
|
||||
.slice(block_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
Block { offset, data }
|
||||
});
|
||||
|
||||
self.bit_unpacker
|
||||
.get_from_subset(doc, block.offset, &block.data)
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
offset: usize,
|
||||
data: OwnedBytes,
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -75,12 +41,6 @@ fn transform_range_before_linear_transformation(
|
||||
if range.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if stats.min_value > *range.end() {
|
||||
return None;
|
||||
}
|
||||
if stats.max_value < *range.start() {
|
||||
return None;
|
||||
}
|
||||
let shifted_range =
|
||||
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
|
||||
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
|
||||
@@ -91,9 +51,8 @@ fn transform_range_before_linear_transformation(
|
||||
impl ColumnValues for BitpackedReader {
|
||||
#[inline(always)]
|
||||
fn get_val(&self, doc: u32) -> u64 {
|
||||
self.stats.min_value + self.stats.gcd.get() * self.unpack_val(doc)
|
||||
self.stats.min_value + self.stats.gcd.get() * self.bit_unpacker.get(doc, &self.data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn min_value(&self) -> u64 {
|
||||
self.stats.min_value
|
||||
@@ -107,329 +66,24 @@ impl ColumnValues for BitpackedReader {
|
||||
self.stats.num_rows
|
||||
}
|
||||
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
|
||||
value_range: ValueRange<u64>,
|
||||
) {
|
||||
match value_range {
|
||||
ValueRange::All => {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
}
|
||||
ValueRange::Inclusive(range) => {
|
||||
if let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
{
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if transformed_range.contains(&raw_val) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
if threshold < self.stats.min_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold >= self.stats.max_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val > raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
if threshold <= self.stats.min_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold > self.stats.max_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = (diff + gcd - 1) / gcd;
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val >= raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
if threshold > self.stats.max_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold <= self.stats.min_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = if diff % gcd == 0 {
|
||||
diff / gcd
|
||||
} else {
|
||||
diff / gcd + 1
|
||||
};
|
||||
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val < raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
if threshold >= self.stats.max_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold < self.stats.min_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = diff / gcd;
|
||||
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val <= raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: ValueRange<u64>,
|
||||
range: RangeInclusive<u64>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
match range {
|
||||
ValueRange::All => {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
ValueRange::Inclusive(range) => {
|
||||
let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
else {
|
||||
positions.clear();
|
||||
return;
|
||||
};
|
||||
// TODO: This does not use the `self.blocks` cache, because callers are usually
|
||||
// already doing sequential, and fairly dense reads. Fix it to
|
||||
// iterate over blocks if that assumption turns out to be incorrect!
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
if threshold < self.stats.min_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold >= self.stats.max_value {
|
||||
return;
|
||||
}
|
||||
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
|
||||
// We want raw > raw_threshold.
|
||||
// bit_unpacker.get_ids_for_value_range_from_subset takes a RangeInclusive.
|
||||
// We can construct a RangeInclusive: (raw_threshold + 1) ..= u64::MAX
|
||||
// But max raw value is known? (max_value - min_value) / gcd.
|
||||
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
|
||||
let transformed_range = (raw_threshold + 1)..=max_raw;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
if threshold <= self.stats.min_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold > self.stats.max_value {
|
||||
return;
|
||||
}
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = (diff + gcd - 1) / gcd;
|
||||
// We want raw >= raw_threshold.
|
||||
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
|
||||
let transformed_range = raw_threshold..=max_raw;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
if threshold > self.stats.max_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold <= self.stats.min_value {
|
||||
return;
|
||||
}
|
||||
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
// We want raw < raw_threshold_limit
|
||||
// raw <= raw_threshold_limit - 1
|
||||
let raw_threshold_limit = if diff % gcd == 0 {
|
||||
diff / gcd
|
||||
} else {
|
||||
diff / gcd + 1
|
||||
};
|
||||
|
||||
if raw_threshold_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let transformed_range = 0..=(raw_threshold_limit - 1);
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
if threshold >= self.stats.max_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold < self.stats.min_value {
|
||||
return;
|
||||
}
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
// We want raw <= raw_threshold.
|
||||
let raw_threshold = diff / gcd;
|
||||
let transformed_range = 0..=raw_threshold;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
}
|
||||
let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
else {
|
||||
positions.clear();
|
||||
return;
|
||||
};
|
||||
self.bit_unpacker.get_ids_for_value_range(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
&self.data,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,20 +127,14 @@ impl ColumnCodec for BitpackedCodec {
|
||||
type Estimator = BitpackedCodecEstimator;
|
||||
|
||||
/// Opens a fast field given a file.
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let (stats, data) = ColumnStats::deserialize_from_tail(file_slice)?;
|
||||
|
||||
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
let stats = ColumnStats::deserialize(&mut data)?;
|
||||
let num_bits = num_bits(&stats);
|
||||
let bit_unpacker = BitUnpacker::new(num_bits);
|
||||
let block_count = bit_unpacker.block_count(data.len());
|
||||
Ok(BitpackedReader {
|
||||
data,
|
||||
bit_unpacker,
|
||||
stats,
|
||||
blocks: (0..block_count)
|
||||
.into_iter()
|
||||
.map(|_| OnceLock::new())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use std::{io, iter};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, HasLen, OwnedBytes};
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, OwnedBytes};
|
||||
use fastdivide::DividerU64;
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
@@ -174,63 +172,32 @@ impl ColumnCodec<u64> for BlockwiseLinearCodec {
|
||||
|
||||
type Estimator = BlockwiseLinearEstimator;
|
||||
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let (stats, body) = ColumnStats::deserialize_from_tail(file_slice)?;
|
||||
|
||||
let (_, footer) = body.clone().split_from_end(4);
|
||||
|
||||
let footer_len: u32 = footer.read_bytes()?.as_slice().deserialize()?;
|
||||
let (data, footer) = body.split_from_end(footer_len as usize + 4);
|
||||
|
||||
let mut footer = footer.read_bytes()?;
|
||||
fn load(mut bytes: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
let stats = ColumnStats::deserialize(&mut bytes)?;
|
||||
let footer_len: u32 = (&bytes[bytes.len() - 4..]).deserialize()?;
|
||||
let footer_offset = bytes.len() - 4 - footer_len as usize;
|
||||
let (data, mut footer) = bytes.split(footer_offset);
|
||||
let num_blocks = compute_num_blocks(stats.num_rows);
|
||||
|
||||
let mut blocks: Vec<Block> = iter::repeat_with(|| Block::deserialize(&mut footer))
|
||||
.take(num_blocks as usize)
|
||||
.collect::<io::Result<_>>()?;
|
||||
let mut start_offset = 0;
|
||||
let mut blocks = Vec::with_capacity(num_blocks as usize);
|
||||
|
||||
for _ in 0..num_blocks {
|
||||
let mut block = Block::deserialize(&mut footer)?;
|
||||
let len = (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
|
||||
|
||||
for block in &mut blocks {
|
||||
block.data_start_offset = start_offset;
|
||||
blocks.push(BlockWithData {
|
||||
block,
|
||||
file_slice: data.slice(start_offset..(start_offset + len).min(data.len())),
|
||||
data: Default::default(),
|
||||
});
|
||||
|
||||
start_offset += len;
|
||||
start_offset += (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
|
||||
}
|
||||
Ok(BlockwiseLinearReader {
|
||||
blocks: blocks.into_boxed_slice().into(),
|
||||
data,
|
||||
stats,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct BlockWithData {
|
||||
block: Block,
|
||||
file_slice: FileSlice,
|
||||
data: OnceLock<OwnedBytes>,
|
||||
}
|
||||
|
||||
impl Deref for BlockWithData {
|
||||
type Target = Block;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.block
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for BlockWithData {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.block
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BlockwiseLinearReader {
|
||||
blocks: Arc<[BlockWithData]>,
|
||||
blocks: Arc<[Block]>,
|
||||
data: OwnedBytes,
|
||||
stats: ColumnStats,
|
||||
}
|
||||
|
||||
@@ -241,9 +208,7 @@ impl ColumnValues for BlockwiseLinearReader {
|
||||
let idx_within_block = idx % BLOCK_SIZE;
|
||||
let block = &self.blocks[block_id];
|
||||
let interpoled_val: u64 = block.line.eval(idx_within_block);
|
||||
let block_bytes = block
|
||||
.data
|
||||
.get_or_init(|| block.file_slice.read_bytes().unwrap());
|
||||
let block_bytes = &self.data[block.data_start_offset..];
|
||||
let bitpacked_diff = block.bit_unpacker.get(idx_within_block, block_bytes);
|
||||
// TODO optimize me! the line parameters could be tweaked to include the multiplication and
|
||||
// remove the dependency.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::io;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
@@ -191,8 +190,7 @@ impl ColumnCodec for LinearCodec {
|
||||
|
||||
type Estimator = LinearCodecEstimator;
|
||||
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let mut data = file_slice.read_bytes()?;
|
||||
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
let stats = ColumnStats::deserialize(&mut data)?;
|
||||
let linear_params = LinearParams::deserialize(&mut data)?;
|
||||
Ok(LinearReader {
|
||||
|
||||
@@ -8,8 +8,7 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::BinarySerializable;
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
|
||||
use crate::column_values::monotonic_mapping::{
|
||||
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
|
||||
@@ -61,7 +60,7 @@ pub trait ColumnCodec<T: PartialOrd = u64> {
|
||||
type Estimator: ColumnCodecEstimator + Default;
|
||||
|
||||
/// Loads a column that has been serialized using this codec.
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues>;
|
||||
fn load(bytes: OwnedBytes) -> io::Result<Self::ColumnValues>;
|
||||
|
||||
/// Returns an estimator.
|
||||
fn estimator() -> Self::Estimator {
|
||||
@@ -112,22 +111,20 @@ impl CodecType {
|
||||
|
||||
fn load<T: MonotonicallyMappableToU64>(
|
||||
&self,
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
match self {
|
||||
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(file_slice),
|
||||
CodecType::Linear => load_specific_codec::<LinearCodec, T>(file_slice),
|
||||
CodecType::BlockwiseLinear => {
|
||||
load_specific_codec::<BlockwiseLinearCodec, T>(file_slice)
|
||||
}
|
||||
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(bytes),
|
||||
CodecType::Linear => load_specific_codec::<LinearCodec, T>(bytes),
|
||||
CodecType::BlockwiseLinear => load_specific_codec::<BlockwiseLinearCodec, T>(bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_specific_codec<C: ColumnCodec, T: MonotonicallyMappableToU64>(
|
||||
file_slice: FileSlice,
|
||||
bytes: OwnedBytes,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let reader = C::load(file_slice)?;
|
||||
let reader = C::load(bytes)?;
|
||||
let reader_typed = monotonic_map_column(
|
||||
reader,
|
||||
StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::<T>::new()),
|
||||
@@ -192,28 +189,25 @@ pub fn serialize_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
///
|
||||
/// This method first identifies the codec off the first byte.
|
||||
pub fn load_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
file_slice: FileSlice,
|
||||
mut bytes: OwnedBytes,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let (header, body) = file_slice.split(1);
|
||||
let codec_type: CodecType = header
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.get(0)
|
||||
.cloned()
|
||||
let codec_type: CodecType = bytes
|
||||
.first()
|
||||
.copied()
|
||||
.and_then(CodecType::try_from_code)
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to read codec type"))?;
|
||||
codec_type.load(body)
|
||||
bytes.advance(1);
|
||||
codec_type.load(bytes)
|
||||
}
|
||||
|
||||
/// Helper function to serialize a column (autodetect from all codecs) and then open it
|
||||
#[cfg(test)]
|
||||
pub fn serialize_and_load_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
vals: &dyn Iterable,
|
||||
codec_types: &[CodecType],
|
||||
) -> Arc<dyn ColumnValues<T>> {
|
||||
let mut buffer = Vec::new();
|
||||
serialize_u64_based_column_values(vals, codec_types, &mut buffer).unwrap();
|
||||
load_u64_based_column_values::<T>(FileSlice::from(buffer)).unwrap()
|
||||
load_u64_based_column_values::<T>(OwnedBytes::new(buffer)).unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use common::HasLen;
|
||||
use proptest::prelude::*;
|
||||
use proptest::{prop_oneof, proptest};
|
||||
use rand::Rng;
|
||||
@@ -14,7 +13,7 @@ fn test_serialize_and_load_simple() {
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(buffer.len(), 7);
|
||||
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 3);
|
||||
assert_eq!(col.get_val(0), 1);
|
||||
assert_eq!(col.get_val(1), 2);
|
||||
@@ -31,7 +30,7 @@ fn test_empty_column_i64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<i64>(FileSlice::from(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<i64>(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
assert_eq!(col.min_value(), i64::MIN);
|
||||
assert_eq!(col.max_value(), i64::MIN);
|
||||
@@ -49,7 +48,7 @@ fn test_empty_column_u64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
assert_eq!(col.min_value(), u64::MIN);
|
||||
assert_eq!(col.max_value(), u64::MIN);
|
||||
@@ -67,7 +66,7 @@ fn test_empty_column_f64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<f64>(FileSlice::from(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<f64>(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
// FIXME. f64::MIN would be better!
|
||||
assert!(col.min_value().is_nan());
|
||||
@@ -98,7 +97,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
|
||||
let actual_compression = buffer.len() as u64;
|
||||
|
||||
let reader = TColumnCodec::load(FileSlice::from(buffer)).unwrap();
|
||||
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(reader.num_vals(), vals.len() as u32);
|
||||
let mut buffer = Vec::new();
|
||||
for (doc, orig_val) in vals.iter().copied().enumerate() {
|
||||
@@ -132,7 +131,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
.collect();
|
||||
let mut positions = Vec::new();
|
||||
reader.get_row_ids_for_value_range(
|
||||
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
|
||||
vals[test_rand_idx]..=vals[test_rand_idx],
|
||||
0..vals.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
@@ -327,7 +326,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer,
|
||||
)?;
|
||||
let buffer = FileSlice::from(buffer);
|
||||
let buffer = OwnedBytes::new(buffer);
|
||||
let column = crate::column_values::load_u64_based_column_values::<i64>(buffer.clone())?;
|
||||
assert_eq!(column.get_val(0), -4000i64);
|
||||
assert_eq!(column.get_val(1), -3000i64);
|
||||
@@ -344,7 +343,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer_without_gcd,
|
||||
)?;
|
||||
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
|
||||
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
|
||||
assert!(buffer_without_gcd.len() > buffer.len());
|
||||
|
||||
Ok(())
|
||||
@@ -370,7 +369,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer,
|
||||
)?;
|
||||
let buffer = FileSlice::from(buffer);
|
||||
let buffer = OwnedBytes::new(buffer);
|
||||
let column = crate::column_values::load_u64_based_column_values::<u64>(buffer.clone())?;
|
||||
assert_eq!(column.get_val(0), 1000u64);
|
||||
assert_eq!(column.get_val(1), 2000u64);
|
||||
@@ -387,7 +386,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer_without_gcd,
|
||||
)?;
|
||||
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
|
||||
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
|
||||
assert!(buffer_without_gcd.len() > buffer.len());
|
||||
Ok(())
|
||||
}
|
||||
@@ -406,7 +405,7 @@ fn test_fastfield_gcd_u64() -> io::Result<()> {
|
||||
|
||||
#[test]
|
||||
pub fn test_fastfield2() {
|
||||
let test_fastfield = serialize_and_load_u64_based_column_values::<u64>(
|
||||
let test_fastfield = crate::column_values::serialize_and_load_u64_based_column_values::<u64>(
|
||||
&&[100u64, 200u64, 300u64][..],
|
||||
&ALL_U64_CODEC_TYPES,
|
||||
);
|
||||
|
||||
@@ -4,7 +4,6 @@ mod term_merger;
|
||||
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -79,7 +78,6 @@ pub fn merge_columnar(
|
||||
required_columns: &[(String, ColumnType)],
|
||||
merge_row_order: MergeRowOrder,
|
||||
output: &mut impl io::Write,
|
||||
cancel: impl Fn() -> bool,
|
||||
) -> io::Result<()> {
|
||||
let mut serializer = ColumnarSerializer::new(output);
|
||||
let num_docs_per_columnar = columnar_readers
|
||||
@@ -89,9 +87,6 @@ pub fn merge_columnar(
|
||||
|
||||
let columns_to_merge = group_columns_for_merge(columnar_readers, required_columns)?;
|
||||
for res in columns_to_merge {
|
||||
if cancel() {
|
||||
return Err(io::Error::new(ErrorKind::Interrupted, "Merge cancelled"));
|
||||
}
|
||||
let ((column_name, _column_type_category), grouped_columns) = res;
|
||||
let grouped_columns = grouped_columns.open(&merge_row_order)?;
|
||||
if grouped_columns.is_empty() {
|
||||
|
||||
@@ -205,7 +205,6 @@ fn test_merge_columnar_numbers() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -234,7 +233,6 @@ fn test_merge_columnar_texts() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -284,7 +282,6 @@ fn test_merge_columnar_byte() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -341,7 +338,6 @@ fn test_merge_columnar_byte_with_missing() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -394,7 +390,6 @@ fn test_merge_columnar_different_types() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -460,7 +455,6 @@ fn test_merge_columnar_different_empty_cardinality() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -571,7 +565,6 @@ proptest! {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut out,
|
||||
|| false,
|
||||
).unwrap();
|
||||
|
||||
let merged_reader = ColumnarReader::open(out).unwrap();
|
||||
@@ -589,7 +582,6 @@ proptest! {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut out,
|
||||
|| false,
|
||||
).unwrap();
|
||||
|
||||
}
|
||||
|
||||
@@ -71,14 +71,7 @@ fn test_format(path: &str) {
|
||||
let columnar_readers = vec![&reader, &reader2];
|
||||
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
|
||||
let mut out = Vec::new();
|
||||
merge_columnar(
|
||||
&columnar_readers,
|
||||
&[],
|
||||
merge_row_order.into(),
|
||||
&mut out,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
|
||||
let reader = ColumnarReader::open(out).unwrap();
|
||||
check_columns(&reader);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{ByteCount, DateTime};
|
||||
use common::{ByteCount, DateTime, OwnedBytes};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::column::{BytesColumn, Column, StrColumn};
|
||||
@@ -239,7 +239,8 @@ pub struct DynamicColumnHandle {
|
||||
impl DynamicColumnHandle {
|
||||
// TODO rename load
|
||||
pub fn open(&self) -> io::Result<DynamicColumn> {
|
||||
self.open_internal(self.file_slice.clone())
|
||||
let column_bytes: OwnedBytes = self.file_slice.read_bytes()?;
|
||||
self.open_internal(column_bytes)
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
@@ -258,15 +259,16 @@ impl DynamicColumnHandle {
|
||||
/// If not, the fastfield reader will returns the u64-value associated with the original
|
||||
/// FastValue.
|
||||
pub fn open_u64_lenient(&self) -> io::Result<Option<Column<u64>>> {
|
||||
let column_bytes = self.file_slice.read_bytes()?;
|
||||
match self.column_type {
|
||||
ColumnType::Str | ColumnType::Bytes => {
|
||||
let column: BytesColumn =
|
||||
crate::column::open_column_bytes(self.file_slice.clone(), self.format_version)?;
|
||||
crate::column::open_column_bytes(column_bytes, self.format_version)?;
|
||||
Ok(Some(column.term_ord_column))
|
||||
}
|
||||
ColumnType::IpAddr => {
|
||||
let column = crate::column::open_column_u128_as_compact_u64(
|
||||
self.file_slice.clone(),
|
||||
column_bytes,
|
||||
self.format_version,
|
||||
)?;
|
||||
Ok(Some(column))
|
||||
@@ -276,40 +278,40 @@ impl DynamicColumnHandle {
|
||||
| ColumnType::U64
|
||||
| ColumnType::F64
|
||||
| ColumnType::DateTime => {
|
||||
let column = crate::column::open_column_u64::<u64>(
|
||||
self.file_slice.clone(),
|
||||
self.format_version,
|
||||
)?;
|
||||
let column =
|
||||
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?;
|
||||
Ok(Some(column))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn open_internal(&self, file_slice: FileSlice) -> io::Result<DynamicColumn> {
|
||||
fn open_internal(&self, column_bytes: OwnedBytes) -> io::Result<DynamicColumn> {
|
||||
let dynamic_column: DynamicColumn = match self.column_type {
|
||||
ColumnType::Bytes => {
|
||||
crate::column::open_column_bytes(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_bytes(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::Str => {
|
||||
crate::column::open_column_str(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_str(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
crate::column::open_column_u64::<i64>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<i64>(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
crate::column::open_column_u64::<u64>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
crate::column::open_column_u64::<f64>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<f64>(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
crate::column::open_column_u64::<bool>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<bool>(column_bytes, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::IpAddr => {
|
||||
crate::column::open_column_u128::<Ipv6Addr>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u128::<Ipv6Addr>(column_bytes, self.format_version)?
|
||||
.into()
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
crate::column::open_column_u64::<DateTime>(file_slice, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<DateTime>(column_bytes, self.format_version)?
|
||||
.into()
|
||||
}
|
||||
};
|
||||
Ok(dynamic_column)
|
||||
|
||||
@@ -29,7 +29,6 @@ mod column;
|
||||
pub mod column_index;
|
||||
pub mod column_values;
|
||||
mod columnar;
|
||||
mod comparable_doc;
|
||||
mod dictionary;
|
||||
mod dynamic_column;
|
||||
mod iterable;
|
||||
@@ -37,7 +36,7 @@ pub(crate) mod utils;
|
||||
mod value;
|
||||
|
||||
pub use block_accessor::ColumnBlockAccessor;
|
||||
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
|
||||
pub use column::{BytesColumn, Column, StrColumn};
|
||||
pub use column_index::ColumnIndex;
|
||||
pub use column_values::{
|
||||
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,
|
||||
@@ -46,7 +45,6 @@ pub use columnar::{
|
||||
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
|
||||
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
|
||||
};
|
||||
pub use comparable_doc::ComparableDoc;
|
||||
use sstable::VoidSSTable;
|
||||
pub use value::{NumericalType, NumericalValue};
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
|
||||
let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
|
||||
panic!();
|
||||
};
|
||||
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect();
|
||||
let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
|
||||
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
|
||||
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
|
||||
panic!();
|
||||
};
|
||||
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect();
|
||||
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
|
||||
assert_eq!(
|
||||
&vals,
|
||||
&[
|
||||
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
|
||||
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect();
|
||||
let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
|
||||
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
|
||||
assert_eq!(str_col.num_rows(), 5);
|
||||
let mut term_buffer = String::new();
|
||||
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
|
||||
panic!();
|
||||
};
|
||||
let index: Vec<Option<u64>> = (0..5)
|
||||
.map(|row_id| bytes_col.ords().first(row_id))
|
||||
.map(|doc_id| bytes_col.ords().first(doc_id))
|
||||
.collect();
|
||||
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
|
||||
assert_eq!(bytes_col.num_rows(), 5);
|
||||
@@ -641,7 +641,7 @@ proptest! {
|
||||
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
|
||||
let mut output: Vec<u8> = Vec::new();
|
||||
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]).into();
|
||||
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output, || false,).unwrap();
|
||||
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output).unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> = columnar_docs.iter().flatten().cloned().collect();
|
||||
let expected_merged_columnar = build_columnar(&concat_rows[..]);
|
||||
@@ -665,7 +665,6 @@ fn test_columnar_merging_empty_columnar() {
|
||||
&[],
|
||||
crate::MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -703,7 +702,6 @@ fn test_columnar_merging_number_columns() {
|
||||
&[],
|
||||
crate::MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -777,7 +775,6 @@ fn test_columnar_merge_and_remap(
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -820,7 +817,6 @@ fn test_columnar_merge_empty() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -847,7 +843,6 @@ fn test_columnar_merge_single_str_column() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -880,7 +875,6 @@ fn test_delete_decrease_cardinality() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
|
||||
@@ -181,6 +181,14 @@ pub struct BitSet {
|
||||
len: u64,
|
||||
max_value: u32,
|
||||
}
|
||||
impl std::fmt::Debug for BitSet {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BitSet")
|
||||
.field("len", &self.len)
|
||||
.field("max_value", &self.max_value)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_buckets(max_val: u32) -> u32 {
|
||||
max_val.div_ceil(64u32)
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use std::ops::Range;
|
||||
|
||||
use super::file_slice::FileSlice;
|
||||
use super::{HasLen, OwnedBytes};
|
||||
|
||||
const DEFAULT_BUFFER_MAX_SIZE: usize = 512 * 1024; // 512K
|
||||
|
||||
/// A buffered reader for a FileSlice.
|
||||
///
|
||||
/// Reads the underlying `FileSlice` in large, sequential chunks to amortize
|
||||
/// the cost of `read_bytes` calls, while keeping peak memory usage under control.
|
||||
///
|
||||
/// TODO: Rather than wrapping a `FileSlice` in buffering, it will usually be better to adjust a
|
||||
/// `FileHandle` to directly handle buffering itself.
|
||||
/// TODO: See: https://github.com/paradedb/paradedb/issues/3374
|
||||
pub struct BufferedFileSlice {
|
||||
file_slice: FileSlice,
|
||||
buffer: RefCell<OwnedBytes>,
|
||||
buffer_range: RefCell<Range<u64>>,
|
||||
buffer_max_size: usize,
|
||||
}
|
||||
|
||||
impl BufferedFileSlice {
|
||||
/// Creates a new `BufferedFileSlice`.
|
||||
///
|
||||
/// The `buffer_max_size` is the amount of data that will be read from the
|
||||
/// `FileSlice` on a buffer miss.
|
||||
pub fn new(file_slice: FileSlice, buffer_max_size: usize) -> Self {
|
||||
Self {
|
||||
file_slice,
|
||||
buffer: RefCell::new(OwnedBytes::empty()),
|
||||
buffer_range: RefCell::new(0..0),
|
||||
buffer_max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new `BufferedFileSlice` with a default buffer max size.
|
||||
pub fn new_with_default_buffer_size(file_slice: FileSlice) -> Self {
|
||||
Self::new(file_slice, DEFAULT_BUFFER_MAX_SIZE)
|
||||
}
|
||||
|
||||
/// Creates an empty `BufferedFileSlice`.
|
||||
pub fn empty() -> Self {
|
||||
Self::new(FileSlice::empty(), 0)
|
||||
}
|
||||
|
||||
/// Returns an `OwnedBytes` corresponding to the given `required_range`.
|
||||
///
|
||||
/// If the requested range is not in the buffer, this will trigger a read
|
||||
/// from the underlying `FileSlice`.
|
||||
///
|
||||
/// If the requested range is larger than the buffer_max_size, it will be read directly from the
|
||||
/// source without buffering.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an `io::Error` if the underlying read fails or the range is
|
||||
/// out of bounds.
|
||||
pub fn get_bytes(&self, required_range: Range<u64>) -> io::Result<OwnedBytes> {
|
||||
let buffer_range = self.buffer_range.borrow();
|
||||
|
||||
// Cache miss condition: the required range is not fully contained in the current buffer.
|
||||
if required_range.start < buffer_range.start || required_range.end > buffer_range.end {
|
||||
drop(buffer_range); // release borrow before mutating
|
||||
|
||||
if required_range.end > self.file_slice.len() as u64 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Requested range extends beyond the end of the file slice.",
|
||||
));
|
||||
}
|
||||
|
||||
if (required_range.end - required_range.start) as usize > self.buffer_max_size {
|
||||
// This read is larger than our buffer max size.
|
||||
// Read it directly and bypass the buffer to avoid churning.
|
||||
return self
|
||||
.file_slice
|
||||
.read_bytes_slice(required_range.start as usize..required_range.end as usize);
|
||||
}
|
||||
|
||||
let new_buffer_start = required_range.start;
|
||||
let new_buffer_end = min(
|
||||
new_buffer_start + self.buffer_max_size as u64,
|
||||
self.file_slice.len() as u64,
|
||||
);
|
||||
let read_range = new_buffer_start..new_buffer_end;
|
||||
|
||||
let new_buffer = self
|
||||
.file_slice
|
||||
.read_bytes_slice(read_range.start as usize..read_range.end as usize)?;
|
||||
|
||||
self.buffer.replace(new_buffer);
|
||||
self.buffer_range.replace(read_range);
|
||||
}
|
||||
|
||||
// Now the data is guaranteed to be in the buffer.
|
||||
let buffer = self.buffer.borrow();
|
||||
let buffer_range = self.buffer_range.borrow();
|
||||
let local_start = (required_range.start - buffer_range.start) as usize;
|
||||
let local_end = (required_range.end - buffer_range.start) as usize;
|
||||
Ok(buffer.slice(local_start..local_end))
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::fs::File;
|
||||
use std::ops::{Deref, Range, RangeBounds};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -339,27 +339,6 @@ impl FileHandle for OwnedBytes {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DeferredFileSlice {
|
||||
opener: Arc<dyn Fn() -> io::Result<FileSlice> + Send + Sync + 'static>,
|
||||
file_slice: OnceLock<std::io::Result<FileSlice>>,
|
||||
}
|
||||
|
||||
impl DeferredFileSlice {
|
||||
pub fn new(opener: impl Fn() -> io::Result<FileSlice> + Send + Sync + 'static) -> Self {
|
||||
DeferredFileSlice {
|
||||
opener: Arc::new(opener),
|
||||
file_slice: OnceLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open(&self) -> io::Result<&FileSlice> {
|
||||
match self.file_slice.get_or_init(|| (self.opener)()) {
|
||||
Ok(file_slice) => Ok(file_slice),
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io;
|
||||
|
||||
@@ -6,7 +6,6 @@ pub use byteorder::LittleEndian as Endianness;
|
||||
|
||||
mod bitset;
|
||||
pub mod bounds;
|
||||
pub mod buffered_file_slice;
|
||||
mod byte_count;
|
||||
mod datetime;
|
||||
pub mod file_slice;
|
||||
|
||||
@@ -58,33 +58,6 @@ impl BinarySerializable for VIntU128 {
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct VInt(pub u64);
|
||||
|
||||
impl VInt {
|
||||
pub fn deserialize_with_size<R: Read>(reader: &mut R) -> io::Result<(Self, usize)> {
|
||||
let mut nbytes = 0;
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0u64;
|
||||
loop {
|
||||
match bytes.next() {
|
||||
Some(Ok(b)) => {
|
||||
nbytes += 1;
|
||||
result |= u64::from(b % 128u8) << shift;
|
||||
if b >= STOP_BIT {
|
||||
return Ok((VInt(result), nbytes));
|
||||
}
|
||||
shift += 7;
|
||||
}
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Reach end of buffer while reading VInt",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const STOP_BIT: u8 = 128;
|
||||
|
||||
#[inline]
|
||||
@@ -252,6 +225,7 @@ impl BinarySerializable for VInt {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::{BinarySerializable, VInt, serialize_vint_u32};
|
||||
|
||||
fn aux_test_vint(val: u64) {
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
// # Multiple Snippets Example
|
||||
//
|
||||
// This example demonstrates how to return multiple text fragments
|
||||
// from a document, useful for long documents with matches in different locations.
|
||||
|
||||
use tantivy::collector::TopDocs;
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::*;
|
||||
use tantivy::snippet::SnippetGenerator;
|
||||
use tantivy::{doc, Index, IndexWriter};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn main() -> tantivy::Result<()> {
|
||||
let index_path = TempDir::new()?;
|
||||
|
||||
// Define the schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
let title = schema_builder.add_text_field("title", TEXT | STORED);
|
||||
let body = schema_builder.add_text_field("body", TEXT | STORED);
|
||||
let schema = schema_builder.build();
|
||||
|
||||
// Create the index
|
||||
let index = Index::create_in_dir(&index_path, schema)?;
|
||||
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
|
||||
|
||||
// Index a long document with multiple occurrences of "rust"
|
||||
index_writer.add_document(doc!(
|
||||
title => "The Rust Programming Language",
|
||||
body => "Rust is a systems programming language that runs blazingly fast, prevents \
|
||||
segfaults, and guarantees thread safety. Lorem ipsum dolor sit amet, \
|
||||
consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore. \
|
||||
Rust empowers everyone to build reliable and efficient software. More filler \
|
||||
text to create distance between matches. Ut enim ad minim veniam, quis nostrud \
|
||||
exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. \
|
||||
The Rust compiler is known for its helpful error messages. Duis aute irure \
|
||||
dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla \
|
||||
pariatur. Rust has a strong type system and ownership model."
|
||||
))?;
|
||||
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, body]);
|
||||
let query = query_parser.parse_query("rust")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
// Create snippet generator
|
||||
let mut snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
|
||||
|
||||
println!("=== Single Snippet (Default Behavior) ===\n");
|
||||
for (score, doc_address) in &top_docs {
|
||||
let doc = searcher.doc::<TantivyDocument>(*doc_address)?;
|
||||
let snippet = snippet_generator.snippet_from_doc(&doc);
|
||||
println!("Document score: {}", score);
|
||||
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
|
||||
println!("Single snippet: {}\n", snippet.to_html());
|
||||
}
|
||||
|
||||
println!("\n=== Multiple Snippets (New Feature) ===\n");
|
||||
|
||||
// Configure to return multiple snippets
|
||||
// Get up to 3 snippets
|
||||
snippet_generator.set_snippets_limit(3);
|
||||
// Smaller fragments
|
||||
snippet_generator.set_max_num_chars(80);
|
||||
// By default, multiple snippets are sorted by score. You can change this to sort by position.
|
||||
// snippet_generator.set_sort_order(SnippetSortOrder::Position);
|
||||
|
||||
for (score, doc_address) in top_docs {
|
||||
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
|
||||
let snippets = snippet_generator.snippets_from_doc(&doc);
|
||||
|
||||
println!("Document score: {}", score);
|
||||
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
|
||||
println!("Found {} snippets:", snippets.len());
|
||||
|
||||
for (i, snippet) in snippets.iter().enumerate() {
|
||||
println!(" Snippet {}: {}", i + 1, snippet.to_html());
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
cargo +stable nextest run --features quickwit,mmap,stopwords,lz4-compression,zstd-compression,failpoints --verbose --workspace
|
||||
@@ -1,4 +1,4 @@
|
||||
use columnar::{Column, ColumnType, StrColumn};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
|
||||
use common::BitSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::Serialize;
|
||||
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
|
||||
};
|
||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
use crate::aggregation::bucket::{
|
||||
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
|
||||
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
|
||||
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
TermsAggregationInternal,
|
||||
};
|
||||
use crate::aggregation::metric::{
|
||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
|
||||
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
|
||||
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
|
||||
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
|
||||
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
|
||||
/// Request data for each aggregation type.
|
||||
pub per_request: PerRequestAggSegCtx,
|
||||
pub context: AggContextParams,
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
}
|
||||
|
||||
impl AggregationsSegmentCtx {
|
||||
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
|
||||
.as_deref()
|
||||
.expect("range_req_data slot is empty (taken)")
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
|
||||
self.per_request.filter_req_data[idx]
|
||||
.as_deref()
|
||||
.expect("filter_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- mutable getters ----------
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
|
||||
self.per_request.term_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_cardinality_req_data_mut(
|
||||
&mut self,
|
||||
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
|
||||
) -> &mut CardinalityAggReqData {
|
||||
&mut self.per_request.cardinality_req_data[idx]
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
|
||||
self.per_request.histogram_req_data[idx]
|
||||
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
|
||||
|
||||
// ---------- take / put (terms, histogram, range) ----------
|
||||
|
||||
/// Move out the boxed Terms request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
|
||||
self.per_request.term_req_data[idx]
|
||||
.take()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
/// Put back a Terms request into an empty slot at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
|
||||
debug_assert!(self.per_request.term_req_data[idx].is_none());
|
||||
self.per_request.term_req_data[idx] = Some(value);
|
||||
}
|
||||
|
||||
/// Move out the boxed Histogram request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
|
||||
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
|
||||
|
||||
/// Convert the aggregation tree into a serializable struct representation.
|
||||
/// Each node contains: { name, kind, children }.
|
||||
#[allow(dead_code)]
|
||||
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
|
||||
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
|
||||
let mut children: Vec<AggTreeViewNode> =
|
||||
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
|
||||
pub(crate) fn build_segment_agg_collectors_root(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
|
||||
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_agg_collectors(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, nodes)
|
||||
}
|
||||
|
||||
fn build_segment_agg_collectors_generic(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let mut collectors = Vec::new();
|
||||
for node in nodes.iter() {
|
||||
@@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
|
||||
Ok(Box::new(SegmentCardinalityCollector::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
)))
|
||||
}
|
||||
AggKind::StatsKind(stats_type) => {
|
||||
@@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
|
||||
| StatsType::Count
|
||||
| StatsType::Max
|
||||
| StatsType::Min
|
||||
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
node.idx_in_req_data,
|
||||
))),
|
||||
StatsType::ExtendedStats(sigma) => {
|
||||
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
|
||||
req_data.field_type,
|
||||
sigma,
|
||||
node.idx_in_req_data,
|
||||
req_data.missing,
|
||||
)))
|
||||
}
|
||||
StatsType::Percentiles => Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
|
||||
| StatsType::Stats => build_segment_stats_collector(req_data),
|
||||
StatsType::ExtendedStats(sigma) => Ok(Box::new(
|
||||
SegmentExtendedStatsCollector::from_req(req_data, sigma),
|
||||
)),
|
||||
StatsType::Percentiles => {
|
||||
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
|
||||
Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(
|
||||
req_data.field_type,
|
||||
req_data.missing_u64,
|
||||
req_data.accessor.clone(),
|
||||
node.idx_in_req_data,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
AggKind::TopHits => {
|
||||
@@ -428,12 +415,8 @@ pub(crate) fn build_segment_agg_collector(
|
||||
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
|
||||
AggKind::Filter => build_segment_filter_collector(req, node),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,6 +476,7 @@ pub(crate) fn build_aggregations_data_from_req(
|
||||
let mut data = AggregationsSegmentCtx {
|
||||
per_request: Default::default(),
|
||||
context,
|
||||
column_block_accessor: ColumnBlockAccessor::default(),
|
||||
};
|
||||
|
||||
for (name, agg) in aggs.iter() {
|
||||
@@ -521,9 +505,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: range_req.clone(),
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -541,9 +525,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req.clone(),
|
||||
is_date_histogram: false,
|
||||
bounds: HistogramBounds {
|
||||
@@ -568,9 +550,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req,
|
||||
is_date_histogram: true,
|
||||
bounds: HistogramBounds {
|
||||
@@ -650,7 +630,6 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for,
|
||||
missing: *missing,
|
||||
@@ -678,7 +657,6 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for: StatsType::Percentiles,
|
||||
missing: percentiles_req.missing,
|
||||
@@ -753,6 +731,7 @@ fn build_nodes(
|
||||
segment_reader: reader.clone(),
|
||||
evaluator,
|
||||
matching_docs_buffer,
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -895,7 +874,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
});
|
||||
}
|
||||
|
||||
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
|
||||
// Add one node per accessor
|
||||
for (accessor, column_type) in column_and_types {
|
||||
let missing_value_for_accessor = if use_special_missing_agg {
|
||||
None
|
||||
@@ -926,11 +905,8 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
// Will be filled later when building collectors
|
||||
sub_aggregation_blueprint: None,
|
||||
sug_aggregations: sub_aggs.clone(),
|
||||
allowed_term_ids,
|
||||
is_top_level,
|
||||
@@ -943,7 +919,6 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
});
|
||||
|
||||
@@ -2,15 +2,441 @@ use serde_json::Value;
|
||||
|
||||
use crate::aggregation::agg_req::{Aggregation, Aggregations};
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||
use crate::aggregation::collector::AggregationCollector;
|
||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||
use crate::aggregation::DistributedAggregationCollector;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, FAST};
|
||||
use crate::{Index, IndexWriter, Term};
|
||||
|
||||
// The following tests ensure that each bucket aggregation type correctly functions as a
|
||||
// sub-aggregation of another bucket aggregation in two scenarios:
|
||||
// 1) The parent has more buckets than the child sub-aggregation
|
||||
// 2) The child sub-aggregation has more buckets than the parent
|
||||
//
|
||||
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
|
||||
|
||||
#[test]
|
||||
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 4 buckets
|
||||
// Child: terms on text -> 2 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
// Exact expected structure and counts
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{
|
||||
"key": "*-3",
|
||||
"doc_count": 1,
|
||||
"to": 3.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "3-7",
|
||||
"doc_count": 3,
|
||||
"from": 3.0,
|
||||
"to": 7.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 2, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "7-20",
|
||||
"doc_count": 3,
|
||||
"from": 7.0,
|
||||
"to": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 3, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "20-*",
|
||||
"doc_count": 2,
|
||||
"from": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: histogram on score with large interval -> 1 bucket
|
||||
// Child: terms on text -> 2 buckets (cool/nohit)
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_hist": {
|
||||
"histogram": {"field": "score", "interval": 100.0},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_hist"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": 0.0,
|
||||
"doc_count": 9,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 7, "key": "cool"},
|
||||
{"doc_count": 2, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 5 buckets
|
||||
// Child: coarse range with 3 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0}
|
||||
]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: range with 4 buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several ranges
|
||||
// Child: histogram with large interval (single bucket per parent)
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text -> 2 buckets
|
||||
// Child: histogram with small interval -> multiple buckets including empties
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 4},
|
||||
{"key": 10.0, "doc_count": 2},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 1},
|
||||
{"key": 10.0, "doc_count": 0},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several buckets
|
||||
// Child: date_histogram with 30d -> single bucket per parent
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
|
||||
// Verify each parent bucket has exactly one child date bucket with matching doc_count
|
||||
for bucket in buckets {
|
||||
let parent_count = bucket["doc_count"].as_u64().unwrap();
|
||||
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(child_buckets.len(), 1);
|
||||
assert_eq!(child_buckets[0]["doc_count"], parent_count);
|
||||
}
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: date_histogram with 1d -> multiple buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
|
||||
|
||||
// cool bucket
|
||||
assert_eq!(buckets[0]["key"], "cool");
|
||||
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(cool_buckets.len(), 3);
|
||||
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
|
||||
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
|
||||
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
|
||||
|
||||
// nohit bucket
|
||||
assert_eq!(buckets[1]["key"], "nohit");
|
||||
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(nohit_buckets.len(), 2);
|
||||
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
|
||||
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_avg_req(field_name: &str) -> Aggregation {
|
||||
serde_json::from_value(json!({
|
||||
"avg": {
|
||||
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
|
||||
// Note: The flushng part of these tests are outdated, since the buffering change after converting
|
||||
// the collection into one collector per request instead of per bucket.
|
||||
//
|
||||
// However they are useful as they test a complex aggregation requests.
|
||||
fn test_aggregation_flushing(
|
||||
merge_segments: bool,
|
||||
use_distributed_collector: bool,
|
||||
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
|
||||
|
||||
let reader = index.reader()?;
|
||||
|
||||
assert_eq!(DOC_BLOCK_SIZE, 64);
|
||||
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
|
||||
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
|
||||
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
|
||||
// block.
|
||||
//
|
||||
// Build a request so that on the first level we have one full cache, which is then flushed.
|
||||
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
|
||||
|
||||
@@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::cached_sub_aggs::{
|
||||
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::docset::DocSet;
|
||||
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
|
||||
use crate::schema::Schema;
|
||||
@@ -404,15 +408,18 @@ pub struct FilterAggReqData {
|
||||
pub evaluator: DocumentQueryEvaluator,
|
||||
/// Reusable buffer for matching documents to minimize allocations during collection
|
||||
pub matching_docs_buffer: Vec<DocId>,
|
||||
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl FilterAggReqData {
|
||||
pub(crate) fn get_memory_consumption(&self) -> usize {
|
||||
// Estimate: name + segment reader reference + bitset + buffer capacity
|
||||
self.name.len()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
+ std::mem::size_of::<bool>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,17 +496,24 @@ impl Debug for DocumentQueryEvaluator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Segment collector for filter aggregation
|
||||
pub struct SegmentFilterCollector {
|
||||
/// Document count in this bucket
|
||||
#[derive(Debug, Clone, PartialEq, Copy)]
|
||||
struct DocCount {
|
||||
doc_count: u64,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// Segment collector for filter aggregation
|
||||
pub struct SegmentFilterCollector<C: SubAggCache> {
|
||||
/// Document counts per parent bucket
|
||||
parent_buckets: Vec<DocCount>,
|
||||
/// Sub-aggregation collectors
|
||||
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_aggregations: Option<CachedSubAggs<C>>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
/// Accessor index for this filter aggregation (to access FilterAggReqData)
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl SegmentFilterCollector {
|
||||
impl<C: SubAggCache> SegmentFilterCollector<C> {
|
||||
/// Create a new filter segment collector following the new agg_data pattern
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
@@ -511,47 +525,75 @@ impl SegmentFilterCollector {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
|
||||
|
||||
Ok(SegmentFilterCollector {
|
||||
doc_count: 0,
|
||||
parent_buckets: Vec::new(),
|
||||
sub_aggregations: sub_agg_collector,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for SegmentFilterCollector {
|
||||
pub(crate) fn build_segment_filter_collector(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
|
||||
.as_ref()
|
||||
.expect("filter_req_data slot is empty")
|
||||
.is_top_level;
|
||||
|
||||
if is_top_level {
|
||||
Ok(Box::new(
|
||||
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
|
||||
))
|
||||
} else {
|
||||
Ok(Box::new(
|
||||
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentFilterCollector")
|
||||
.field("doc_count", &self.doc_count)
|
||||
.field("buckets", &self.parent_buckets)
|
||||
.field("has_sub_aggs", &self.sub_aggregations.is_some())
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl CollectorClone for SegmentFilterCollector {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
// For now, panic - this needs proper implementation with weight recreation
|
||||
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let mut sub_results = IntermediateAggregationResults::default();
|
||||
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
|
||||
|
||||
if let Some(sub_aggs) = self.sub_aggregations {
|
||||
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_results,
|
||||
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
|
||||
// exist, so that sub-aggregations can still produce results (e.g., zero doc
|
||||
// count)
|
||||
bucket_opt
|
||||
.map(|bucket| bucket.bucket_id)
|
||||
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Create the filter bucket result
|
||||
let filter_bucket_result = IntermediateBucketResult::Filter {
|
||||
doc_count: self.doc_count,
|
||||
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
|
||||
sub_aggregations: sub_results,
|
||||
};
|
||||
|
||||
@@ -570,32 +612,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
// Access the evaluator from FilterAggReqData
|
||||
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
|
||||
|
||||
// O(1) BitSet lookup to check if document matches filter
|
||||
if req_data.evaluator.matches_document(doc) {
|
||||
self.doc_count += 1;
|
||||
|
||||
// If we have sub-aggregations, collect on them for this filtered document
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn collect(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if docs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
|
||||
// Take the request data to avoid borrow checker issues with sub-aggregations
|
||||
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
|
||||
|
||||
@@ -604,18 +631,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
req.evaluator
|
||||
.filter_batch(docs, &mut req.matching_docs_buffer);
|
||||
|
||||
self.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
bucket.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
|
||||
// Batch process sub-aggregations if we have matches
|
||||
if !req.matching_docs_buffer.is_empty() {
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
// Use collect_block for better sub-aggregation performance
|
||||
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
|
||||
for &doc_id in &req.matching_docs_buffer {
|
||||
sub_aggs.push(bucket.bucket_id, doc_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the request data back
|
||||
agg_data.put_back_filter_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.check_flush_local(agg_data)?;
|
||||
}
|
||||
// put back bucket
|
||||
self.parent_buckets[parent_bucket_id as usize] = bucket;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -626,6 +659,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.parent_buckets.push(DocCount {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result for filter aggregation
|
||||
@@ -1519,9 +1567,9 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let agg = json!({
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy_bitpacker::minmax;
|
||||
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::BucketEntry;
|
||||
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateHistogramBucketEntry,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The field type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
|
||||
/// Will be filled during initialization of the collector.
|
||||
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The histogram aggregation request.
|
||||
pub req: HistogramAggregation,
|
||||
/// True if this is a date_histogram aggregation.
|
||||
@@ -257,18 +252,24 @@ impl HistogramBounds {
|
||||
pub(crate) struct SegmentHistogramBucketEntry {
|
||||
pub key: f64,
|
||||
pub doc_count: u64,
|
||||
pub bucket_id: BucketId,
|
||||
}
|
||||
|
||||
impl SegmentHistogramBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateHistogramBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_aggregation_res,
|
||||
self.bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok(IntermediateHistogramBucketEntry {
|
||||
key: self.key,
|
||||
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct HistogramBuckets {
|
||||
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct SegmentHistogramCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
|
||||
/// One Histogram bucket per parent bucket id.
|
||||
parent_buckets: Vec<HistogramBuckets>,
|
||||
sub_agg: Option<HighCardCachedSubAggs>,
|
||||
accessor_idx: usize,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
let bucket = self.into_intermediate_bucket_result(agg_data)?;
|
||||
// TODO: avoid prepare_max_bucket here and handle empty buckets.
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
|
||||
|
||||
let bounds = req.bounds;
|
||||
let interval = req.req.interval;
|
||||
let offset = req.offset;
|
||||
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
|
||||
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in req
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in agg_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let val = f64_from_fastfield_u64(val, &req.field_type);
|
||||
let val = f64_from_fastfield_u64(val, req.field_type);
|
||||
let bucket_pos = get_bucket_pos(val);
|
||||
if bounds.contains(val) {
|
||||
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
|
||||
SegmentHistogramBucketEntry { key, doc_count: 0 }
|
||||
SegmentHistogramBucketEntry {
|
||||
key,
|
||||
doc_count: 0,
|
||||
bucket_id: self.bucket_id_provider.next_bucket_id(),
|
||||
}
|
||||
});
|
||||
bucket.doc_count += 1;
|
||||
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
|
||||
self.sub_aggregations
|
||||
.entry(bucket_pos)
|
||||
.or_insert_with(|| sub_aggregation_blueprint.clone())
|
||||
.collect(doc, agg_data)?;
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for sub_aggregation in self.sub_aggregations.values_mut() {
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
self.parent_buckets.push(HistogramBuckets {
|
||||
buckets: FxHashMap::default(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
impl SegmentHistogramCollector {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
|
||||
let buckets_mem = self.buckets.memory_consumption();
|
||||
self_mem + sub_aggs_mem + buckets_mem
|
||||
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
|
||||
self_mem + buckets_mem
|
||||
}
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
pub fn into_intermediate_bucket_result(
|
||||
self,
|
||||
fn add_intermediate_bucket_result(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
histogram: HistogramBuckets,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut buckets = Vec::with_capacity(self.buckets.len());
|
||||
let mut buckets = Vec::with_capacity(histogram.buckets.len());
|
||||
|
||||
for (bucket_pos, bucket) in self.buckets {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(
|
||||
self.sub_aggregations.get(&bucket_pos).cloned(),
|
||||
agg_data,
|
||||
);
|
||||
for bucket in histogram.buckets.into_values() {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
|
||||
|
||||
buckets.push(bucket_res?);
|
||||
}
|
||||
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let blueprint = if !node.children.is_empty() {
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
|
||||
max: f64::MAX,
|
||||
});
|
||||
req_data.offset = req_data.req.offset.unwrap_or(0.0);
|
||||
|
||||
req_data.sub_aggregation_blueprint = blueprint;
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
|
||||
Ok(Self {
|
||||
buckets: Default::default(),
|
||||
sub_aggregations: Default::default(),
|
||||
parent_buckets: Default::default(),
|
||||
sub_agg,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::AggregationLimitsGuard;
|
||||
use crate::aggregation::cached_sub_aggs::{
|
||||
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -23,12 +27,12 @@ pub struct RangeAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The range aggregation request.
|
||||
pub req: RangeAggregation,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// Whether this is a top-level aggregation.
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl RangeAggReqData {
|
||||
@@ -151,19 +155,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentRangeCollector {
|
||||
pub struct SegmentRangeCollector<C: SubAggCache> {
|
||||
/// The buckets containing the aggregation data.
|
||||
buckets: Vec<SegmentRangeAndBucketEntry>,
|
||||
/// One for each ParentBucketId
|
||||
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
|
||||
column_type: ColumnType,
|
||||
pub(crate) accessor_idx: usize,
|
||||
sub_agg: Option<CachedSubAggs<C>>,
|
||||
/// Here things get a bit weird. We need to assign unique bucket ids across all
|
||||
/// parent buckets. So we keep track of the next available bucket id here.
|
||||
/// This allows a kind of flattening of the bucket ids across all parent buckets.
|
||||
/// E.g. in nested aggregations:
|
||||
/// Term Agg -> Range aggregation -> Stats aggregation
|
||||
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
|
||||
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
|
||||
/// - INFO: 0,1,2,3
|
||||
/// - ERROR: 4,5,6,7
|
||||
/// - WARN: 8,9,10,11
|
||||
///
|
||||
/// This allows the Stats aggregation to have unique bucket ids to refer to.
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
limits: AggregationLimitsGuard,
|
||||
}
|
||||
|
||||
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentRangeCollector")
|
||||
.field("parent_buckets_len", &self.parent_buckets.len())
|
||||
.field("column_type", &self.column_type)
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.field("has_sub_agg", &self.sub_agg.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SegmentRangeBucketEntry {
|
||||
pub key: Key,
|
||||
pub doc_count: u64,
|
||||
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
pub bucket_id: BucketId,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
|
||||
@@ -184,48 +216,50 @@ impl Debug for SegmentRangeBucketEntry {
|
||||
impl SegmentRangeBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateRangeBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = self.sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
let sub_aggregation = IntermediateAggregationResults::default();
|
||||
|
||||
Ok(IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: sub_aggregation_res,
|
||||
sub_aggregation_res: sub_aggregation,
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let field_type = self.column_type;
|
||||
let name = agg_data
|
||||
.get_range_req_data(self.accessor_idx)
|
||||
.name
|
||||
.to_string();
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
|
||||
.buckets
|
||||
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
|
||||
.into_iter()
|
||||
.map(move |range_bucket| {
|
||||
Ok((
|
||||
range_to_string(&range_bucket.range, &field_type)?,
|
||||
range_bucket
|
||||
.bucket
|
||||
.into_intermediate_bucket_entry(agg_data)?,
|
||||
))
|
||||
.map(|range_bucket| {
|
||||
let bucket_id = range_bucket.bucket.bucket_id;
|
||||
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut agg.sub_aggregation_res,
|
||||
bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
@@ -242,73 +276,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
// Take request data to avoid borrow conflicts during sub-aggregation
|
||||
let mut req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
let req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
|
||||
for (doc, val) in req
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
|
||||
|
||||
for (doc, val) in agg_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let bucket_pos = self.get_bucket_pos(val);
|
||||
let bucket = &mut self.buckets[bucket_pos];
|
||||
let bucket_pos = get_bucket_pos(val, buckets);
|
||||
let bucket = &mut buckets[bucket_pos];
|
||||
bucket.bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
|
||||
agg_data.put_back_range_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for bucket in self.buckets.iter_mut() {
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let new_buckets = self.create_new_buckets(agg_data)?;
|
||||
self.parent_buckets.push(new_buckets);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
pub(crate) fn build_segment_range_collector(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
|
||||
let field_type = req_data.field_type;
|
||||
|
||||
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
|
||||
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
|
||||
// we can are still in low cardinality territory.
|
||||
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
|
||||
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_low_card {
|
||||
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
|
||||
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
} else {
|
||||
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentRangeCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let (field_type, ranges) = {
|
||||
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
|
||||
(req_view.field_type, req_view.req.ranges.clone())
|
||||
};
|
||||
|
||||
impl<C: SubAggCache> SegmentRangeCollector<C> {
|
||||
pub(crate) fn create_new_buckets(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
|
||||
let field_type = self.column_type;
|
||||
let req_data = agg_data.get_range_req_data(self.accessor_idx);
|
||||
// The range input on the request is f64.
|
||||
// We need to convert to u64 ranges, because we read the values as u64.
|
||||
// The mapping from the conversion is monotonic so ordering is preserved.
|
||||
let sub_agg_prototype = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(req_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
@@ -317,20 +392,20 @@ impl SegmentRangeCollector {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
};
|
||||
let sub_aggregation = sub_agg_prototype.clone();
|
||||
// let sub_aggregation = sub_agg_prototype.clone();
|
||||
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation,
|
||||
bucket_id,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
@@ -339,27 +414,20 @@ impl SegmentRangeCollector {
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
req_data.context.limits.add_memory_consumed(
|
||||
self.limits.add_memory_consumed(
|
||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||
)?;
|
||||
|
||||
Ok(SegmentRangeCollector {
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_bucket_pos(&self, val: u64) -> usize {
|
||||
let pos = self
|
||||
.buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(self.buckets[pos].range.contains(&val));
|
||||
pos
|
||||
Ok(buckets)
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
|
||||
let pos = buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
|
||||
/// Converts the user provided f64 range value to fast field value space.
|
||||
///
|
||||
@@ -456,7 +524,7 @@ pub(crate) fn range_to_string(
|
||||
let val = i64::from_u64(val);
|
||||
format_date(val)
|
||||
} else {
|
||||
Ok(f64_from_fastfield_u64(val, field_type).to_string())
|
||||
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -486,7 +554,7 @@ mod tests {
|
||||
pub fn get_collector_from_ranges(
|
||||
ranges: Vec<RangeAggregationRange>,
|
||||
field_type: ColumnType,
|
||||
) -> SegmentRangeCollector {
|
||||
) -> SegmentRangeCollector<HighCardSubAggCache> {
|
||||
let req = RangeAggregation {
|
||||
field: "dummy".to_string(),
|
||||
ranges,
|
||||
@@ -506,30 +574,33 @@ mod tests {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
};
|
||||
SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation: None,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
bucket_id: 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
SegmentRangeCollector {
|
||||
buckets,
|
||||
parent_buckets: vec![buckets],
|
||||
column_type: field_type,
|
||||
accessor_idx: 0,
|
||||
sub_agg: None,
|
||||
bucket_id_provider: Default::default(),
|
||||
limits: AggregationLimitsGuard::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -776,7 +847,7 @@ mod tests {
|
||||
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -799,7 +870,7 @@ mod tests {
|
||||
];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -814,7 +885,7 @@ mod tests {
|
||||
let buckets = vec![(-10f64..-1f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
|
||||
}
|
||||
@@ -823,7 +894,7 @@ mod tests {
|
||||
let buckets = vec![(0f64..10f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
|
||||
}
|
||||
@@ -832,7 +903,7 @@ mod tests {
|
||||
fn range_binary_search_test_u64() {
|
||||
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9), 0);
|
||||
@@ -878,7 +949,7 @@ mod tests {
|
||||
let ranges = vec![(10.0..100.0).into()];
|
||||
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9f64.to_u64()), 0);
|
||||
@@ -890,63 +961,3 @@ mod tests {
|
||||
// the max value
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
|
||||
|
||||
const TOTAL_DOCS: u64 = 1_000_000u64;
|
||||
const NUM_DOCS: u64 = 50_000u64;
|
||||
|
||||
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
|
||||
let bucket_size = num_docs / num_buckets;
|
||||
let mut buckets: Vec<RangeAggregationRange> = vec![];
|
||||
for i in 0..num_buckets {
|
||||
let bucket_start = (i * bucket_size) as f64;
|
||||
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
|
||||
}
|
||||
|
||||
get_collector_from_ranges(buckets, ColumnType::U64)
|
||||
}
|
||||
|
||||
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let all_docs = (0..total_docs - 1).collect_vec();
|
||||
let mut vals = all_docs
|
||||
.as_slice()
|
||||
.choose_multiple(&mut rng, num_docs_returned as usize)
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
vals.sort();
|
||||
vals
|
||||
}
|
||||
|
||||
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
|
||||
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
|
||||
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
|
||||
b.iter(|| {
|
||||
let mut bucket_pos = 0;
|
||||
for val in &vals {
|
||||
bucket_pos = collector.get_bucket_pos(*val);
|
||||
}
|
||||
bucket_pos
|
||||
})
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_100_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 100)
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_10_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 10)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::term_agg::TermsAggregation;
|
||||
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Special aggregation to handle missing values for term aggregations.
|
||||
/// This missing aggregation will check multiple columns for existence.
|
||||
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
|
||||
}
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct TermMissingAgg {
|
||||
struct MissingCount {
|
||||
missing_count: u32,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TermMissingAgg {
|
||||
accessor_idx: usize,
|
||||
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_agg: Option<HighCardCachedSubAggs>,
|
||||
/// Idx = parent bucket id, Value = missing count for that bucket
|
||||
missing_count_per_bucket: Vec<MissingCount>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
impl TermMissingAgg {
|
||||
pub(crate) fn new(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let sub_agg = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
let bucket_id_provider = BucketIdProvider::default();
|
||||
|
||||
Ok(Self {
|
||||
accessor_idx,
|
||||
sub_agg,
|
||||
..Default::default()
|
||||
missing_count_per_bucket: Vec::new(),
|
||||
bucket_id_provider,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let term_agg = &req_data.req;
|
||||
let missing = term_agg
|
||||
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
|
||||
Default::default();
|
||||
|
||||
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let mut missing_entry = IntermediateTermBucketEntry {
|
||||
doc_count: self.missing_count,
|
||||
doc_count: missing_count.missing_count,
|
||||
sub_aggregation: Default::default(),
|
||||
};
|
||||
if let Some(sub_agg) = self.sub_agg {
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
let mut res = IntermediateAggregationResults::default();
|
||||
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
|
||||
sub_agg
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
|
||||
missing_entry.sub_aggregation = res;
|
||||
}
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
self.missing_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
|
||||
for doc in docs {
|
||||
let doc = *doc;
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
bucket.missing_count += 1;
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for doc in docs {
|
||||
self.collect(*doc, agg_data)?;
|
||||
while self.missing_count_per_bucket.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.missing_count_per_bucket.push(MissingCount {
|
||||
missing_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
|
||||
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BufAggregationCollector {
|
||||
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
|
||||
staged_docs: DocBlock,
|
||||
num_staged_docs: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufAggregationCollector {
|
||||
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
num_staged_docs: 0,
|
||||
staged_docs: [0; DOC_BLOCK_SIZE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.staged_docs[self.num_staged_docs] = doc;
|
||||
self.num_staged_docs += 1;
|
||||
if self.num_staged_docs == self.staged_docs.len() {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
|
||||
self.collector.flush(agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
245
src/aggregation/cached_sub_aggs.rs
Normal file
245
src/aggregation/cached_sub_aggs.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::DocId;
|
||||
|
||||
/// A cache for sub-aggregations, storing doc ids per bucket id.
|
||||
/// Depending on the cardinality of the parent aggregation, we use different
|
||||
/// storage strategies.
|
||||
///
|
||||
/// ## Low Cardinality
|
||||
/// Cardinality here refers to the number of unique flattened buckets that can be created
|
||||
/// by the parent aggregation.
|
||||
/// Flattened buckets are the result of combining all buckets per collector
|
||||
/// into a single list of buckets, where each bucket is identified by its BucketId.
|
||||
///
|
||||
/// ## Usage
|
||||
/// Since this is caching for sub-aggregations, it is only used by bucket
|
||||
/// aggregations.
|
||||
///
|
||||
/// TODO: consider using a more advanced data structure for high cardinality
|
||||
/// aggregations.
|
||||
/// What this datastructure does in general is to group docs by bucket id.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CachedSubAggs<C: SubAggCache> {
|
||||
cache: C,
|
||||
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
|
||||
num_docs: usize,
|
||||
}
|
||||
|
||||
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
|
||||
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
|
||||
|
||||
const FLUSH_THRESHOLD: usize = 2048;
|
||||
|
||||
/// A trait for caching sub-aggregation doc ids per bucket id.
|
||||
/// Different implementations can be used depending on the cardinality
|
||||
/// of the parent aggregation.
|
||||
pub trait SubAggCache: Debug {
|
||||
fn new() -> Self;
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()>;
|
||||
}
|
||||
|
||||
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
|
||||
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
cache: Backend::new(),
|
||||
sub_agg_collector: sub_agg,
|
||||
num_docs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
|
||||
&mut self.sub_agg_collector
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
self.cache.push(bucket_id, doc_id);
|
||||
self.num_docs += 1;
|
||||
}
|
||||
|
||||
/// Check if we need to flush based on the number of documents cached.
|
||||
/// If so, flushes the cache to the provided aggregation collector.
|
||||
pub fn check_flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.num_docs >= FLUSH_THRESHOLD {
|
||||
self.cache
|
||||
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
|
||||
self.num_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this _does_ flush the sub aggregations.
|
||||
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if self.num_docs != 0 {
|
||||
self.cache
|
||||
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
|
||||
self.num_docs = 0;
|
||||
}
|
||||
self.sub_agg_collector.flush(agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of partitions for high cardinality sub-aggregation cache.
|
||||
const NUM_PARTITIONS: usize = 16;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct HighCardSubAggCache {
|
||||
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
|
||||
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
|
||||
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
|
||||
///
|
||||
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
|
||||
/// buckets, and there may be nothing to group.
|
||||
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
|
||||
}
|
||||
|
||||
impl HighCardSubAggCache {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
for partition in self.partitions.iter_mut() {
|
||||
partition.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct PartitionEntry {
|
||||
bucket_ids: Vec<BucketId>,
|
||||
docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl PartitionEntry {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
self.bucket_ids.clear();
|
||||
self.docs.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAggCache for HighCardSubAggCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
let idx = bucket_id % NUM_PARTITIONS as u32;
|
||||
let slot = &mut self.partitions[idx as usize];
|
||||
slot.bucket_ids.push(bucket_id);
|
||||
slot.docs.push(doc_id);
|
||||
}
|
||||
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
_force: bool,
|
||||
) -> crate::Result<()> {
|
||||
let mut max_bucket = 0u32;
|
||||
for partition in self.partitions.iter() {
|
||||
if let Some(&local_max) = partition.bucket_ids.iter().max() {
|
||||
max_bucket = max_bucket.max(local_max);
|
||||
}
|
||||
}
|
||||
|
||||
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
|
||||
for slot in self.partitions.iter() {
|
||||
if !slot.bucket_ids.is_empty() {
|
||||
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
|
||||
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LowCardSubAggCache {
|
||||
/// Cache doc ids per bucket for sub-aggregations.
|
||||
///
|
||||
/// The outer Vec is indexed by BucketId.
|
||||
per_bucket_docs: Vec<Vec<DocId>>,
|
||||
}
|
||||
|
||||
impl LowCardSubAggCache {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
for v in &mut self.per_bucket_docs {
|
||||
v.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAggCache for LowCardSubAggCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
per_bucket_docs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
let idx = bucket_id as usize;
|
||||
if self.per_bucket_docs.len() <= idx {
|
||||
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
|
||||
}
|
||||
self.per_bucket_docs[idx].push(doc_id);
|
||||
}
|
||||
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()> {
|
||||
// Pre-aggregated: call collect per bucket.
|
||||
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
|
||||
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
// The threshold above which we flush buckets individually.
|
||||
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
|
||||
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
|
||||
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
|
||||
const _: () = {
|
||||
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
|
||||
// Note: There may be other flexible values, for other aggregations, but we can use the
|
||||
// const value here as a upper bound. (better than nothing)
|
||||
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
|
||||
assert!(
|
||||
bucket_treshold_limit > 0,
|
||||
"Bucket threshold must be greater than 0"
|
||||
);
|
||||
};
|
||||
if force {
|
||||
bucket_treshold = 0;
|
||||
}
|
||||
for (bucket_id, docs) in self
|
||||
.per_bucket_docs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, docs)| docs.len() > bucket_treshold)
|
||||
{
|
||||
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
|
||||
}
|
||||
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::agg_req::Aggregations;
|
||||
use super::agg_result::AggregationResults;
|
||||
use super::buf_collector::BufAggregationCollector;
|
||||
use super::cached_sub_aggs::LowCardCachedSubAggs;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use super::AggContextParams;
|
||||
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
|
||||
use crate::aggregation::agg_data::{
|
||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||
};
|
||||
@@ -136,7 +136,7 @@ fn merge_fruits(
|
||||
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
|
||||
pub struct AggregationSegmentCollector {
|
||||
aggs_with_accessor: AggregationsSegmentCtx,
|
||||
agg_collector: BufAggregationCollector,
|
||||
agg_collector: LowCardCachedSubAggs,
|
||||
error: Option<TantivyError>,
|
||||
}
|
||||
|
||||
@@ -151,8 +151,11 @@ impl AggregationSegmentCollector {
|
||||
) -> crate::Result<Self> {
|
||||
let mut agg_data =
|
||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||
let result =
|
||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
let mut result =
|
||||
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
result
|
||||
.get_sub_agg_collector()
|
||||
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
|
||||
|
||||
Ok(AggregationSegmentCollector {
|
||||
aggs_with_accessor: agg_data,
|
||||
@@ -170,26 +173,31 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
self.agg_collector.push(0, doc);
|
||||
match self
|
||||
.agg_collector
|
||||
.collect(doc, &mut self.aggs_with_accessor)
|
||||
.check_flush_local(&mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The query pushes the documents to the collector via this method.
|
||||
///
|
||||
/// Only valid for Collectors that ignore docs
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.collect_block(docs, &mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
|
||||
match self.agg_collector.get_sub_agg_collector().collect(
|
||||
0,
|
||||
docs,
|
||||
&mut self.aggs_with_accessor,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,10 +208,13 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
|
||||
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
Box::new(self.agg_collector).add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
)?;
|
||||
self.agg_collector
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
0,
|
||||
)?;
|
||||
|
||||
Ok(sub_aggregation_res)
|
||||
}
|
||||
|
||||
@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
|
||||
/// The number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
/// The sub_aggregation in this bucket.
|
||||
pub sub_aggregation: IntermediateAggregationResults,
|
||||
pub sub_aggregation_res: IntermediateAggregationResults,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`.
|
||||
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: self
|
||||
.sub_aggregation
|
||||
.sub_aggregation_res
|
||||
.into_final_result_internal(req, limits)?,
|
||||
to: self.to,
|
||||
from: self.from,
|
||||
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
|
||||
impl MergeFruits for IntermediateRangeBucketEntry {
|
||||
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
|
||||
self.doc_count += other.doc_count;
|
||||
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
|
||||
self.sub_aggregation_res
|
||||
.merge_fruits(other.sub_aggregation_res)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -887,7 +888,7 @@ mod tests {
|
||||
IntermediateRangeBucketEntry {
|
||||
key: IntermediateKey::Str(key.to_string()),
|
||||
doc_count: *doc_count,
|
||||
sub_aggregation: Default::default(),
|
||||
sub_aggregation_res: Default::default(),
|
||||
from: None,
|
||||
to: None,
|
||||
},
|
||||
@@ -920,7 +921,7 @@ mod tests {
|
||||
doc_count: *doc_count,
|
||||
from: None,
|
||||
to: None,
|
||||
sub_aggregation: get_sub_test_tree(&[(
|
||||
sub_aggregation_res: get_sub_test_tree(&[(
|
||||
sub_aggregation_key.to_string(),
|
||||
*sub_aggregation_count,
|
||||
)]),
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
|
||||
|
||||
impl IntermediateAverage {
|
||||
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_value_for_accessor: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The aggregation request.
|
||||
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
buckets: Vec<SegmentCardinalityCollectorBucket>,
|
||||
accessor_idx: usize,
|
||||
/// The column accessor to access the fast field values.
|
||||
accessor: Column<u64>,
|
||||
/// The column_type of the field.
|
||||
column_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
|
||||
#[derive(Clone, Debug, PartialEq, Default)]
|
||||
pub(crate) struct SegmentCardinalityCollectorBucket {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
}
|
||||
impl SegmentCardinalityCollectorBucket {
|
||||
pub fn new(column_type: ColumnType) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: Default::default(),
|
||||
accessor_idx,
|
||||
entries: FxHashSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut CardinalityAggReqData,
|
||||
) {
|
||||
if let Some(missing) = agg_data.missing_value_for_accessor {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_data.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
req_data: &CardinalityAggReqData,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = req_data
|
||||
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
|
||||
term_ids.push(term_ord as u32);
|
||||
}
|
||||
}
|
||||
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
accessor: Column<u64>,
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
|
||||
column_type,
|
||||
accessor_idx,
|
||||
accessor,
|
||||
missing_value_for_accessor,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_value_for_accessor,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
let name = req_data.name.to_string();
|
||||
// take the bucket in buckets and replace it with a new empty one
|
||||
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
|
||||
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
|
||||
self.fetch_block_with_field(docs, req_data);
|
||||
self.fetch_block_with_field(docs, agg_data);
|
||||
let bucket = &mut self.buckets[parent_bucket_id as usize];
|
||||
|
||||
let col_block_accessor = &req_data.column_block_accessor;
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let col_block_accessor = &agg_data.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
self.entries.insert(term_ord);
|
||||
bucket.entries.insert(term_ord);
|
||||
}
|
||||
} else if req_data.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = req_data
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if max_bucket as usize >= self.buckets.len() {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
SegmentCardinalityCollectorBucket::new(self.column_type)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateCount {
|
||||
|
||||
impl IntermediateCount {
|
||||
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateCount) {
|
||||
|
||||
@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of extended statistics
|
||||
/// on numeric values that are extracted
|
||||
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentExtendedStatsCollector {
|
||||
name: String,
|
||||
missing: Option<u64>,
|
||||
field_type: ColumnType,
|
||||
pub(crate) extended_stats: IntermediateExtendedStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
accessor: columnar::Column<u64>,
|
||||
buckets: Vec<IntermediateExtendedStats>,
|
||||
sigma: Option<f64>,
|
||||
}
|
||||
|
||||
impl SegmentExtendedStatsCollector {
|
||||
pub fn from_req(
|
||||
field_type: ColumnType,
|
||||
sigma: Option<f64>,
|
||||
accessor_idx: usize,
|
||||
missing: Option<f64>,
|
||||
) -> Self {
|
||||
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
|
||||
let missing = req
|
||||
.missing
|
||||
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
|
||||
Self {
|
||||
field_type,
|
||||
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
|
||||
accessor_idx,
|
||||
name: req.name.clone(),
|
||||
field_type: req.field_type,
|
||||
accessor: req.accessor.clone(),
|
||||
missing,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
|
||||
sigma,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
|
||||
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
let name = self.name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
|
||||
self.extended_stats,
|
||||
extended_stats,
|
||||
)),
|
||||
)?;
|
||||
|
||||
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = self.missing {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.extended_stats
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block_with_missing(docs, &self.accessor, self.missing);
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
extended_stats.collect(val1);
|
||||
}
|
||||
|
||||
// store back
|
||||
self.buckets[parent_bucket_id as usize] = extended_stats;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
if self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
IntermediateExtendedStats::with_sigma(self.sigma)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateMax {
|
||||
|
||||
impl IntermediateMax {
|
||||
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMax) {
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateMin {
|
||||
|
||||
impl IntermediateMin {
|
||||
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMin) {
|
||||
|
||||
@@ -31,7 +31,7 @@ use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// Used when converting to intermediate result
|
||||
|
||||
@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// # Percentiles
|
||||
///
|
||||
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentPercentilesCollector {
|
||||
pub(crate) percentiles: PercentilesCollector,
|
||||
pub(crate) buckets: Vec<PercentilesCollector>,
|
||||
pub(crate) accessor_idx: usize,
|
||||
/// The type of the field.
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -229,33 +234,18 @@ impl PercentilesCollector {
|
||||
}
|
||||
|
||||
impl SegmentPercentilesCollector {
|
||||
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
percentiles: PercentilesCollector::new(),
|
||||
pub fn from_req_and_validate(
|
||||
field_type: ColumnType,
|
||||
missing_u64: Option<u64>,
|
||||
accessor: Column<u64>,
|
||||
accessor_idx: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: Vec::with_capacity(64),
|
||||
field_type,
|
||||
missing_u64,
|
||||
accessor,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
|
||||
impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
// Swap collector with an empty one to avoid cloning
|
||||
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_metric_result =
|
||||
IntermediateMetricResult::Percentiles(percentiles_collector);
|
||||
|
||||
results.push(
|
||||
name,
|
||||
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let percentiles = &mut self.buckets[parent_bucket_id as usize];
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.percentiles
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
percentiles.collect(val1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
while self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.push(PercentilesCollector::new());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
/// are extracted from the aggregated documents.
|
||||
@@ -83,7 +83,7 @@ impl Stats {
|
||||
|
||||
/// Intermediate result of the stats aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateStats {
|
||||
/// The number of extracted values.
|
||||
pub(crate) count: u64,
|
||||
@@ -187,75 +187,75 @@ pub enum StatsType {
|
||||
Percentiles,
|
||||
}
|
||||
|
||||
fn create_collector<const TYPE_ID: u8>(
|
||||
req: &MetricAggReqData,
|
||||
) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(SegmentStatsCollector::<TYPE_ID> {
|
||||
name: req.name.clone(),
|
||||
collecting_for: req.collecting_for,
|
||||
is_number_or_date_type: req.is_number_or_date_type,
|
||||
missing_u64: req.missing_u64,
|
||||
accessor: req.accessor.clone(),
|
||||
buckets: vec![IntermediateStats::default()],
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a concrete `SegmentStatsCollector` depending on the column type.
|
||||
pub(crate) fn build_segment_stats_collector(
|
||||
req: &MetricAggReqData,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match req.field_type {
|
||||
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
|
||||
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
|
||||
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
|
||||
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
|
||||
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
|
||||
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
|
||||
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
|
||||
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
|
||||
pub(crate) missing_u64: Option<u64>,
|
||||
pub(crate) accessor: Column<u64>,
|
||||
pub(crate) is_number_or_date_type: bool,
|
||||
pub(crate) buckets: Vec<IntermediateStats>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) collecting_for: StatsType,
|
||||
}
|
||||
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
if req_data.is_number_or_date_type {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in req_data.column_block_accessor.iter_vals() {
|
||||
// we ignore the value and simply record that we got something
|
||||
self.stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
for SegmentStatsCollector<COLUMN_TYPE_ID>
|
||||
{
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let name = req.name.clone();
|
||||
let name = self.name.clone();
|
||||
|
||||
let intermediate_metric_result = match req.collecting_for {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let stats = self.buckets[parent_bucket_id as usize];
|
||||
let intermediate_metric_result = match self.collecting_for {
|
||||
StatsType::Average => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
|
||||
}
|
||||
StatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
|
||||
}
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
|
||||
_ => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Unsupported stats type for stats aggregation: {:?}",
|
||||
req.collecting_for
|
||||
self.collecting_for
|
||||
)))
|
||||
}
|
||||
};
|
||||
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.stats
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
// TODO: remove once we fetch all values for all bucket ids in one go
|
||||
if docs.len() == 1 && self.missing_u64.is_none() {
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
self.accessor.values_for_doc(docs[0]),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
agg_data.column_block_accessor.iter_vals(),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let required_buckets = (max_bucket as usize) + 1;
|
||||
if self.buckets.len() < required_buckets {
|
||||
self.buckets
|
||||
.resize_with(required_buckets, IntermediateStats::default);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_stats<const COLUMN_TYPE_ID: u8>(
|
||||
stats: &mut IntermediateStats,
|
||||
vals: impl Iterator<Item = u64>,
|
||||
is_number_or_date_type: bool,
|
||||
) -> crate::Result<()> {
|
||||
if is_number_or_date_type {
|
||||
for val in vals {
|
||||
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
|
||||
stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in vals {
|
||||
// we ignore the value and simply record that we got something
|
||||
stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateSum {
|
||||
|
||||
impl IntermediateSum {
|
||||
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange};
|
||||
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn};
|
||||
use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR;
|
||||
use common::DateTime;
|
||||
use regex::Regex;
|
||||
@@ -16,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::collector::sort_key::{Comparator, ReverseComparator};
|
||||
use crate::aggregation::{AggregationError, BucketId};
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
// duplicate import removed; already imported above
|
||||
|
||||
/// Contains all information required by the TopHitsSegmentCollector to perform the
|
||||
/// top_hits aggregation on a segment.
|
||||
@@ -384,7 +382,7 @@ impl From<FastFieldValue> for OwnedValue {
|
||||
|
||||
/// Holds a fast field value in its u64 representation, and the order in which it should be sorted.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(crate) struct DocValueAndOrder {
|
||||
struct DocValueAndOrder {
|
||||
/// A fast field value in its u64 representation.
|
||||
value: Option<u64>,
|
||||
/// Sort order for the value
|
||||
@@ -456,37 +454,6 @@ impl PartialEq for DocSortValuesAndFields {
|
||||
|
||||
impl Eq for DocSortValuesAndFields {}
|
||||
|
||||
impl Comparator<DocSortValuesAndFields> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering {
|
||||
rhs.cmp(lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: DocSortValuesAndFields,
|
||||
) -> ValueRange<DocSortValuesAndFields> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub(crate) struct TopHitsSegmentSortKey(pub Vec<DocValueAndOrder>);
|
||||
|
||||
impl Comparator<TopHitsSegmentSortKey> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering {
|
||||
rhs.cmp(lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: TopHitsSegmentSortKey,
|
||||
) -> ValueRange<TopHitsSegmentSortKey> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// The TopHitsCollector used for collecting over segments and merging results.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct TopHitsTopNComputer {
|
||||
@@ -504,7 +471,10 @@ impl TopHitsTopNComputer {
|
||||
/// Create a new TopHitsCollector
|
||||
pub fn new(req: &TopHitsAggregationReq) -> Self {
|
||||
Self {
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
top_n: TopNComputer::new_with_comparator(
|
||||
req.size + req.from.unwrap_or(0),
|
||||
ReverseComparator,
|
||||
),
|
||||
req: req.clone(),
|
||||
}
|
||||
}
|
||||
@@ -550,7 +520,8 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
top_n: TopNComputer<TopHitsSegmentSortKey, DocAddress, ReverseComparator>,
|
||||
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
|
||||
num_hits: usize,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -559,27 +530,35 @@ impl TopHitsSegmentCollector {
|
||||
accessor_idx: usize,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
) -> Self {
|
||||
let num_hits = req.size + req.from.unwrap_or(0);
|
||||
Self {
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
num_hits,
|
||||
segment_ordinal,
|
||||
accessor_idx,
|
||||
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
|
||||
}
|
||||
}
|
||||
fn into_top_hits_collector(
|
||||
self,
|
||||
fn get_top_hits_computer(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
if parent_bucket_id as usize >= self.buckets.len() {
|
||||
return TopHitsTopNComputer::new(req);
|
||||
}
|
||||
let top_n = std::mem::replace(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
TopNComputer::new(0),
|
||||
);
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
// Map TopHitsSegmentSortKey back to Vec<DocValueAndOrder> if needed or use directly
|
||||
// The TopNComputer here stores TopHitsSegmentSortKey.
|
||||
let top_results = self.top_n.into_vec();
|
||||
let top_results = top_n.into_vec();
|
||||
|
||||
for res in top_results {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
top_hits_computer.collect(
|
||||
DocSortValuesAndFields {
|
||||
sorts: res.sort_key.0,
|
||||
sorts: res.sort_key,
|
||||
doc_value_fields,
|
||||
},
|
||||
res.doc,
|
||||
@@ -588,54 +567,24 @@ impl TopHitsSegmentCollector {
|
||||
|
||||
top_hits_computer
|
||||
}
|
||||
|
||||
/// TODO add a specialized variant for a single sort field
|
||||
fn collect_with(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessors: &[(Column<u64>, ColumnType)],
|
||||
) -> crate::Result<()> {
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.top_n.push(
|
||||
TopHitsSegmentSortKey(sorts),
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
|
||||
let value_accessors = &req_data.value_accessors;
|
||||
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(
|
||||
self.into_top_hits_collector(value_accessors, &req_data.req),
|
||||
);
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
|
||||
parent_bucket_id,
|
||||
value_accessors,
|
||||
&req_data.req,
|
||||
));
|
||||
results.push(
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -645,26 +594,56 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
/// TODO: Consider a caching layer to reduce the call overhead
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let top_n = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
// TODO: Consider getting fields with the column block accessor.
|
||||
for doc in docs {
|
||||
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
|
||||
let req = &req_data.req;
|
||||
let accessors = &req_data.accessors;
|
||||
for &doc_id in docs {
|
||||
// TODO: this is terrible, a new vec is allocated for every doc
|
||||
// We can fetch blocks instead
|
||||
// We don't need to store the order for every value
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.buckets.resize(
|
||||
(max_bucket as usize) + 1,
|
||||
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -780,7 +759,7 @@ mod tests {
|
||||
],
|
||||
"from": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
@@ -909,7 +888,7 @@ mod tests {
|
||||
"mixed.*",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
|
||||
let collector = AggregationCollector::from_aggs(d, Default::default());
|
||||
|
||||
@@ -133,7 +133,7 @@ mod agg_limits;
|
||||
pub mod agg_req;
|
||||
pub mod agg_result;
|
||||
pub mod bucket;
|
||||
mod buf_collector;
|
||||
pub(crate) mod cached_sub_aggs;
|
||||
mod collector;
|
||||
mod date;
|
||||
mod error;
|
||||
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
|
||||
/// A bucket id is a dense identifier for a bucket within an aggregation.
|
||||
/// It is used to index into a Vec that hold per-bucket data.
|
||||
///
|
||||
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
|
||||
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
|
||||
///
|
||||
/// This allows to have a single AggregationCollector instance per aggregation,
|
||||
/// that can handle multiple buckets efficiently.
|
||||
///
|
||||
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
|
||||
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
|
||||
pub type BucketId = u32;
|
||||
|
||||
/// Context parameters for aggregation execution
|
||||
///
|
||||
/// This struct holds shared resources needed during aggregation execution:
|
||||
@@ -335,19 +348,37 @@ impl Display for Key {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
|
||||
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
|
||||
val as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|
||||
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
|
||||
{
|
||||
i64::from_u64(val) as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
|
||||
f64::from_u64(val)
|
||||
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
|
||||
val as f64
|
||||
} else {
|
||||
panic!(
|
||||
"ColumnType ID {} cannot be converted to f64 metric",
|
||||
COLUMN_TYPE_ID
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
|
||||
///
|
||||
/// # Panics
|
||||
/// Only `u64`, `f64`, `date`, and `i64` are supported.
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
|
||||
match field_type {
|
||||
ColumnType::U64 => val as f64,
|
||||
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
|
||||
ColumnType::F64 => f64::from_u64(val),
|
||||
ColumnType::Bool => val as f64,
|
||||
_ => {
|
||||
panic!("unexpected type {field_type:?}. This should not happen")
|
||||
}
|
||||
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
|
||||
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
|
||||
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
|
||||
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
|
||||
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
|
||||
_ => panic!("unexpected type {field_type:?}. This should not happen"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,25 +8,67 @@ use std::fmt::Debug;
|
||||
pub(crate) use super::agg_limits::AggregationLimitsGuard;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Monotonically increasing provider of BucketIds.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BucketIdProvider(u32);
|
||||
impl BucketIdProvider {
|
||||
/// Get the next BucketId.
|
||||
pub fn next_bucket_id(&mut self) -> BucketId {
|
||||
let bucket_id = self.0;
|
||||
self.0 += 1;
|
||||
bucket_id
|
||||
}
|
||||
}
|
||||
|
||||
/// A SegmentAggregationCollector is used to collect aggregation results.
|
||||
pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
pub trait SegmentAggregationCollector: Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
fn collect_block(
|
||||
/// Collect docs for multiple buckets in one call.
|
||||
/// Minimizes dynamic dispatch overhead when collecting many buckets.
|
||||
///
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect_multiple(
|
||||
&mut self,
|
||||
bucket_ids: &[BucketId],
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
debug_assert_eq!(bucket_ids.len(), docs.len());
|
||||
let mut start = 0;
|
||||
while start < bucket_ids.len() {
|
||||
let bucket_id = bucket_ids[start];
|
||||
let mut end = start + 1;
|
||||
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
|
||||
end += 1;
|
||||
}
|
||||
self.collect(bucket_id, &docs[start..end], agg_data)?;
|
||||
start = end;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the collector for collecting up to BucketId `max_bucket`.
|
||||
/// This is useful so we can split allocation ahead of time of collecting.
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
|
||||
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
|
||||
pub trait CollectorClone {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
|
||||
}
|
||||
|
||||
impl<T> CollectorClone for T
|
||||
where T: 'static + SegmentAggregationCollector + Clone
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn SegmentAggregationCollector> {
|
||||
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
#[derive(Default)]
|
||||
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
|
||||
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
|
||||
/// and can provide specialized versions instead, that remove some of its overhead.
|
||||
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
for agg in self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results)?;
|
||||
for agg in &mut self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.collect_block(docs, agg_data)?;
|
||||
collector.collect(parent_bucket_id, docs, agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -821,6 +821,7 @@ mod tests {
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use test::Bencher;
|
||||
|
||||
@@ -96,10 +96,11 @@ mod histogram_collector;
|
||||
pub use histogram_collector::HistogramCollector;
|
||||
|
||||
mod multi_collector;
|
||||
pub use columnar::ComparableDoc;
|
||||
|
||||
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
|
||||
|
||||
mod top_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_score_collector::{TopDocs, TopNComputer};
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ impl SegmentCollector for MultiCollectorChild {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::query::TermQuery;
|
||||
|
||||
@@ -13,13 +13,31 @@ pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
|
||||
// By spec, regardless of whether ascending or descending order was requested, in presence of a
|
||||
// tie, we sort by ascending doc id/doc address.
|
||||
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
|
||||
hits: &mut [ComparableDoc<TSortKey, D>],
|
||||
order: Order,
|
||||
) {
|
||||
if order.is_asc() {
|
||||
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
|
||||
} else {
|
||||
hits.sort_by(|l, r| {
|
||||
l.sort_key
|
||||
.cmp(&r.sort_key)
|
||||
.reverse() // This is descending
|
||||
.then(l.doc.cmp(&r.doc))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
Comparator, NaturalComparator, ReverseComparator, SortByErasedType, SortBySimilarityScore,
|
||||
SortByStaticFastValue, SortByString,
|
||||
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
@@ -371,52 +389,6 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type CompoundSortKey = (Option<String>, Option<f64>);
|
||||
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
city_order: Order,
|
||||
altitude_order: Order,
|
||||
expected: Vec<(CompoundSortKey, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((
|
||||
(SortByString::for_field("city"), city_order),
|
||||
(
|
||||
SortByStaticFastValue::<f64>::for_field("altitude"),
|
||||
altitude_order,
|
||||
),
|
||||
));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(key, doc)| (key, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
Order::Desc,
|
||||
vec![
|
||||
((Some("austin".to_owned()), Some(149.0)), 0),
|
||||
((Some("greenville".to_owned()), Some(27.0)), 1),
|
||||
((Some("tokyo".to_owned()), Some(40.0)), 2),
|
||||
((None, Some(0.0)), 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
@@ -469,11 +441,7 @@ mod tests {
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
if order.is_desc() {
|
||||
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
|
||||
} else {
|
||||
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
|
||||
}
|
||||
sort_hits(&mut comparable_docs, order);
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
@@ -483,197 +451,4 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_compound_prop(
|
||||
city_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
altitude_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..20_usize,
|
||||
offset in 0..20_usize,
|
||||
segments_data in proptest::collection::vec(
|
||||
proptest::collection::vec(
|
||||
(proptest::option::of("[a-c]"), proptest::option::of(0..50u64)),
|
||||
1..10_usize // segment size
|
||||
),
|
||||
1..4_usize // num segments
|
||||
)
|
||||
) {
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::TantivyDocument;
|
||||
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let altitude = schema_builder.add_u64_field("altitude", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for segment_data in segments_data.into_iter() {
|
||||
for (city_val, altitude_val) in segment_data.into_iter() {
|
||||
let mut doc = TantivyDocument::default();
|
||||
if let Some(c) = city_val {
|
||||
doc.add_text(city, c);
|
||||
}
|
||||
if let Some(a) = altitude_val {
|
||||
doc.add_u64(altitude, a);
|
||||
}
|
||||
index_writer.add_document(doc).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
|
||||
let top_collector = TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by((
|
||||
(SortByString::for_field("city"), city_order),
|
||||
(
|
||||
SortByStaticFastValue::<u64>::for_field("altitude"),
|
||||
altitude_order,
|
||||
),
|
||||
));
|
||||
|
||||
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
|
||||
let actual_doc_ids: Vec<DocAddress> =
|
||||
actual_results.into_iter().map(|(_, doc)| doc).collect();
|
||||
|
||||
// Verification logic
|
||||
let all_docs_collector = DocSetCollector;
|
||||
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
|
||||
|
||||
let docs_with_keys: Vec<((Option<String>, Option<u64>), DocAddress)> = all_docs
|
||||
.into_iter()
|
||||
.map(|doc_addr| {
|
||||
let reader = searcher.segment_reader(doc_addr.segment_ord);
|
||||
|
||||
let city_val = if let Some(col) = reader.fast_fields().str("city").unwrap() {
|
||||
let ord = col.ords().first(doc_addr.doc_id);
|
||||
if let Some(ord) = ord {
|
||||
let mut out = Vec::new();
|
||||
col.dictionary().ord_to_term(ord, &mut out).unwrap();
|
||||
String::from_utf8(out).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let alt_val = if let Some((col, _)) = reader.fast_fields().u64_lenient("altitude").unwrap() {
|
||||
col.first(doc_addr.doc_id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
((city_val, alt_val), doc_addr)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let city_comparator = ComparatorEnum::from(city_order);
|
||||
let alt_comparator = ComparatorEnum::from(altitude_order);
|
||||
let comparator = (city_comparator, alt_comparator);
|
||||
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
|
||||
.into_iter()
|
||||
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
|
||||
.collect();
|
||||
|
||||
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
|
||||
|
||||
let expected_results = comparable_docs
|
||||
.into_iter()
|
||||
.skip(offset)
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_doc_ids: Vec<DocAddress> =
|
||||
expected_results.into_iter().map(|cd| cd.doc).collect();
|
||||
|
||||
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_u64_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..20_usize,
|
||||
offset in 0..20_usize,
|
||||
segments_data in proptest::collection::vec(
|
||||
proptest::collection::vec(
|
||||
proptest::option::of(0..100u64),
|
||||
1..1000_usize // segment size
|
||||
),
|
||||
1..4_usize // num segments
|
||||
)
|
||||
) {
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::TantivyDocument;
|
||||
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for segment_data in segments_data.into_iter() {
|
||||
for val in segment_data.into_iter() {
|
||||
let mut doc = TantivyDocument::default();
|
||||
if let Some(v) = val {
|
||||
doc.add_u64(field, v);
|
||||
}
|
||||
index_writer.add_document(doc).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
|
||||
let top_collector = TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by((SortByStaticFastValue::<u64>::for_field("field"), order));
|
||||
|
||||
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
|
||||
let actual_doc_ids: Vec<DocAddress> =
|
||||
actual_results.into_iter().map(|(_, doc)| doc).collect();
|
||||
|
||||
// Verification logic
|
||||
let all_docs_collector = DocSetCollector;
|
||||
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
|
||||
|
||||
let docs_with_keys: Vec<(Option<u64>, DocAddress)> = all_docs
|
||||
.into_iter()
|
||||
.map(|doc_addr| {
|
||||
let reader = searcher.segment_reader(doc_addr.segment_ord);
|
||||
let val = if let Some((col, _)) = reader.fast_fields().u64_lenient("field").unwrap() {
|
||||
col.first(doc_addr.doc_id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(val, doc_addr)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let comparator = ComparatorEnum::from(order);
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
|
||||
.into_iter()
|
||||
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
|
||||
.collect();
|
||||
|
||||
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
|
||||
|
||||
let expected_results = comparable_docs
|
||||
.into_iter()
|
||||
.skip(offset)
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_doc_ids: Vec<DocAddress> =
|
||||
expected_results.into_iter().map(|cd| cd.doc).collect();
|
||||
|
||||
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{MonotonicallyMappableToU64, ValueRange};
|
||||
use columnar::MonotonicallyMappableToU64;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::{OwnedValue, Schema};
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
@@ -69,26 +69,6 @@ fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedVal
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
/// Return the order between two ComparableDoc values, using the semantics which are
|
||||
/// implemented by TopNComputer.
|
||||
#[inline(always)]
|
||||
fn compare_doc<D: Ord>(
|
||||
&self,
|
||||
lhs: &ComparableDoc<T, D>,
|
||||
rhs: &ComparableDoc<T, D>,
|
||||
) -> Ordering {
|
||||
// TopNComputer sorts in descending order of the SortKey by default: we apply that ordering
|
||||
// here to ease comparison in testing.
|
||||
self.compare(&rhs.sort_key, &lhs.sort_key).then_with(|| {
|
||||
// In case of a tie on the sort key, we always sort by ascending `DocAddress` in order
|
||||
// to ensure a stable sorting of the documents, regardless of the sort key's order.
|
||||
// See the TopNComputer docs for more information.
|
||||
lhs.doc.cmp(&rhs.doc)
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
|
||||
}
|
||||
|
||||
/// Compare values naturally (e.g. 1 < 2).
|
||||
@@ -104,11 +84,7 @@ pub struct NaturalComparator;
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,10 +97,6 @@ impl Comparator<OwnedValue> for NaturalComparator {
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse (e.g. 2 < 1).
|
||||
@@ -142,69 +114,13 @@ impl Comparator<OwnedValue> for NaturalComparator {
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
macro_rules! impl_reverse_comparator_primitive {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl Comparator<$t> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
impl_reverse_comparator_primitive!(
|
||||
bool,
|
||||
u8,
|
||||
u16,
|
||||
u32,
|
||||
u64,
|
||||
u128,
|
||||
usize,
|
||||
i8,
|
||||
i16,
|
||||
i32,
|
||||
i64,
|
||||
i128,
|
||||
isize,
|
||||
f32,
|
||||
f64,
|
||||
String,
|
||||
crate::DateTime,
|
||||
Vec<u8>,
|
||||
crate::schema::Facet
|
||||
);
|
||||
|
||||
impl<T: PartialOrd + Send + Sync + std::fmt::Debug + Clone + 'static> Comparator<Option<T>>
|
||||
for ReverseComparator
|
||||
impl<T> Comparator<T> for ReverseComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &Option<T>, rhs: &Option<T>) -> Ordering {
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
let is_some = threshold.is_some();
|
||||
ValueRange::LessThan(threshold, is_some)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
let is_not_null = !matches!(threshold, OwnedValue::Null);
|
||||
ValueRange::LessThan(threshold, is_not_null)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse, but treating `None` as lower than `Some`.
|
||||
@@ -231,14 +147,6 @@ where ReverseComparator: Comparator<T>
|
||||
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
if threshold.is_some() {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
} else {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for ReverseNoneIsLowerComparator {
|
||||
@@ -246,10 +154,6 @@ impl Comparator<u32> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for ReverseNoneIsLowerComparator {
|
||||
@@ -257,10 +161,6 @@ impl Comparator<u64> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for ReverseNoneIsLowerComparator {
|
||||
@@ -268,10 +168,6 @@ impl Comparator<f64> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for ReverseNoneIsLowerComparator {
|
||||
@@ -279,10 +175,6 @@ impl Comparator<f32> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for ReverseNoneIsLowerComparator {
|
||||
@@ -290,10 +182,6 @@ impl Comparator<i64> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
@@ -301,10 +189,6 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
|
||||
@@ -312,10 +196,6 @@ impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values naturally, but treating `None` as higher than `Some`.
|
||||
@@ -338,15 +218,6 @@ where NaturalComparator: Comparator<T>
|
||||
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
if threshold.is_some() {
|
||||
let is_some = threshold.is_some();
|
||||
ValueRange::GreaterThan(threshold, is_some)
|
||||
} else {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for NaturalNoneIsHigherComparator {
|
||||
@@ -354,10 +225,6 @@ impl Comparator<u32> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for NaturalNoneIsHigherComparator {
|
||||
@@ -365,10 +232,6 @@ impl Comparator<u64> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for NaturalNoneIsHigherComparator {
|
||||
@@ -376,10 +239,6 @@ impl Comparator<f64> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for NaturalNoneIsHigherComparator {
|
||||
@@ -387,10 +246,6 @@ impl Comparator<f32> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for NaturalNoneIsHigherComparator {
|
||||
@@ -398,10 +253,6 @@ impl Comparator<i64> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for NaturalNoneIsHigherComparator {
|
||||
@@ -409,10 +260,6 @@ impl Comparator<String> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
|
||||
@@ -420,10 +267,6 @@ impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
@@ -465,19 +308,6 @@ where
|
||||
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
|
||||
match self {
|
||||
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
|
||||
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
|
||||
ComparatorEnum::ReverseNoneLower => {
|
||||
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
|
||||
}
|
||||
ComparatorEnum::NaturalNoneHigher => {
|
||||
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
|
||||
@@ -492,10 +322,6 @@ where
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
|
||||
@@ -512,13 +338,6 @@ where
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, (Type2, Type3)),
|
||||
) -> ValueRange<(Type1, (Type2, Type3))> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
|
||||
@@ -535,13 +354,6 @@ where
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, Type2, Type3),
|
||||
) -> ValueRange<(Type1, Type2, Type3)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
@@ -565,13 +377,6 @@ where
|
||||
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
|
||||
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, (Type2, (Type3, Type4))),
|
||||
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
@@ -595,13 +400,6 @@ where
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, Type2, Type3, Type4),
|
||||
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
|
||||
@@ -691,32 +489,16 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
type SegmentComparator = TComparator;
|
||||
type Buffer = TSegmentSortKeyComputer::Buffer;
|
||||
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
self.comparator.clone()
|
||||
}
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
self.segment_sort_key_computer
|
||||
.segment_sort_keys(input_docs, output, buffer, filter)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
@@ -737,13 +519,36 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::schema::OwnedValue;
|
||||
|
||||
#[test]
|
||||
fn test_natural_none_is_higher() {
|
||||
let comp = NaturalNoneIsHigherComparator;
|
||||
let null = None;
|
||||
let v1 = Some(1_u64);
|
||||
let v2 = Some(2_u64);
|
||||
|
||||
// NaturalNoneIsGreaterComparator logic:
|
||||
// 1. Delegates to NaturalComparator for non-nulls.
|
||||
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
|
||||
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
|
||||
|
||||
// 2. Treats None (Null) as Greater than any value.
|
||||
// compare(None, Some(2)) should be Greater.
|
||||
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
|
||||
|
||||
// compare(Some(1), None) should be Less.
|
||||
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
|
||||
|
||||
// compare(None, None) should be Equal.
|
||||
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_ownedvalue_compare() {
|
||||
let u = OwnedValue::U64(10);
|
||||
let i = OwnedValue::I64(10);
|
||||
let f = OwnedValue::F64(10.0);
|
||||
|
||||
let nc = NaturalComparator::default();
|
||||
let nc = NaturalComparator;
|
||||
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
|
||||
@@ -759,27 +564,4 @@ mod tests {
|
||||
// Str < F64
|
||||
assert_eq!(nc.compare(&s, &f), Ordering::Less);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_natural_none_is_higher() {
|
||||
let comp = NaturalNoneIsHigherComparator;
|
||||
let null = OwnedValue::Null;
|
||||
let v1 = OwnedValue::U64(1);
|
||||
let v2 = OwnedValue::U64(2);
|
||||
|
||||
// NaturalNoneIsGreaterComparator logic:
|
||||
// 1. Delegates to NaturalComparator for non-nulls.
|
||||
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
|
||||
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
|
||||
|
||||
// 2. Treats None (Null) as Greater than any value.
|
||||
// compare(Null, 2) should be Greater.
|
||||
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
|
||||
|
||||
// compare(1, Null) should be Less.
|
||||
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
|
||||
|
||||
// compare(Null, Null) should be Equal.
|
||||
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64};
|
||||
|
||||
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
|
||||
use crate::collector::sort_key::{
|
||||
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastFieldNotAvailableError;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DateTime, DocId, Score};
|
||||
@@ -37,23 +36,12 @@ impl SortByErasedType {
|
||||
|
||||
trait ErasedSegmentSortKeyComputer: Send + Sync {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
filter: ValueRange<Option<u64>>,
|
||||
);
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
|
||||
}
|
||||
|
||||
struct ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
where
|
||||
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
|
||||
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
|
||||
{
|
||||
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
|
||||
inner: C,
|
||||
converter: F,
|
||||
buffer: C::Buffer,
|
||||
}
|
||||
|
||||
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
@@ -65,16 +53,6 @@ where
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
filter: ValueRange<Option<u64>>,
|
||||
) {
|
||||
self.inner
|
||||
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let val = self.inner.convert_segment_sort_key(sort_key);
|
||||
(self.converter)(val)
|
||||
@@ -82,7 +60,7 @@ where
|
||||
}
|
||||
|
||||
struct ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScoreSegmentComputer,
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}
|
||||
|
||||
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
|
||||
@@ -91,15 +69,6 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
|
||||
Some(score_value.to_u64())
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
_input_docs: &[DocId],
|
||||
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
_filter: ValueRange<Option<u64>>,
|
||||
) {
|
||||
unimplemented!("Batch computation not supported for score sorting")
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
|
||||
OwnedValue::F64(f64::from_u64(score_value))
|
||||
@@ -143,7 +112,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<String>| {
|
||||
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
@@ -154,7 +122,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<u64>| {
|
||||
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
@@ -165,7 +132,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<i64>| {
|
||||
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
@@ -176,7 +142,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<f64>| {
|
||||
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
@@ -187,7 +152,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<bool>| {
|
||||
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
@@ -198,7 +162,6 @@ impl SortKeyComputer for SortByErasedType {
|
||||
converter: |val: Option<DateTime>| {
|
||||
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
column_type => {
|
||||
@@ -211,8 +174,7 @@ impl SortKeyComputer for SortByErasedType {
|
||||
}
|
||||
}
|
||||
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore
|
||||
.segment_sort_key_computer(segment_reader)?,
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}),
|
||||
};
|
||||
Ok(ErasedColumnSegmentSortKeyComputer { inner })
|
||||
@@ -227,23 +189,12 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
|
||||
type SortKey = OwnedValue;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
self.inner.segment_sort_keys(input_docs, output, filter)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
|
||||
self.inner.convert_segment_sort_key(segment_sort_key)
|
||||
}
|
||||
@@ -382,7 +333,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {:?}", key),
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -400,7 +351,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {:?}", key),
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use columnar::ValueRange;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::{DocAddress, DocId, Score};
|
||||
|
||||
/// Sort by similarity score.
|
||||
@@ -11,7 +9,7 @@ pub struct SortBySimilarityScore;
|
||||
impl SortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type Child = SortBySimilarityScoreSegmentComputer;
|
||||
type Child = SortBySimilarityScore;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
@@ -23,7 +21,7 @@ impl SortKeyComputer for SortBySimilarityScore {
|
||||
&self,
|
||||
_segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
Ok(SortBySimilarityScoreSegmentComputer)
|
||||
Ok(SortBySimilarityScore)
|
||||
}
|
||||
|
||||
// Sorting by score is special in that it allows for the Block-Wand optimization.
|
||||
@@ -63,29 +61,16 @@ impl SortKeyComputer for SortBySimilarityScore {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SortBySimilarityScoreSegmentComputer;
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
type SegmentSortKey = Score;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
_input_docs: &[DocId],
|
||||
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
_filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
unimplemented!("Batch computation not supported for score sorting")
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use columnar::{Column, ValueRange};
|
||||
use columnar::Column;
|
||||
|
||||
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
@@ -85,110 +84,13 @@ impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu
|
||||
type SortKey = Option<T>;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_column.first(doc)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
|
||||
self.sort_column
|
||||
.first_vals_in_value_range(input_docs, output, u64_filter);
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key.map(T::from_u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::{Schema, FAST};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_fast_value_batch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 10u64))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 20u64))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByStaticFastValue::<u64>::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(10), Some(20), None]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_fast_value_batch_with_filter() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 10u64))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 20u64))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByStaticFastValue::<u64>::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(
|
||||
&mut docs,
|
||||
&mut output,
|
||||
&mut buffer,
|
||||
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(20)]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
use columnar::{StrColumn, ValueRange};
|
||||
use columnar::StrColumn;
|
||||
|
||||
use crate::collector::sort_key::sort_key_computer::{
|
||||
convert_optional_u64_range_to_u64_range, range_contains_none,
|
||||
};
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
@@ -53,7 +50,6 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
@@ -61,28 +57,6 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
str_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
if let Some(str_column) = &self.str_column_opt {
|
||||
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
|
||||
str_column
|
||||
.ords()
|
||||
.first_vals_in_value_range(input_docs, output, u64_filter);
|
||||
} else if range_contains_none(&filter) {
|
||||
for &doc in input_docs {
|
||||
output.push(ComparableDoc {
|
||||
doc,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
@@ -96,90 +70,3 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
String::try_from(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_string_batch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "a"))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "c"))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByString::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(0), Some(1), None]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_string_batch_with_filter() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "a"))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "c"))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByString::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
// Filter: > "b". "a" is 0, "c" is 1.
|
||||
// We want > "a" (ord 0). So we filter > ord 0.
|
||||
// 0 is "a", 1 is "c".
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(
|
||||
&mut docs,
|
||||
&mut output,
|
||||
&mut buffer,
|
||||
ValueRange::GreaterThan(Some(0), false /* inclusive */),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(1)]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::ValueRange;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, NaturalComparator};
|
||||
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
|
||||
use crate::collector::{
|
||||
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
|
||||
};
|
||||
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
@@ -25,10 +21,7 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Comparator type.
|
||||
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
|
||||
|
||||
/// Buffer type used for scratch space.
|
||||
type Buffer: Default + Send + Sync + 'static;
|
||||
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
|
||||
|
||||
/// Returns the segment sort key comparator.
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
@@ -38,18 +31,6 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
/// Computes the sort key for the given document and score.
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
|
||||
|
||||
/// Computes the sort keys for a batch of documents.
|
||||
///
|
||||
/// The computed sort keys and document IDs are pushed into the `output` vector.
|
||||
/// The `buffer` is used for scratch space.
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
);
|
||||
|
||||
/// Computes the sort key and pushes the document in a TopN Computer.
|
||||
///
|
||||
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
|
||||
@@ -58,32 +39,12 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key = self.segment_sort_key(doc, score);
|
||||
top_n_computer.push(sort_key, doc);
|
||||
}
|
||||
|
||||
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
|
||||
) {
|
||||
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
|
||||
// should always be able to `reserve` space for the entire block.
|
||||
top_n_computer.reserve(docs.len());
|
||||
|
||||
let comparator = self.segment_comparator();
|
||||
let value_range = if let Some(threshold) = &top_n_computer.threshold {
|
||||
comparator.threshold_to_valuerange(threshold.clone())
|
||||
} else {
|
||||
ValueRange::All
|
||||
};
|
||||
|
||||
let (buffer, scratch) = top_n_computer.buffer_and_scratch();
|
||||
self.segment_sort_keys(docs, buffer, scratch, value_range);
|
||||
}
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
@@ -97,6 +58,26 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
self.segment_comparator().compare(left, right)
|
||||
}
|
||||
|
||||
/// Implementing this method makes it possible to avoid computing
|
||||
/// a sort_key entirely if we can assess that it won't pass a threshold
|
||||
/// with a partial computation.
|
||||
///
|
||||
/// This is currently used for lexicographic sorting.
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let sort_key = self.segment_sort_key(doc_id, score);
|
||||
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
|
||||
if cmp == Ordering::Less {
|
||||
None
|
||||
} else {
|
||||
Some((cmp, sort_key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a segment level sort key into the global sort key.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
|
||||
}
|
||||
@@ -164,8 +145,7 @@ where
|
||||
TailSortKeyComputer: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
|
||||
type Child =
|
||||
ChainSegmentSortKeyComputer<HeadSortKeyComputer::Child, TailSortKeyComputer::Child>;
|
||||
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
|
||||
|
||||
type Comparator = (
|
||||
HeadSortKeyComputer::Comparator,
|
||||
@@ -177,10 +157,10 @@ where
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok(ChainSegmentSortKeyComputer {
|
||||
head: self.0.segment_sort_key_computer(segment_reader)?,
|
||||
tail: self.1.segment_sort_key_computer(segment_reader)?,
|
||||
})
|
||||
Ok((
|
||||
self.0.segment_sort_key_computer(segment_reader)?,
|
||||
self.1.segment_sort_key_computer(segment_reader)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Checks whether the schema is compatible with the sort key computer.
|
||||
@@ -198,91 +178,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ChainSegmentSortKeyComputer<Head, Tail>
|
||||
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
|
||||
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
|
||||
where
|
||||
Head: SegmentSortKeyComputer,
|
||||
Tail: SegmentSortKeyComputer,
|
||||
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
{
|
||||
head: Head,
|
||||
tail: Tail,
|
||||
}
|
||||
type SortKey = (
|
||||
HeadSegmentSortKeyComputer::SortKey,
|
||||
TailSegmentSortKeyComputer::SortKey,
|
||||
);
|
||||
type SegmentSortKey = (
|
||||
HeadSegmentSortKeyComputer::SegmentSortKey,
|
||||
TailSegmentSortKeyComputer::SegmentSortKey,
|
||||
);
|
||||
|
||||
pub struct ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey> {
|
||||
pub head: HeadBuffer,
|
||||
pub tail: TailBuffer,
|
||||
pub head_output: Vec<ComparableDoc<HeadKey, DocId>>,
|
||||
pub tail_output: Vec<ComparableDoc<TailKey, DocId>>,
|
||||
pub tail_input_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<HeadBuffer: Default, TailBuffer: Default, HeadKey, TailKey> Default
|
||||
for ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey>
|
||||
{
|
||||
fn default() -> Self {
|
||||
ChainBuffer {
|
||||
head: HeadBuffer::default(),
|
||||
tail: TailBuffer::default(),
|
||||
head_output: Vec::new(),
|
||||
tail_output: Vec::new(),
|
||||
tail_input_docs: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
|
||||
where
|
||||
Head: SegmentSortKeyComputer,
|
||||
Tail: SegmentSortKeyComputer,
|
||||
{
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &<Self as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
) -> Option<(Ordering, <Self as SegmentSortKeyComputer>::SegmentSortKey)> {
|
||||
let (head_threshold, tail_threshold) = threshold;
|
||||
let head_sort_key = self.head.segment_sort_key(doc_id, score);
|
||||
let head_cmp = self
|
||||
.head
|
||||
.compare_segment_sort_key(&head_sort_key, head_threshold);
|
||||
if head_cmp == Ordering::Less {
|
||||
None
|
||||
} else if head_cmp == Ordering::Equal {
|
||||
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
|
||||
let tail_cmp = self
|
||||
.tail
|
||||
.compare_segment_sort_key(&tail_sort_key, tail_threshold);
|
||||
if tail_cmp == Ordering::Less {
|
||||
None
|
||||
} else {
|
||||
Some((tail_cmp, (head_sort_key, tail_sort_key)))
|
||||
}
|
||||
} else {
|
||||
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
|
||||
Some((head_cmp, (head_sort_key, tail_sort_key)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail> SegmentSortKeyComputer for ChainSegmentSortKeyComputer<Head, Tail>
|
||||
where
|
||||
Head: SegmentSortKeyComputer,
|
||||
Tail: SegmentSortKeyComputer,
|
||||
{
|
||||
type SortKey = (Head::SortKey, Tail::SortKey);
|
||||
type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey);
|
||||
|
||||
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
|
||||
|
||||
type Buffer =
|
||||
ChainBuffer<Head::Buffer, Tail::Buffer, Head::SegmentSortKey, Tail::SegmentSortKey>;
|
||||
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
(
|
||||
self.head.segment_comparator(),
|
||||
self.tail.segment_comparator(),
|
||||
)
|
||||
}
|
||||
type SegmentComparator = (
|
||||
HeadSegmentSortKeyComputer::SegmentComparator,
|
||||
TailSegmentSortKeyComputer::SegmentComparator,
|
||||
);
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
@@ -294,90 +208,9 @@ where
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.head
|
||||
self.0
|
||||
.compare_segment_sort_key(&left.0, &right.0)
|
||||
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
let (head_filter, threshold) = match filter {
|
||||
ValueRange::GreaterThan((head_threshold, tail_threshold), _)
|
||||
| ValueRange::LessThan((head_threshold, tail_threshold), _) => {
|
||||
let head_cmp = self.head.segment_comparator();
|
||||
let strict_head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
|
||||
let head_filter = match strict_head_filter {
|
||||
ValueRange::GreaterThan(t, m) => ValueRange::GreaterThanOrEqual(t, m),
|
||||
ValueRange::LessThan(t, m) => ValueRange::LessThanOrEqual(t, m),
|
||||
other => other,
|
||||
};
|
||||
(head_filter, Some((head_threshold, tail_threshold)))
|
||||
}
|
||||
_ => (ValueRange::All, None),
|
||||
};
|
||||
|
||||
buffer.head_output.clear();
|
||||
self.head.segment_sort_keys(
|
||||
input_docs,
|
||||
&mut buffer.head_output,
|
||||
&mut buffer.head,
|
||||
head_filter,
|
||||
);
|
||||
|
||||
if buffer.head_output.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
buffer.tail_output.clear();
|
||||
buffer.tail_input_docs.clear();
|
||||
for cd in &buffer.head_output {
|
||||
buffer.tail_input_docs.push(cd.doc);
|
||||
}
|
||||
|
||||
self.tail.segment_sort_keys(
|
||||
&buffer.tail_input_docs,
|
||||
&mut buffer.tail_output,
|
||||
&mut buffer.tail,
|
||||
ValueRange::All,
|
||||
);
|
||||
|
||||
let head_cmp = self.head.segment_comparator();
|
||||
let tail_cmp = self.tail.segment_comparator();
|
||||
|
||||
for (head_doc, tail_doc) in buffer
|
||||
.head_output
|
||||
.drain(..)
|
||||
.zip(buffer.tail_output.drain(..))
|
||||
{
|
||||
debug_assert_eq!(head_doc.doc, tail_doc.doc);
|
||||
let doc = head_doc.doc;
|
||||
let head_key = head_doc.sort_key;
|
||||
let tail_key = tail_doc.sort_key;
|
||||
|
||||
let accept = if let Some((head_threshold, tail_threshold)) = &threshold {
|
||||
let head_ord = head_cmp.compare(&head_key, head_threshold);
|
||||
let ord = if head_ord == Ordering::Equal {
|
||||
tail_cmp.compare(&tail_key, tail_threshold)
|
||||
} else {
|
||||
head_ord
|
||||
};
|
||||
ord == Ordering::Greater
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if accept {
|
||||
output.push(ComparableDoc {
|
||||
sort_key: (head_key, tail_key),
|
||||
doc,
|
||||
});
|
||||
}
|
||||
}
|
||||
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -385,7 +218,7 @@ where
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key: Self::SegmentSortKey;
|
||||
if let Some(threshold) = &top_n_computer.threshold {
|
||||
@@ -402,29 +235,48 @@ where
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
let head_sort_key = self.head.segment_sort_key(doc, score);
|
||||
let tail_sort_key = self.tail.segment_sort_key(doc, score);
|
||||
let head_sort_key = self.0.segment_sort_key(doc, score);
|
||||
let tail_sort_key = self.1.segment_sort_key(doc, score);
|
||||
(head_sort_key, tail_sort_key)
|
||||
}
|
||||
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let (head_threshold, tail_threshold) = threshold;
|
||||
let (head_cmp, head_sort_key) =
|
||||
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
|
||||
if head_cmp == Ordering::Equal {
|
||||
let (tail_cmp, tail_sort_key) =
|
||||
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
|
||||
Some((tail_cmp, (head_sort_key, tail_sort_key)))
|
||||
} else {
|
||||
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
|
||||
Some((head_cmp, (head_sort_key, tail_sort_key)))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
let (head_sort_key, tail_sort_key) = sort_key;
|
||||
(
|
||||
self.head.convert_segment_sort_key(head_sort_key),
|
||||
self.tail.convert_segment_sort_key(tail_sort_key),
|
||||
self.0.convert_segment_sort_key(head_sort_key),
|
||||
self.1.convert_segment_sort_key(tail_sort_key),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct is used as an adapter to take a sort key computer and map its score to another
|
||||
/// new sort key.
|
||||
pub struct MappedSegmentSortKeyComputer<T: SegmentSortKeyComputer, NewSortKey> {
|
||||
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
|
||||
sort_key_computer: T,
|
||||
map: fn(T::SortKey) -> NewSortKey,
|
||||
map: fn(PreviousSortKey) -> NewSortKey,
|
||||
}
|
||||
|
||||
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
|
||||
for MappedSegmentSortKeyComputer<T, NewScore>
|
||||
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
|
||||
where
|
||||
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
|
||||
PreviousScore: 'static + Clone + Send + Sync,
|
||||
@@ -433,25 +285,19 @@ where
|
||||
type SortKey = NewScore;
|
||||
type SegmentSortKey = T::SegmentSortKey;
|
||||
type SegmentComparator = T::SegmentComparator;
|
||||
type Buffer = T::Buffer;
|
||||
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
self.sort_key_computer.segment_comparator()
|
||||
}
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
self.sort_key_computer
|
||||
.segment_sort_keys(input_docs, output, buffer, filter)
|
||||
.accept_sort_key_lazy(doc_id, score, threshold)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -459,21 +305,12 @@ where
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
self.sort_key_computer
|
||||
.compute_sort_key_and_collect(doc, score, top_n_computer);
|
||||
}
|
||||
|
||||
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
|
||||
) {
|
||||
self.sort_key_computer
|
||||
.compute_sort_keys_and_collect(docs, top_n_computer);
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
(self.map)(
|
||||
self.sort_key_computer
|
||||
@@ -499,6 +336,10 @@ where
|
||||
);
|
||||
type Child = MappedSegmentSortKeyComputer<
|
||||
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
|
||||
@@ -522,13 +363,7 @@ where
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: ChainSegmentSortKeyComputer {
|
||||
head: sort_key_computer1,
|
||||
tail: ChainSegmentSortKeyComputer {
|
||||
head: sort_key_computer2,
|
||||
tail: sort_key_computer3,
|
||||
},
|
||||
},
|
||||
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
|
||||
map,
|
||||
})
|
||||
}
|
||||
@@ -563,6 +398,13 @@ where
|
||||
SortKeyComputer1,
|
||||
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
|
||||
) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(
|
||||
SortKeyComputer2::SortKey,
|
||||
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
|
||||
),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
type SortKey = (
|
||||
@@ -584,16 +426,10 @@ where
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: ChainSegmentSortKeyComputer {
|
||||
head: sort_key_computer1,
|
||||
tail: ChainSegmentSortKeyComputer {
|
||||
head: sort_key_computer2,
|
||||
tail: ChainSegmentSortKeyComputer {
|
||||
head: sort_key_computer3,
|
||||
tail: sort_key_computer4,
|
||||
},
|
||||
},
|
||||
},
|
||||
sort_key_computer: (
|
||||
sort_key_computer1,
|
||||
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
|
||||
),
|
||||
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
|
||||
(sort_key1, sort_key2, sort_key3, sort_key4)
|
||||
},
|
||||
@@ -616,13 +452,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
|
||||
func: F,
|
||||
_phantom: PhantomData<TSortKey>,
|
||||
}
|
||||
|
||||
impl<F, SegmentF, TSortKey> SortKeyComputer for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
|
||||
@@ -630,18 +459,15 @@ where
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type Child = FuncSegmentSortKeyComputer<SegmentF, TSortKey>;
|
||||
type Child = SegmentF;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok(FuncSegmentSortKeyComputer {
|
||||
func: (self)(segment_reader),
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TSortKey> SegmentSortKeyComputer for FuncSegmentSortKeyComputer<F, TSortKey>
|
||||
impl<F, TSortKey> SegmentSortKeyComputer for F
|
||||
where
|
||||
F: 'static + FnMut(DocId) -> TSortKey,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
|
||||
@@ -649,25 +475,9 @@ where
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
|
||||
(self.func)(doc)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
_filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
for &doc in input_docs {
|
||||
output.push(ComparableDoc {
|
||||
sort_key: (self.func)(doc),
|
||||
doc,
|
||||
});
|
||||
}
|
||||
(self)(doc)
|
||||
}
|
||||
|
||||
/// Convert a segment level score into the global level score.
|
||||
@@ -676,75 +486,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
|
||||
match range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(&None),
|
||||
ValueRange::GreaterThan(_threshold, match_nulls) => *match_nulls,
|
||||
ValueRange::GreaterThanOrEqual(_threshold, match_nulls) => *match_nulls,
|
||||
ValueRange::LessThan(_threshold, match_nulls) => *match_nulls,
|
||||
ValueRange::LessThanOrEqual(_threshold, match_nulls) => *match_nulls,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_optional_u64_range_to_u64_range(
|
||||
range: ValueRange<Option<u64>>,
|
||||
) -> ValueRange<u64> {
|
||||
match range {
|
||||
ValueRange::Inclusive(r) => {
|
||||
let start = r.start().unwrap_or(0);
|
||||
let end = r.end().unwrap_or(u64::MAX);
|
||||
ValueRange::Inclusive(start..=end)
|
||||
}
|
||||
ValueRange::GreaterThan(Some(val), match_nulls) => {
|
||||
ValueRange::GreaterThan(val, match_nulls)
|
||||
}
|
||||
ValueRange::GreaterThan(None, match_nulls) => {
|
||||
if match_nulls {
|
||||
ValueRange::All
|
||||
} else {
|
||||
ValueRange::Inclusive(u64::MIN..=u64::MAX)
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(Some(val), match_nulls) => {
|
||||
ValueRange::GreaterThanOrEqual(val, match_nulls)
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(None, match_nulls) => {
|
||||
if match_nulls {
|
||||
ValueRange::All
|
||||
} else {
|
||||
ValueRange::Inclusive(u64::MIN..=u64::MAX)
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(None, match_nulls) => {
|
||||
if match_nulls {
|
||||
ValueRange::LessThan(u64::MIN, true)
|
||||
} else {
|
||||
ValueRange::Inclusive(1..=0)
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(Some(val), match_nulls) => ValueRange::LessThan(val, match_nulls),
|
||||
ValueRange::LessThanOrEqual(None, match_nulls) => {
|
||||
if match_nulls {
|
||||
ValueRange::LessThan(u64::MIN, true)
|
||||
} else {
|
||||
ValueRange::Inclusive(1..=0)
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(Some(val), match_nulls) => {
|
||||
ValueRange::LessThanOrEqual(val, match_nulls)
|
||||
}
|
||||
ValueRange::All => ValueRange::All,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Index, Order, SegmentReader};
|
||||
|
||||
@@ -892,178 +640,4 @@ mod tests {
|
||||
(200u32, 2u32)
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_batch_score_computer_edge_case() {
|
||||
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
|
||||
let score_computer_secondary = |_segment_reader: &SegmentReader| |_doc: DocId| "b";
|
||||
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
|
||||
let index = build_test_index();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let mut segment_sort_key_computer = lazy_score_computer
|
||||
.segment_sort_key_computer(searcher.segment_reader(0))
|
||||
.unwrap();
|
||||
|
||||
let mut top_n_computer =
|
||||
TopNComputer::new_with_comparator(10, lazy_score_computer.comparator());
|
||||
// Threshold (200, "a"). Doc is (200, "b"). 200 == 200, "b" > "a". Should be accepted.
|
||||
top_n_computer.threshold = Some((200, "a"));
|
||||
|
||||
let docs = vec![0];
|
||||
segment_sort_key_computer.compute_sort_keys_and_collect(&docs, &mut top_n_computer);
|
||||
|
||||
let results = top_n_computer.into_sorted_vec();
|
||||
assert_eq!(results.len(), 1);
|
||||
let result = &results[0];
|
||||
assert_eq!(result.doc, 0);
|
||||
assert_eq!(result.sort_key, (200, "b"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod proptest_tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
use crate::collector::sort_key::order::*;
|
||||
|
||||
// Re-implement logic to interpret ValueRange<Option<u64>> manually to verify expectations
|
||||
fn range_contains_opt(range: &ValueRange<Option<u64>>, val: &Option<u64>) -> bool {
|
||||
match range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(val),
|
||||
ValueRange::GreaterThan(t, match_nulls) => {
|
||||
if val.is_none() {
|
||||
*match_nulls
|
||||
} else {
|
||||
val > t
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(t, match_nulls) => {
|
||||
if val.is_none() {
|
||||
*match_nulls
|
||||
} else {
|
||||
val >= t
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(t, match_nulls) => {
|
||||
if val.is_none() {
|
||||
*match_nulls
|
||||
} else {
|
||||
val < t
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(t, match_nulls) => {
|
||||
if val.is_none() {
|
||||
*match_nulls
|
||||
} else {
|
||||
val <= t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn range_contains_u64(range: &ValueRange<u64>, val: &u64) -> bool {
|
||||
match range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(val),
|
||||
ValueRange::GreaterThan(t, _) => val > t,
|
||||
ValueRange::GreaterThanOrEqual(t, _) => val >= t,
|
||||
ValueRange::LessThan(t, _) => val < t,
|
||||
ValueRange::LessThanOrEqual(t, _) => val <= t,
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_comparator_consistency_natural_none_is_lower(
|
||||
threshold in any::<Option<u64>>(),
|
||||
val in any::<Option<u64>>()
|
||||
) {
|
||||
check_comparator::<NaturalComparator>(threshold, val)?;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comparator_consistency_reverse(
|
||||
threshold in any::<Option<u64>>(),
|
||||
val in any::<Option<u64>>()
|
||||
) {
|
||||
check_comparator::<ReverseComparator>(threshold, val)?;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comparator_consistency_reverse_none_is_lower(
|
||||
threshold in any::<Option<u64>>(),
|
||||
val in any::<Option<u64>>()
|
||||
) {
|
||||
check_comparator::<ReverseNoneIsLowerComparator>(threshold, val)?;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comparator_consistency_natural_none_is_higher(
|
||||
threshold in any::<Option<u64>>(),
|
||||
val in any::<Option<u64>>()
|
||||
) {
|
||||
check_comparator::<NaturalNoneIsHigherComparator>(threshold, val)?;
|
||||
}
|
||||
}
|
||||
|
||||
fn check_comparator<C: Comparator<Option<u64>>>(
|
||||
threshold: Option<u64>,
|
||||
val: Option<u64>,
|
||||
) -> std::result::Result<(), proptest::test_runner::TestCaseError> {
|
||||
let comparator = C::default();
|
||||
let range = comparator.threshold_to_valuerange(threshold);
|
||||
let ordering = comparator.compare(&val, &threshold);
|
||||
let should_be_in_range = ordering == Ordering::Greater;
|
||||
|
||||
let in_range_opt = range_contains_opt(&range, &val);
|
||||
|
||||
prop_assert_eq!(
|
||||
in_range_opt,
|
||||
should_be_in_range,
|
||||
"Comparator consistency failed for {:?}. Threshold: {:?}, Val: {:?}, Range: {:?}, \
|
||||
Ordering: {:?}. range_contains_opt says {}, but compare says {}",
|
||||
std::any::type_name::<C>(),
|
||||
threshold,
|
||||
val,
|
||||
range,
|
||||
ordering,
|
||||
in_range_opt,
|
||||
should_be_in_range
|
||||
);
|
||||
|
||||
// Check range_contains_none
|
||||
let expected_none_in_range = range_contains_opt(&range, &None);
|
||||
let actual_none_in_range = range_contains_none(&range);
|
||||
prop_assert_eq!(
|
||||
actual_none_in_range,
|
||||
expected_none_in_range,
|
||||
"range_contains_none failed for {:?}. Range: {:?}. Expected (from \
|
||||
range_contains_opt): {}, Actual: {}",
|
||||
std::any::type_name::<C>(),
|
||||
range,
|
||||
expected_none_in_range,
|
||||
actual_none_in_range
|
||||
);
|
||||
|
||||
// Check convert_optional_u64_range_to_u64_range
|
||||
let u64_range = convert_optional_u64_range_to_u64_range(range.clone());
|
||||
if let Some(v) = val {
|
||||
let in_u64_range = range_contains_u64(&u64_range, &v);
|
||||
let in_opt_range = range_contains_opt(&range, &Some(v));
|
||||
prop_assert_eq!(
|
||||
in_u64_range,
|
||||
in_opt_range,
|
||||
"convert_optional_u64_range_to_u64_range failed for {:?}. Val: {:?}, OptRange: \
|
||||
{:?}, U64Range: {:?}. Opt says {}, U64 says {}",
|
||||
std::any::type_name::<C>(),
|
||||
v,
|
||||
range,
|
||||
u64_range,
|
||||
in_opt_range,
|
||||
in_u64_range
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,12 +99,7 @@ where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
|
||||
{
|
||||
pub(crate) topn_computer: TopNComputer<
|
||||
TSegmentSortKeyComputer::SegmentSortKey,
|
||||
DocId,
|
||||
C,
|
||||
TSegmentSortKeyComputer::Buffer,
|
||||
>,
|
||||
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
|
||||
pub(crate) segment_ord: u32,
|
||||
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
}
|
||||
@@ -125,11 +120,6 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.segment_sort_key_computer
|
||||
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
let segment_ord = self.segment_ord;
|
||||
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::cmp::Ordering;
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::ValueRange;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Collector;
|
||||
@@ -11,7 +10,8 @@ use crate::collector::sort_key::{
|
||||
SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::sort_key_top_collector::TopBySortKeyCollector;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::top_collector::ComparableDoc;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastValue;
|
||||
use crate::{DocAddress, DocId, Order, Score, SegmentReader};
|
||||
|
||||
@@ -481,22 +481,11 @@ where
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
|
||||
(self.sort_key_fn)(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
_input_docs: &[DocId],
|
||||
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
_filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
unimplemented!("Batch computation is not supported for tweak score.")
|
||||
}
|
||||
|
||||
/// Convert a segment level score into the global level score.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key
|
||||
@@ -520,14 +509,12 @@ where
|
||||
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(from = "TopNComputerDeser<Score, D, C>")]
|
||||
pub struct TopNComputer<Score, D, C, Buffer = ()> {
|
||||
pub struct TopNComputer<Score, D, C> {
|
||||
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
|
||||
buffer: Vec<ComparableDoc<Score, D>>,
|
||||
top_n: usize,
|
||||
pub(crate) threshold: Option<Score>,
|
||||
comparator: C,
|
||||
#[serde(skip)]
|
||||
scratch: Buffer,
|
||||
}
|
||||
|
||||
// Intermediate struct for TopNComputer for deserialization, to keep vec capacity
|
||||
@@ -539,9 +526,7 @@ struct TopNComputerDeser<Score, D, C> {
|
||||
comparator: C,
|
||||
}
|
||||
|
||||
impl<Score, D, C, Buffer> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C, Buffer>
|
||||
where Buffer: Default
|
||||
{
|
||||
impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C> {
|
||||
fn from(mut value: TopNComputerDeser<Score, D, C>) -> Self {
|
||||
let expected_cap = value.top_n.max(1) * 2;
|
||||
let current_cap = value.buffer.capacity();
|
||||
@@ -556,15 +541,12 @@ where Buffer: Default
|
||||
top_n: value.top_n,
|
||||
threshold: value.threshold,
|
||||
comparator: value.comparator,
|
||||
scratch: Buffer::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Score: std::fmt::Debug, D, C, Buffer> std::fmt::Debug for TopNComputer<Score, D, C, Buffer>
|
||||
where
|
||||
C: Comparator<Score>,
|
||||
Buffer: std::fmt::Debug,
|
||||
impl<Score: std::fmt::Debug, D, C> std::fmt::Debug for TopNComputer<Score, D, C>
|
||||
where C: Comparator<Score>
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("TopNComputer")
|
||||
@@ -572,13 +554,12 @@ where
|
||||
.field("top_n", &self.top_n)
|
||||
.field("current_threshold", &self.threshold)
|
||||
.field("comparator", &self.comparator)
|
||||
.field("scratch", &self.scratch)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// Custom clone to keep capacity
|
||||
impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Score, D, C, Buffer> {
|
||||
impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
|
||||
fn clone(&self) -> Self {
|
||||
let mut buffer_clone = Vec::with_capacity(self.buffer.capacity());
|
||||
buffer_clone.extend(self.buffer.iter().cloned());
|
||||
@@ -587,17 +568,15 @@ impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Sco
|
||||
top_n: self.top_n,
|
||||
threshold: self.threshold.clone(),
|
||||
comparator: self.comparator.clone(),
|
||||
scratch: self.scratch.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator, ()>
|
||||
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator>
|
||||
where
|
||||
D: Ord,
|
||||
TSortKey: Clone,
|
||||
NaturalComparator: Comparator<TSortKey>,
|
||||
ReverseComparator: Comparator<TSortKey>,
|
||||
{
|
||||
/// Create a new `TopNComputer`.
|
||||
/// Internally it will allocate a buffer of size `2 * top_n`.
|
||||
@@ -606,26 +585,33 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKey, D, C, Buffer> TopNComputer<TSortKey, D, C, Buffer>
|
||||
#[inline(always)]
|
||||
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
|
||||
c: &C,
|
||||
lhs: &ComparableDoc<TSortKey, D>,
|
||||
rhs: &ComparableDoc<TSortKey, D>,
|
||||
) -> std::cmp::Ordering {
|
||||
c.compare(&lhs.sort_key, &rhs.sort_key)
|
||||
.reverse() // Reverse here because we want top K.
|
||||
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
|
||||
// sort by doc id
|
||||
}
|
||||
|
||||
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
|
||||
where
|
||||
D: Ord,
|
||||
TSortKey: Clone,
|
||||
C: Comparator<TSortKey>,
|
||||
Buffer: Default,
|
||||
{
|
||||
/// Create a new `TopNComputer`.
|
||||
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
|
||||
/// COLLECT_BLOCK_BUFFER_LEN`.
|
||||
/// Internally it will allocate a buffer of size `2 * top_n`.
|
||||
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
|
||||
// We ensure that there is always enough space to include an entire block in the buffer if
|
||||
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
|
||||
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
|
||||
let vec_cap = top_n.max(1) * 2;
|
||||
TopNComputer {
|
||||
buffer: Vec::with_capacity(vec_cap),
|
||||
top_n,
|
||||
threshold: None,
|
||||
comparator,
|
||||
scratch: Buffer::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -649,34 +635,22 @@ where
|
||||
// At this point, we need to have established that the doc is above the threshold.
|
||||
#[inline(always)]
|
||||
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
|
||||
self.reserve(1);
|
||||
// This cannot panic, because we've reserved room for one element.
|
||||
let comparable_doc = ComparableDoc { doc, sort_key };
|
||||
push_assuming_capacity(comparable_doc, &mut self.buffer);
|
||||
}
|
||||
|
||||
// Ensure that there is capacity to push `additional` more elements without resizing.
|
||||
#[inline(always)]
|
||||
pub(crate) fn reserve(&mut self, additional: usize) {
|
||||
if self.buffer.len() + additional > self.buffer.capacity() {
|
||||
if self.buffer.len() == self.buffer.capacity() {
|
||||
let median = self.truncate_top_n();
|
||||
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
|
||||
self.threshold = Some(median);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn buffer_and_scratch(
|
||||
&mut self,
|
||||
) -> (&mut Vec<ComparableDoc<TSortKey, D>>, &mut Buffer) {
|
||||
(&mut self.buffer, &mut self.scratch)
|
||||
// This cannot panic, because we truncate_median will at least remove one element, since
|
||||
// the min capacity is 2.
|
||||
let comparable_doc = ComparableDoc { doc, sort_key };
|
||||
push_assuming_capacity(comparable_doc, &mut self.buffer);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn truncate_top_n(&mut self) -> TSortKey {
|
||||
// Use select_nth_unstable to find the top nth score
|
||||
let (_, median_el, _) = self
|
||||
.buffer
|
||||
.select_nth_unstable_by(self.top_n, |lhs, rhs| self.comparator.compare_doc(lhs, rhs));
|
||||
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
|
||||
compare_for_top_k(&self.comparator, lhs, rhs)
|
||||
});
|
||||
|
||||
let median_score = median_el.sort_key.clone();
|
||||
// Remove all elements below the top_n
|
||||
@@ -691,7 +665,7 @@ where
|
||||
self.truncate_top_n();
|
||||
}
|
||||
self.buffer
|
||||
.sort_unstable_by(|left, right| self.comparator.compare_doc(left, right));
|
||||
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
|
||||
self.buffer
|
||||
}
|
||||
|
||||
@@ -710,7 +684,7 @@ where
|
||||
//
|
||||
// Panics if there is not enough capacity to add an element.
|
||||
#[inline(always)]
|
||||
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
|
||||
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
|
||||
let prev_len = buf.len();
|
||||
assert!(prev_len < buf.capacity());
|
||||
// This is mimicking the current (non-stabilized) implementation in std.
|
||||
@@ -727,10 +701,9 @@ mod tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::{TopDocs, TopNComputer};
|
||||
use crate::collector::sort_key::{
|
||||
Comparator, ComparatorEnum, NaturalComparator, ReverseComparator,
|
||||
};
|
||||
use crate::collector::{Collector, ComparableDoc, DocSetCollector};
|
||||
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
|
||||
use crate::collector::top_collector::ComparableDoc;
|
||||
use crate::collector::{Collector, DocSetCollector};
|
||||
use crate::query::{AllQuery, Query, QueryParser};
|
||||
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
|
||||
use crate::time::format_description::well_known::Rfc3339;
|
||||
@@ -849,9 +822,9 @@ mod tests {
|
||||
for (feature, doc) in &docs {
|
||||
computer.push(*feature, *doc);
|
||||
}
|
||||
let mut comparable_docs =
|
||||
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
|
||||
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
|
||||
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
|
||||
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
|
||||
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
|
||||
comparable_docs.truncate(limit);
|
||||
prop_assert_eq!(
|
||||
computer.into_sorted_vec(),
|
||||
@@ -1435,11 +1408,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_top_field_collect_string_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..32_usize,
|
||||
offset in 0..32_usize,
|
||||
limit in 1..256_usize,
|
||||
offset in 0..256_usize,
|
||||
segments_terms in
|
||||
proptest::collection::vec(
|
||||
proptest::collection::vec(0..64_u8, 1..256_usize),
|
||||
proptest::collection::vec(0..32_u8, 1..32_usize),
|
||||
0..8_usize,
|
||||
)
|
||||
) {
|
||||
@@ -1481,11 +1454,7 @@ mod tests {
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
if order.is_desc() {
|
||||
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
|
||||
} else {
|
||||
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
|
||||
}
|
||||
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
@@ -1764,8 +1733,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_top_n_computer_not_at_capacity() {
|
||||
let mut top_n_computer: TopNComputer<f32, u32, _, ()> =
|
||||
TopNComputer::new_with_comparator(4, NaturalComparator);
|
||||
let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator);
|
||||
top_n_computer.append_doc(1, 0.8);
|
||||
top_n_computer.append_doc(3, 0.2);
|
||||
top_n_computer.append_doc(5, 0.3);
|
||||
@@ -1790,8 +1758,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_top_n_computer_at_capacity() {
|
||||
let mut top_collector: TopNComputer<f32, u32, _, ()> =
|
||||
TopNComputer::new_with_comparator(4, NaturalComparator);
|
||||
let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator);
|
||||
top_collector.append_doc(1, 0.8);
|
||||
top_collector.append_doc(3, 0.2);
|
||||
top_collector.append_doc(5, 0.3);
|
||||
@@ -1828,14 +1795,12 @@ mod tests {
|
||||
let doc_ids_collection = [4, 5, 6];
|
||||
let score = 3.3f32;
|
||||
|
||||
let mut top_collector_limit_2: TopNComputer<f32, u32, _, ()> =
|
||||
TopNComputer::new_with_comparator(2, NaturalComparator);
|
||||
let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_2.append_doc(*id, score);
|
||||
}
|
||||
|
||||
let mut top_collector_limit_3: TopNComputer<f32, u32, _, ()> =
|
||||
TopNComputer::new_with_comparator(3, NaturalComparator);
|
||||
let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_3.append_doc(*id, score);
|
||||
}
|
||||
@@ -1856,16 +1821,15 @@ mod bench {
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector: TopNComputer<f32, u32, _, ()> =
|
||||
TopNComputer::new_with_comparator(100, NaturalComparator);
|
||||
let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.append_doc(i as u32, 0.8);
|
||||
top_collector.append_doc(i, 0.8);
|
||||
}
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.append_doc(i as u32, 0.8);
|
||||
top_collector.append_doc(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ fn path_for_version(version: &str) -> String {
|
||||
/// feature flag quickwit uses a different dictionary type
|
||||
#[test]
|
||||
#[cfg(not(feature = "quickwit"))]
|
||||
#[ignore = "test incompatible with fixed-width footer changes"]
|
||||
fn test_format_6() {
|
||||
let path = path_for_version("6");
|
||||
|
||||
@@ -48,7 +47,6 @@ fn test_format_6() {
|
||||
/// feature flag quickwit uses a different dictionary type
|
||||
#[test]
|
||||
#[cfg(not(feature = "quickwit"))]
|
||||
#[ignore = "test incompatible with fixed-width footer changes"]
|
||||
fn test_format_7() {
|
||||
let path = path_for_version("7");
|
||||
|
||||
|
||||
@@ -48,7 +48,15 @@ impl Executor {
|
||||
F: Sized + Sync + Fn(A) -> crate::Result<R>,
|
||||
{
|
||||
match self {
|
||||
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
|
||||
Executor::SingleThread => {
|
||||
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
|
||||
// harder.
|
||||
let mut result = Vec::with_capacity(args.size_hint().0);
|
||||
for arg in args {
|
||||
result.push(f(arg)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
Executor::ThreadPool(pool) => {
|
||||
let args: Vec<A> = args.collect();
|
||||
let num_fruits = args.len();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use crate::collector::Collector;
|
||||
@@ -86,7 +86,7 @@ impl Searcher {
|
||||
/// The searcher uses the segment ordinal to route the
|
||||
/// request to the right `Segment`.
|
||||
pub fn doc<D: DocumentDeserialize>(&self, doc_address: DocAddress) -> crate::Result<D> {
|
||||
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
|
||||
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
|
||||
store_reader.get(doc_address.doc_id)
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ impl Searcher {
|
||||
pub fn doc_store_cache_stats(&self) -> CacheStats {
|
||||
let cache_stats: CacheStats = self
|
||||
.inner
|
||||
.store_readers()
|
||||
.store_readers
|
||||
.iter()
|
||||
.map(|reader| reader.cache_stats())
|
||||
.sum();
|
||||
@@ -110,7 +110,7 @@ impl Searcher {
|
||||
doc_address: DocAddress,
|
||||
) -> crate::Result<D> {
|
||||
let executor = self.inner.index.search_executor();
|
||||
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
|
||||
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
|
||||
store_reader.get_async(doc_address.doc_id, executor).await
|
||||
}
|
||||
|
||||
@@ -259,9 +259,8 @@ impl From<Arc<SearcherInner>> for Searcher {
|
||||
pub(crate) struct SearcherInner {
|
||||
schema: Schema,
|
||||
index: Index,
|
||||
doc_store_cache_num_blocks: usize,
|
||||
segment_readers: Vec<SegmentReader>,
|
||||
store_readers: OnceLock<Vec<StoreReader>>,
|
||||
store_readers: Vec<StoreReader>,
|
||||
generation: TrackedObject<SearcherGeneration>,
|
||||
}
|
||||
|
||||
@@ -282,30 +281,19 @@ impl SearcherInner {
|
||||
generation.segments(),
|
||||
"Set of segments referenced by this Searcher and its SearcherGeneration must match"
|
||||
);
|
||||
let store_readers: Vec<StoreReader> = segment_readers
|
||||
.iter()
|
||||
.map(|segment_reader| segment_reader.get_store_reader(doc_store_cache_num_blocks))
|
||||
.collect::<io::Result<Vec<_>>>()?;
|
||||
|
||||
Ok(SearcherInner {
|
||||
schema,
|
||||
index,
|
||||
doc_store_cache_num_blocks,
|
||||
segment_readers,
|
||||
store_readers: OnceLock::default(),
|
||||
store_readers,
|
||||
generation,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn store_readers(&self) -> &[StoreReader] {
|
||||
self.store_readers.get_or_init(|| {
|
||||
self.segment_readers
|
||||
.iter()
|
||||
.map(|segment_reader| {
|
||||
segment_reader
|
||||
.get_store_reader(self.doc_store_cache_num_blocks)
|
||||
.expect("should be able to get store reader")
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Searcher {
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
use std::any::Any;
|
||||
use std::collections::HashSet;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io, thread};
|
||||
|
||||
use log::Level;
|
||||
|
||||
use crate::directory::directory_lock::Lock;
|
||||
use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
|
||||
use crate::directory::{
|
||||
FileHandle, FileSlice, TerminatingWrite, WatchCallback, WatchHandle, WritePtr,
|
||||
};
|
||||
use crate::index::SegmentMetaInventory;
|
||||
use crate::IndexMeta;
|
||||
use crate::directory::{FileHandle, FileSlice, WatchCallback, WatchHandle, WritePtr};
|
||||
|
||||
/// Retry the logic of acquiring locks is pretty simple.
|
||||
/// We just retry `n` times after a given `duratio`, both
|
||||
@@ -64,7 +56,7 @@ impl<T: Send + Sync + 'static> From<Box<T>> for DirectoryLock {
|
||||
impl Drop for DirectoryLockGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = self.directory.delete(&self.path) {
|
||||
error!("Failed to remove the lock file. {:?}", e);
|
||||
error!("Failed to remove the lock file. {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -105,8 +97,6 @@ fn retry_policy(is_blocking: bool) -> RetryPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
pub type DirectoryPanicHandler = Arc<dyn Fn(Box<dyn Any + Send>) + Send + Sync + 'static>;
|
||||
|
||||
/// Write-once read many (WORM) abstraction for where
|
||||
/// tantivy's data should be stored.
|
||||
///
|
||||
@@ -145,10 +135,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// Returns true if and only if the file exists
|
||||
fn exists(&self, path: &Path) -> Result<bool, OpenReadError>;
|
||||
|
||||
/// Returns a boxed `TerminatingWrite` object, to be passed into `open_write`
|
||||
/// which wraps it in a `BufWriter`
|
||||
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError>;
|
||||
|
||||
/// Opens a writer for the *virtual file* associated with
|
||||
/// a [`Path`].
|
||||
///
|
||||
@@ -175,12 +161,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// panic! if `flush` was not called.
|
||||
///
|
||||
/// The file may not previously exist.
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
|
||||
Ok(io::BufWriter::with_capacity(
|
||||
self.bufwriter_capacity(),
|
||||
self.open_write_inner(path)?,
|
||||
))
|
||||
}
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError>;
|
||||
|
||||
/// Reads the full content file that has been written using
|
||||
/// [`Directory::atomic_write()`].
|
||||
@@ -242,75 +223,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// `OnCommitWithDelay` `ReloadPolicy`. Not implementing watch in a `Directory` only prevents
|
||||
/// the `OnCommitWithDelay` `ReloadPolicy` to work properly.
|
||||
fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle>;
|
||||
|
||||
/// Allows the directory to list managed files, overriding the ManagedDirectory's default
|
||||
/// list_managed_files
|
||||
fn list_managed_files(&self) -> crate::Result<HashSet<PathBuf>> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"list_managed_files not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to register a file as managed, overriding the ManagedDirectory's
|
||||
/// default register_file_as_managed
|
||||
fn register_files_as_managed(
|
||||
&self,
|
||||
_files: Vec<PathBuf>,
|
||||
_overwrite: bool,
|
||||
) -> crate::Result<()> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"register_files_as_managed not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to save IndexMeta, overriding the SegmentUpdater's default save_meta
|
||||
fn save_metas(
|
||||
&self,
|
||||
_metas: &IndexMeta,
|
||||
_previous_metas: &IndexMeta,
|
||||
_payload: &mut (dyn Any + '_),
|
||||
) -> crate::Result<()> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"save_meta not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to load IndexMeta, overriding the SegmentUpdater's default load_meta
|
||||
fn load_metas(&self, _inventory: &SegmentMetaInventory) -> crate::Result<IndexMeta> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"load_metas not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns true if this directory supports garbage collection. The default assumption is
|
||||
/// `true`
|
||||
fn supports_garbage_collection(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Return a panic handler to be assigned to the various thread pools that may be created
|
||||
///
|
||||
/// The default is [`None`], which indicates that an unhandled panic from a thread pool will
|
||||
/// abort the process
|
||||
fn panic_handler(&self) -> Option<DirectoryPanicHandler> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns true if this directory is in a position of requiring that tantivy cancel
|
||||
/// whatever operation(s) it might be doing Typically this is just for the background
|
||||
/// merge processes but could be used for anything
|
||||
fn wants_cancel(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Send a logging message to the Directory to handle in its own way
|
||||
fn log(&self, message: &str) {
|
||||
log!(Level::Info, "{message}");
|
||||
}
|
||||
|
||||
fn bufwriter_capacity(&self) -> usize {
|
||||
8192
|
||||
}
|
||||
}
|
||||
|
||||
/// DirectoryClone
|
||||
|
||||
@@ -58,9 +58,3 @@ pub static META_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
|
||||
filepath: PathBuf::from(".tantivy-meta.lock"),
|
||||
is_blocking: true,
|
||||
});
|
||||
|
||||
#[allow(missing_docs)]
|
||||
pub static MANAGED_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
|
||||
filepath: PathBuf::from(".tantivy-managed.lock"),
|
||||
is_blocking: true,
|
||||
});
|
||||
|
||||
@@ -7,14 +7,15 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
use common::{BinarySerializable, HasLen};
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, FixedSize, HasLen};
|
||||
use crc32fast::Hasher;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::directory::error::Incompatibility;
|
||||
use crate::directory::{AntiCallToken, FileSlice, TerminatingWrite};
|
||||
use crate::{Version, INDEX_FORMAT_OLDEST_SUPPORTED_VERSION, INDEX_FORMAT_VERSION};
|
||||
|
||||
pub const FOOTER_LEN: usize = 24;
|
||||
const FOOTER_MAX_LEN: u32 = 50_000;
|
||||
|
||||
/// The magic byte of the footer to identify corruption
|
||||
/// or an old version of the footer.
|
||||
@@ -23,7 +24,7 @@ const FOOTER_MAGIC_NUMBER: u32 = 1337;
|
||||
type CrcHashU32 = u32;
|
||||
|
||||
/// A Footer is appended to every file
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Footer {
|
||||
/// The version of the index format
|
||||
pub version: Version,
|
||||
@@ -40,45 +41,34 @@ impl Footer {
|
||||
pub(crate) fn crc(&self) -> CrcHashU32 {
|
||||
self.crc
|
||||
}
|
||||
pub fn append_footer<W: io::Write>(&self, write: &mut W) -> io::Result<()> {
|
||||
// 24 bytes
|
||||
BinarySerializable::serialize(&self.version.major, write)?;
|
||||
BinarySerializable::serialize(&self.version.minor, write)?;
|
||||
BinarySerializable::serialize(&self.version.patch, write)?;
|
||||
BinarySerializable::serialize(&self.version.index_format_version, write)?;
|
||||
BinarySerializable::serialize(&self.crc, write)?;
|
||||
pub(crate) fn append_footer<W: io::Write>(&self, mut write: &mut W) -> io::Result<()> {
|
||||
let mut counting_write = CountingWriter::wrap(&mut write);
|
||||
counting_write.write_all(serde_json::to_string(&self)?.as_ref())?;
|
||||
let footer_payload_len = counting_write.written_bytes();
|
||||
BinarySerializable::serialize(&(footer_payload_len as u32), write)?;
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, write)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extracts the tantivy Footer from the file and returns the footer and the rest of the file
|
||||
pub fn extract_footer(file: FileSlice) -> io::Result<(Footer, FileSlice)> {
|
||||
if file.len() < FOOTER_LEN {
|
||||
if file.len() < 4 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"File corrupted. The file is too small to contain the {FOOTER_LEN} byte \
|
||||
footer (len={}).",
|
||||
"File corrupted. The file is smaller than 4 bytes (len={}).",
|
||||
file.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let (body_slice, footer_slice) = file.split_from_end(FOOTER_LEN);
|
||||
let footer_bytes = footer_slice.read_bytes()?;
|
||||
let mut footer_bytes = footer_bytes.as_slice();
|
||||
let footer_metadata_len = <(u32, u32)>::SIZE_IN_BYTES;
|
||||
let (footer_len, footer_magic_byte): (u32, u32) = file
|
||||
.slice_from_end(footer_metadata_len)
|
||||
.read_bytes()?
|
||||
.as_ref()
|
||||
.deserialize()?;
|
||||
|
||||
let footer = Footer {
|
||||
version: Version {
|
||||
major: u32::deserialize(&mut footer_bytes)?,
|
||||
minor: u32::deserialize(&mut footer_bytes)?,
|
||||
patch: u32::deserialize(&mut footer_bytes)?,
|
||||
index_format_version: u32::deserialize(&mut footer_bytes)?,
|
||||
},
|
||||
crc: u32::deserialize(&mut footer_bytes)?,
|
||||
};
|
||||
|
||||
let footer_magic_byte = u32::deserialize(&mut footer_bytes)?;
|
||||
if footer_magic_byte != FOOTER_MAGIC_NUMBER {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
@@ -88,12 +78,38 @@ impl Footer {
|
||||
));
|
||||
}
|
||||
|
||||
Ok((footer, body_slice))
|
||||
if footer_len > FOOTER_MAX_LEN {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Footer seems invalid as it suggests a footer len of {footer_len}. File is \
|
||||
corrupted, or the index was created with a different & old version of \
|
||||
tantivy."
|
||||
),
|
||||
));
|
||||
}
|
||||
let total_footer_size = footer_len as usize + footer_metadata_len;
|
||||
if file.len() < total_footer_size {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"File corrupted. The file is smaller than it's footer bytes \
|
||||
(len={total_footer_size})."
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let footer: Footer =
|
||||
serde_json::from_slice(&file.read_bytes_slice(
|
||||
file.len() - total_footer_size..file.len() - footer_metadata_len,
|
||||
)?)?;
|
||||
|
||||
let body = file.slice_to(file.len() - total_footer_size);
|
||||
Ok((footer, body))
|
||||
}
|
||||
|
||||
/// Confirms that the index will be read correctly by this version of tantivy
|
||||
/// Has to be called after `extract_footer` to make sure it's not accessing uninitialised memory
|
||||
#[allow(dead_code)]
|
||||
pub fn is_compatible(&self) -> Result<(), Incompatibility> {
|
||||
const SUPPORTED_INDEX_FORMAT_VERSION_RANGE: std::ops::RangeInclusive<u32> =
|
||||
INDEX_FORMAT_OLDEST_SUPPORTED_VERSION..=INDEX_FORMAT_VERSION;
|
||||
@@ -172,10 +188,6 @@ mod tests {
|
||||
fn test_deserialize_footer_missing_magic_byte() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
let wrong_magic_byte: u32 = 5555;
|
||||
BinarySerializable::serialize(&wrong_magic_byte, &mut buf).unwrap();
|
||||
|
||||
@@ -193,6 +205,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_deserialize_footer_wrong_filesize() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
BinarySerializable::serialize(&100_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
|
||||
|
||||
let owned_bytes = OwnedBytes::new(buf);
|
||||
@@ -202,7 +215,27 @@ mod tests {
|
||||
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"File corrupted. The file is too small to contain the 24 byte footer (len=4)."
|
||||
"File corrupted. The file is smaller than it\'s footer bytes (len=108)."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_too_large_footer() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
|
||||
let footer_length = super::FOOTER_MAX_LEN + 1;
|
||||
BinarySerializable::serialize(&footer_length, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
|
||||
|
||||
let owned_bytes = OwnedBytes::new(buf);
|
||||
|
||||
let fileslice = FileSlice::new(Arc::new(owned_bytes));
|
||||
let err = Footer::extract_footer(fileslice).unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"Footer seems invalid as it suggests a footer len of 50001. File is corrupted, or the \
|
||||
index was created with a different & old version of tantivy."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
use std::any::Any;
|
||||
use std::collections::HashSet;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock, RwLockWriteGuard};
|
||||
use std::{io, result};
|
||||
|
||||
use crc32fast::Hasher;
|
||||
|
||||
use crate::core::MANAGED_FILEPATH;
|
||||
use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
|
||||
use crate::directory::footer::{Footer, FooterProxy, FOOTER_LEN};
|
||||
use crate::directory::footer::{Footer, FooterProxy};
|
||||
use crate::directory::{
|
||||
DirectoryLock, DirectoryPanicHandler, FileHandle, FileSlice, GarbageCollectionResult, Lock,
|
||||
TerminatingWrite, WatchCallback, WatchHandle, MANAGED_LOCK, META_LOCK,
|
||||
DirectoryLock, FileHandle, FileSlice, GarbageCollectionResult, Lock, WatchCallback,
|
||||
WatchHandle, WritePtr, META_LOCK,
|
||||
};
|
||||
use crate::error::DataCorruption;
|
||||
use crate::index::SegmentMetaInventory;
|
||||
use crate::{Directory, IndexMeta};
|
||||
use crate::Directory;
|
||||
|
||||
/// Returns true if the file is "managed".
|
||||
/// Non-managed file are not subject to garbage collection.
|
||||
@@ -41,9 +39,9 @@ fn is_managed(path: &Path) -> bool {
|
||||
#[derive(Debug)]
|
||||
pub struct ManagedDirectory {
|
||||
directory: Box<dyn Directory>,
|
||||
meta_informations: Arc<RwLock<MetaInformation>>,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Default)]
|
||||
struct MetaInformation {
|
||||
managed_paths: HashSet<PathBuf>,
|
||||
@@ -53,9 +51,9 @@ struct MetaInformation {
|
||||
/// that were created by tantivy.
|
||||
fn save_managed_paths(
|
||||
directory: &dyn Directory,
|
||||
managed_paths: &HashSet<PathBuf>,
|
||||
wlock: &RwLockWriteGuard<'_, MetaInformation>,
|
||||
) -> io::Result<()> {
|
||||
let mut w = serde_json::to_vec(managed_paths)?;
|
||||
let mut w = serde_json::to_vec(&wlock.managed_paths)?;
|
||||
writeln!(&mut w)?;
|
||||
directory.atomic_write(&MANAGED_FILEPATH, &w[..])?;
|
||||
Ok(())
|
||||
@@ -64,37 +62,33 @@ fn save_managed_paths(
|
||||
impl ManagedDirectory {
|
||||
/// Wraps a directory as managed directory.
|
||||
pub fn wrap(directory: Box<dyn Directory>) -> crate::Result<ManagedDirectory> {
|
||||
Ok(ManagedDirectory { directory })
|
||||
}
|
||||
|
||||
pub fn list_managed_files(&self) -> crate::Result<HashSet<PathBuf>> {
|
||||
match self.directory.list_managed_files() {
|
||||
Ok(managed_files) => Ok(managed_files),
|
||||
Err(crate::TantivyError::InternalError(_)) => {
|
||||
match self.directory.atomic_read(&MANAGED_FILEPATH) {
|
||||
Ok(data) => {
|
||||
let managed_files_json = String::from_utf8_lossy(&data);
|
||||
let managed_files: HashSet<PathBuf> =
|
||||
serde_json::from_str(&managed_files_json).map_err(|e| {
|
||||
DataCorruption::new(
|
||||
MANAGED_FILEPATH.to_path_buf(),
|
||||
format!("Managed file cannot be deserialized: {e:?}. "),
|
||||
)
|
||||
})?;
|
||||
Ok(managed_files)
|
||||
}
|
||||
Err(OpenReadError::FileDoesNotExist(_)) => Ok(HashSet::new()),
|
||||
io_err @ Err(OpenReadError::IoError { .. }) => {
|
||||
Err(io_err.err().unwrap().into())
|
||||
}
|
||||
Err(OpenReadError::IncompatibleIndex(incompatibility)) => {
|
||||
// For the moment, this should never happen `meta.json`
|
||||
// do not have any footer and cannot detect incompatibility.
|
||||
Err(crate::TantivyError::IncompatibleIndex(incompatibility))
|
||||
}
|
||||
}
|
||||
match directory.atomic_read(&MANAGED_FILEPATH) {
|
||||
Ok(data) => {
|
||||
let managed_files_json = String::from_utf8_lossy(&data);
|
||||
let managed_files: HashSet<PathBuf> = serde_json::from_str(&managed_files_json)
|
||||
.map_err(|e| {
|
||||
DataCorruption::new(
|
||||
MANAGED_FILEPATH.to_path_buf(),
|
||||
format!("Managed file cannot be deserialized: {e:?}. "),
|
||||
)
|
||||
})?;
|
||||
Ok(ManagedDirectory {
|
||||
directory,
|
||||
meta_informations: Arc::new(RwLock::new(MetaInformation {
|
||||
managed_paths: managed_files,
|
||||
})),
|
||||
})
|
||||
}
|
||||
Err(OpenReadError::FileDoesNotExist(_)) => Ok(ManagedDirectory {
|
||||
directory,
|
||||
meta_informations: Arc::default(),
|
||||
}),
|
||||
io_err @ Err(OpenReadError::IoError { .. }) => Err(io_err.err().unwrap().into()),
|
||||
Err(OpenReadError::IncompatibleIndex(incompatibility)) => {
|
||||
// For the moment, this should never happen `meta.json`
|
||||
// do not have any footer and cannot detect incompatibility.
|
||||
Err(crate::TantivyError::IncompatibleIndex(incompatibility))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,26 +110,9 @@ impl ManagedDirectory {
|
||||
&mut self,
|
||||
get_living_files: L,
|
||||
) -> crate::Result<GarbageCollectionResult> {
|
||||
if !self.supports_garbage_collection() {
|
||||
// the underlying directory does not support garbage collection.
|
||||
return Ok(GarbageCollectionResult {
|
||||
deleted_files: vec![],
|
||||
failed_to_delete_files: vec![],
|
||||
});
|
||||
}
|
||||
info!("Garbage collect");
|
||||
let mut files_to_delete = vec![];
|
||||
|
||||
// We're about to do an atomic write to managed.json, lock it down
|
||||
let _lock = self.acquire_lock(&MANAGED_LOCK)?;
|
||||
let managed_paths = match self.directory.list_managed_files() {
|
||||
Ok(managed_paths) => managed_paths,
|
||||
Err(crate::TantivyError::InternalError(_)) => {
|
||||
// If the managed.json file does not exist, we consider
|
||||
// that there is no managed file.
|
||||
self.list_managed_files()?
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
// It is crucial to get the living files after acquiring the
|
||||
// read lock of meta information. That way, we
|
||||
// avoid the following scenario.
|
||||
@@ -147,6 +124,11 @@ impl ManagedDirectory {
|
||||
//
|
||||
// releasing the lock as .delete() will use it too.
|
||||
{
|
||||
let meta_informations_rlock = self
|
||||
.meta_informations
|
||||
.read()
|
||||
.expect("Managed directory rlock poisoned in garbage collect.");
|
||||
|
||||
// The point of this second "file" lock is to enforce the following scenario
|
||||
// 1) process B tries to load a new set of searcher.
|
||||
// The list of segments is loaded
|
||||
@@ -156,7 +138,7 @@ impl ManagedDirectory {
|
||||
match self.acquire_lock(&META_LOCK) {
|
||||
Ok(_meta_lock) => {
|
||||
let living_files = get_living_files();
|
||||
for managed_path in &managed_paths {
|
||||
for managed_path in &meta_informations_rlock.managed_paths {
|
||||
if !living_files.contains(managed_path) {
|
||||
files_to_delete.push(managed_path.clone());
|
||||
}
|
||||
@@ -199,18 +181,16 @@ impl ManagedDirectory {
|
||||
if !deleted_files.is_empty() {
|
||||
// update the list of managed files by removing
|
||||
// the file that were removed.
|
||||
let mut managed_paths_write = managed_paths;
|
||||
let mut meta_informations_wlock = self
|
||||
.meta_informations
|
||||
.write()
|
||||
.expect("Managed directory wlock poisoned (2).");
|
||||
let managed_paths_write = &mut meta_informations_wlock.managed_paths;
|
||||
for delete_file in &deleted_files {
|
||||
managed_paths_write.remove(delete_file);
|
||||
}
|
||||
self.directory.sync_directory()?;
|
||||
|
||||
if let Err(crate::TantivyError::InternalError(_)) = self
|
||||
.directory
|
||||
.register_files_as_managed(managed_paths_write.clone().into_iter().collect(), true)
|
||||
{
|
||||
save_managed_paths(self.directory.as_mut(), &managed_paths_write)?;
|
||||
}
|
||||
save_managed_paths(self.directory.as_mut(), &meta_informations_wlock)?;
|
||||
}
|
||||
|
||||
Ok(GarbageCollectionResult {
|
||||
@@ -235,39 +215,27 @@ impl ManagedDirectory {
|
||||
if !is_managed(filepath) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// We're about to do an atomic write to managed.json, lock it down
|
||||
let _lock = self
|
||||
.acquire_lock(&MANAGED_LOCK)
|
||||
.expect("must be able to acquire lock for managed.json");
|
||||
|
||||
if let Err(crate::TantivyError::InternalError(_)) = self
|
||||
.directory
|
||||
.register_files_as_managed(vec![filepath.to_owned()], false)
|
||||
{
|
||||
let mut managed_paths = self
|
||||
.list_managed_files()
|
||||
.expect("reading managed files should not fail");
|
||||
let has_changed = managed_paths.insert(filepath.to_owned());
|
||||
if !has_changed {
|
||||
return Ok(());
|
||||
}
|
||||
save_managed_paths(self.directory.as_ref(), &managed_paths)?;
|
||||
// This is not the first file we add.
|
||||
// Therefore, we are sure that `.managed.json` has been already
|
||||
// properly created and we do not need to sync its parent directory.
|
||||
//
|
||||
// (It might seem like a nicer solution to create the managed_json on the
|
||||
// creation of the ManagedDirectory instance but it would actually
|
||||
// prevent the use of read-only directories..)
|
||||
let managed_file_definitely_already_exists = managed_paths.len() > 1;
|
||||
if managed_file_definitely_already_exists {
|
||||
return Ok(());
|
||||
}
|
||||
let mut meta_wlock = self
|
||||
.meta_informations
|
||||
.write()
|
||||
.expect("Managed file lock poisoned");
|
||||
let has_changed = meta_wlock.managed_paths.insert(filepath.to_owned());
|
||||
if !has_changed {
|
||||
return Ok(());
|
||||
}
|
||||
save_managed_paths(self.directory.as_ref(), &meta_wlock)?;
|
||||
// This is not the first file we add.
|
||||
// Therefore, we are sure that `.managed.json` has been already
|
||||
// properly created and we do not need to sync its parent directory.
|
||||
//
|
||||
// (It might seem like a nicer solution to create the managed_json on the
|
||||
// creation of the ManagedDirectory instance but it would actually
|
||||
// prevent the use of read-only directories..)
|
||||
let managed_file_definitely_already_exists = meta_wlock.managed_paths.len() > 1;
|
||||
if managed_file_definitely_already_exists {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.directory.sync_directory()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -287,6 +255,17 @@ impl ManagedDirectory {
|
||||
let crc = hasher.finalize();
|
||||
Ok(footer.crc() == crc)
|
||||
}
|
||||
|
||||
/// List all managed files
|
||||
pub fn list_managed_files(&self) -> HashSet<PathBuf> {
|
||||
let managed_paths = self
|
||||
.meta_informations
|
||||
.read()
|
||||
.expect("Managed directory rlock poisoned in list damaged.")
|
||||
.managed_paths
|
||||
.clone();
|
||||
managed_paths
|
||||
}
|
||||
}
|
||||
|
||||
impl Directory for ManagedDirectory {
|
||||
@@ -297,32 +276,22 @@ impl Directory for ManagedDirectory {
|
||||
|
||||
fn open_read(&self, path: &Path) -> result::Result<FileSlice, OpenReadError> {
|
||||
let file_slice = self.directory.open_read(path)?;
|
||||
debug_assert!(
|
||||
{
|
||||
use common::HasLen;
|
||||
file_slice.len() >= FOOTER_LEN
|
||||
},
|
||||
"{} is too short",
|
||||
path.display()
|
||||
);
|
||||
let (reader, _) = file_slice.split_from_end(FOOTER_LEN);
|
||||
// NB: We do not read/validate the footer here -- we blindly skip it entirely
|
||||
let (footer, reader) = Footer::extract_footer(file_slice)
|
||||
.map_err(|io_error| OpenReadError::wrap_io_error(io_error, path.to_path_buf()))?;
|
||||
footer.is_compatible()?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
fn open_write_inner(
|
||||
&self,
|
||||
path: &Path,
|
||||
) -> result::Result<Box<dyn TerminatingWrite>, OpenWriteError> {
|
||||
fn open_write(&self, path: &Path) -> result::Result<WritePtr, OpenWriteError> {
|
||||
self.register_file_as_managed(path)
|
||||
.map_err(|io_error| OpenWriteError::wrap_io_error(io_error, path.to_path_buf()))?;
|
||||
Ok(Box::new(FooterProxy::new(
|
||||
Ok(io::BufWriter::new(Box::new(FooterProxy::new(
|
||||
self.directory
|
||||
.open_write(path)?
|
||||
.into_inner()
|
||||
.map_err(|_| ())
|
||||
.expect("buffer should be empty"),
|
||||
)))
|
||||
))))
|
||||
}
|
||||
|
||||
fn atomic_write(&self, path: &Path, data: &[u8]) -> io::Result<()> {
|
||||
@@ -354,45 +323,13 @@ impl Directory for ManagedDirectory {
|
||||
self.directory.sync_directory()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_metas(
|
||||
&self,
|
||||
metas: &IndexMeta,
|
||||
previous_metas: &IndexMeta,
|
||||
payload: &mut (dyn Any + '_),
|
||||
) -> crate::Result<()> {
|
||||
self.directory.save_metas(metas, previous_metas, payload)
|
||||
}
|
||||
|
||||
fn load_metas(&self, inventory: &SegmentMetaInventory) -> crate::Result<IndexMeta> {
|
||||
self.directory.load_metas(inventory)
|
||||
}
|
||||
|
||||
fn supports_garbage_collection(&self) -> bool {
|
||||
self.directory.supports_garbage_collection()
|
||||
}
|
||||
|
||||
fn panic_handler(&self) -> Option<DirectoryPanicHandler> {
|
||||
self.directory.panic_handler()
|
||||
}
|
||||
|
||||
fn wants_cancel(&self) -> bool {
|
||||
self.directory.wants_cancel()
|
||||
}
|
||||
|
||||
fn log(&self, message: &str) {
|
||||
self.directory.log(message);
|
||||
}
|
||||
|
||||
fn bufwriter_capacity(&self) -> usize {
|
||||
self.directory.bufwriter_capacity()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ManagedDirectory {
|
||||
fn clone(&self) -> ManagedDirectory {
|
||||
ManagedDirectory {
|
||||
directory: self.directory.box_clone(),
|
||||
meta_informations: Arc::clone(&self.meta_informations),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -400,6 +337,7 @@ impl Clone for ManagedDirectory {
|
||||
#[cfg(feature = "mmap")]
|
||||
#[cfg(test)]
|
||||
mod tests_mmap_specific {
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
@@ -9,7 +9,6 @@ use crc32fast::Hasher;
|
||||
|
||||
use crate::directory::{WatchCallback, WatchCallbackList, WatchHandle};
|
||||
|
||||
#[allow(dead_code)]
|
||||
const POLLING_INTERVAL: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 500 });
|
||||
|
||||
// Watches a file and executes registered callbacks when the file is modified.
|
||||
@@ -19,7 +18,6 @@ pub struct FileWatcher {
|
||||
state: Arc<AtomicUsize>, // 0: new, 1: runnable, 2: terminated
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl FileWatcher {
|
||||
pub fn new(path: &Path) -> FileWatcher {
|
||||
FileWatcher {
|
||||
@@ -1,12 +1,15 @@
|
||||
mod file_watcher;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::io::{self, BufWriter, Read, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock, Weak};
|
||||
|
||||
use common::StableDeref;
|
||||
use file_watcher::FileWatcher;
|
||||
use fs4::fs_std::FileExt;
|
||||
#[cfg(all(feature = "mmap", unix))]
|
||||
pub use memmap2::Advice;
|
||||
@@ -18,10 +21,9 @@ use crate::core::META_FILEPATH;
|
||||
use crate::directory::error::{
|
||||
DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError,
|
||||
};
|
||||
use crate::directory::file_watcher::FileWatcher;
|
||||
use crate::directory::{
|
||||
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
|
||||
WatchCallback, WatchHandle,
|
||||
WatchCallback, WatchHandle, WritePtr,
|
||||
};
|
||||
|
||||
pub type ArcBytes = Arc<dyn Deref<Target = [u8]> + Send + Sync + 'static>;
|
||||
@@ -413,8 +415,8 @@ impl Directory for MmapDirectory {
|
||||
.map_err(|io_err| OpenReadError::wrap_io_error(io_err, path.to_path_buf()))
|
||||
}
|
||||
|
||||
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError> {
|
||||
debug!("Open Write {:?}", path);
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
|
||||
debug!("Open Write {path:?}");
|
||||
let full_path = self.resolve_path(path);
|
||||
|
||||
let open_res = OpenOptions::new()
|
||||
@@ -443,7 +445,7 @@ impl Directory for MmapDirectory {
|
||||
// sync_directory() is called.
|
||||
|
||||
let writer = SafeFileWriter::new(file);
|
||||
Ok(Box::new(writer))
|
||||
Ok(BufWriter::new(Box::new(writer)))
|
||||
}
|
||||
|
||||
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
|
||||
@@ -5,7 +5,6 @@ mod mmap_directory;
|
||||
|
||||
mod directory;
|
||||
mod directory_lock;
|
||||
mod file_watcher;
|
||||
pub mod footer;
|
||||
mod managed_directory;
|
||||
mod ram_directory;
|
||||
@@ -19,13 +18,12 @@ mod composite_file;
|
||||
use std::io::BufWriter;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub use common::buffered_file_slice::BufferedFileSlice;
|
||||
pub use common::file_slice::{FileHandle, FileSlice};
|
||||
pub use common::{AntiCallToken, OwnedBytes, TerminatingWrite};
|
||||
|
||||
pub(crate) use self::composite_file::{CompositeFile, CompositeWrite};
|
||||
pub use self::directory::{Directory, DirectoryClone, DirectoryLock, DirectoryPanicHandler};
|
||||
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, MANAGED_LOCK, META_LOCK};
|
||||
pub use self::directory::{Directory, DirectoryClone, DirectoryLock};
|
||||
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK};
|
||||
pub use self::ram_directory::RamDirectory;
|
||||
pub use self::watch_event_router::{WatchCallback, WatchCallbackList, WatchHandle};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Cursor, Write};
|
||||
use std::io::{self, BufWriter, Cursor, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::{fmt, result};
|
||||
@@ -11,7 +11,7 @@ use crate::core::META_FILEPATH;
|
||||
use crate::directory::error::{DeleteError, OpenReadError, OpenWriteError};
|
||||
use crate::directory::{
|
||||
AntiCallToken, Directory, FileSlice, TerminatingWrite, WatchCallback, WatchCallbackList,
|
||||
WatchHandle,
|
||||
WatchHandle, WritePtr,
|
||||
};
|
||||
|
||||
/// Writer associated with the [`RamDirectory`].
|
||||
@@ -197,7 +197,7 @@ impl Directory for RamDirectory {
|
||||
.exists(path))
|
||||
}
|
||||
|
||||
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError> {
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
|
||||
let mut fs = self.fs.write().unwrap();
|
||||
let path_buf = PathBuf::from(path);
|
||||
let vec_writer = VecWriter::new(path_buf.clone(), self.clone());
|
||||
@@ -206,7 +206,7 @@ impl Directory for RamDirectory {
|
||||
if exists {
|
||||
Err(OpenWriteError::FileAlreadyExists(path_buf))
|
||||
} else {
|
||||
Ok(Box::new(vec_writer))
|
||||
Ok(BufWriter::new(Box::new(vec_writer)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,8 @@ pub trait DocSet: Send {
|
||||
/// of `DocSet` should support it.
|
||||
///
|
||||
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
|
||||
///
|
||||
/// `target` has to be larger or equal to `.doc()` when calling `seek`.
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
let mut doc = self.doc();
|
||||
debug_assert!(doc <= target);
|
||||
@@ -49,6 +51,33 @@ pub trait DocSet: Send {
|
||||
doc
|
||||
}
|
||||
|
||||
/// Seeks to the target if possible and returns true if the target is in the DocSet.
|
||||
///
|
||||
/// DocSets that already have an efficient `seek` method don't need to implement
|
||||
/// `seek_into_the_danger_zone`. All wrapper DocSets should forward
|
||||
/// `seek_into_the_danger_zone` to the underlying DocSet.
|
||||
///
|
||||
/// ## API Behaviour
|
||||
/// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target.
|
||||
/// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc
|
||||
/// between the last doc that matched and target or a doc that is a valid next hit after
|
||||
/// target. The DocSet is considered to be in an invalid state until
|
||||
/// `seek_into_the_danger_zone` returns true again.
|
||||
///
|
||||
/// `target` needs to be equal or larger than `doc` when in a valid state.
|
||||
///
|
||||
/// Consecutive calls are not allowed to have decreasing `target` values.
|
||||
///
|
||||
/// # Warning
|
||||
/// This is an advanced API used by intersection. The API contract is tricky, avoid using it.
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
let current_doc = self.doc();
|
||||
if current_doc < target {
|
||||
self.seek(target);
|
||||
}
|
||||
self.doc() == target
|
||||
}
|
||||
|
||||
/// Fills a given mutable buffer with the next doc ids from the
|
||||
/// `DocSet`
|
||||
///
|
||||
@@ -94,6 +123,15 @@ pub trait DocSet: Send {
|
||||
/// which would be the number of documents in the DocSet.
|
||||
///
|
||||
/// By default this returns `size_hint()`.
|
||||
///
|
||||
/// DocSets may have vastly different cost depending on their type,
|
||||
/// e.g. an intersection with 10 hits is much cheaper than
|
||||
/// a phrase search with 10 hits, since it needs to load positions.
|
||||
///
|
||||
/// ### Future Work
|
||||
/// We may want to differentiate `DocSet` costs more more granular, e.g.
|
||||
/// creation_cost, advance_cost, seek_cost on to get a good estimation
|
||||
/// what query types to choose.
|
||||
fn cost(&self) -> u64 {
|
||||
self.size_hint() as u64
|
||||
}
|
||||
@@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet {
|
||||
(**self).seek(target)
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
(**self).seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn doc(&self) -> u32 {
|
||||
(**self).doc()
|
||||
}
|
||||
@@ -169,6 +211,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
|
||||
unboxed.seek(target)
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||
unboxed.seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
|
||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||
unboxed.fill_buffer(buffer)
|
||||
|
||||
@@ -110,11 +110,6 @@ pub enum TantivyError {
|
||||
#[error("Deserialize error: {0}")]
|
||||
/// An error occurred while attempting to deserialize a document.
|
||||
DeserializeError(DeserializeError),
|
||||
/// The user requested the current operation be cancelled
|
||||
#[error("User requested cancel")]
|
||||
Cancelled,
|
||||
#[error("Segment Merging failed: {0:#?}")]
|
||||
MergeErrors(Vec<TantivyError>),
|
||||
}
|
||||
|
||||
impl From<io::Error> for TantivyError {
|
||||
|
||||
@@ -79,7 +79,7 @@ mod tests {
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::path::Path;
|
||||
|
||||
use columnar::{StrColumn, ValueRange};
|
||||
use columnar::StrColumn;
|
||||
use common::{ByteCount, DateTimePrecision, HasLen, TerminatingWrite};
|
||||
use once_cell::sync::Lazy;
|
||||
use rand::prelude::SliceRandom;
|
||||
@@ -395,7 +395,7 @@ mod tests {
|
||||
.unwrap()
|
||||
.first_or_default_col(0);
|
||||
for a in 0..n {
|
||||
assert_eq!(col.get_val(a as u32), permutation[a], "for doc {a}");
|
||||
assert_eq!(col.get_val(a as u32), permutation[a]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -944,7 +944,7 @@ mod tests {
|
||||
let test_range = |range: RangeInclusive<u64>| {
|
||||
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
|
||||
let mut vec = vec![];
|
||||
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
|
||||
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
|
||||
assert_eq!(vec.len(), expected_count);
|
||||
};
|
||||
test_range(50..=50);
|
||||
@@ -1022,7 +1022,7 @@ mod tests {
|
||||
let test_range = |range: RangeInclusive<u64>| {
|
||||
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
|
||||
let mut vec = vec![];
|
||||
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
|
||||
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
|
||||
assert_eq!(vec.len(), expected_count);
|
||||
};
|
||||
let test_range_variant = |start, stop| {
|
||||
|
||||
@@ -30,30 +30,22 @@ fn load_metas(
|
||||
directory: &dyn Directory,
|
||||
inventory: &SegmentMetaInventory,
|
||||
) -> crate::Result<IndexMeta> {
|
||||
match directory.load_metas(inventory) {
|
||||
Ok(metas) => Ok(metas),
|
||||
Err(crate::TantivyError::InternalError(_)) => {
|
||||
let meta_data = directory.atomic_read(&META_FILEPATH)?;
|
||||
let meta_string = String::from_utf8(meta_data).map_err(|_utf8_err| {
|
||||
error!("Meta data is not valid utf8.");
|
||||
DataCorruption::new(
|
||||
META_FILEPATH.to_path_buf(),
|
||||
"Meta file does not contain valid utf8 file.".to_string(),
|
||||
)
|
||||
})?;
|
||||
IndexMeta::deserialize(&meta_string, inventory)
|
||||
.map_err(|e| {
|
||||
DataCorruption::new(
|
||||
META_FILEPATH.to_path_buf(),
|
||||
format!(
|
||||
"Meta file cannot be deserialized. {e:?}. Content: {meta_string:?}"
|
||||
),
|
||||
)
|
||||
})
|
||||
.map_err(From::from)
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
let meta_data = directory.atomic_read(&META_FILEPATH)?;
|
||||
let meta_string = String::from_utf8(meta_data).map_err(|_utf8_err| {
|
||||
error!("Meta data is not valid utf8.");
|
||||
DataCorruption::new(
|
||||
META_FILEPATH.to_path_buf(),
|
||||
"Meta file does not contain valid utf8 file.".to_string(),
|
||||
)
|
||||
})?;
|
||||
IndexMeta::deserialize(&meta_string, inventory)
|
||||
.map_err(|e| {
|
||||
DataCorruption::new(
|
||||
META_FILEPATH.to_path_buf(),
|
||||
format!("Meta file cannot be deserialized. {e:?}. Content: {meta_string:?}"),
|
||||
)
|
||||
})
|
||||
.map_err(From::from)
|
||||
}
|
||||
|
||||
/// Save the index meta file.
|
||||
@@ -68,14 +60,16 @@ fn save_new_metas(
|
||||
index_settings: IndexSettings,
|
||||
directory: &dyn Directory,
|
||||
) -> crate::Result<()> {
|
||||
let empty_metas = IndexMeta {
|
||||
index_settings,
|
||||
segments: Vec::new(),
|
||||
schema,
|
||||
opstamp: 0u64,
|
||||
payload: None,
|
||||
};
|
||||
save_metas(&empty_metas, &empty_metas, directory)?;
|
||||
save_metas(
|
||||
&IndexMeta {
|
||||
index_settings,
|
||||
segments: Vec::new(),
|
||||
schema,
|
||||
opstamp: 0u64,
|
||||
payload: None,
|
||||
},
|
||||
directory,
|
||||
)?;
|
||||
directory.sync_directory()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -588,7 +582,7 @@ impl Index {
|
||||
num_threads: usize,
|
||||
overall_memory_budget_in_bytes: usize,
|
||||
) -> crate::Result<IndexWriter<D>> {
|
||||
let memory_arena_in_bytes_per_thread = overall_memory_budget_in_bytes / num_threads.max(1);
|
||||
let memory_arena_in_bytes_per_thread = overall_memory_budget_in_bytes / num_threads;
|
||||
let options = IndexWriterOptions::builder()
|
||||
.num_worker_threads(num_threads)
|
||||
.memory_budget_per_thread(memory_arena_in_bytes_per_thread)
|
||||
@@ -661,11 +655,9 @@ impl Index {
|
||||
|
||||
/// Creates a new segment.
|
||||
pub fn new_segment(&self) -> Segment {
|
||||
self.new_segment_with_id(SegmentId::generate_random())
|
||||
}
|
||||
|
||||
pub fn new_segment_with_id(&self, segment_id: SegmentId) -> Segment {
|
||||
let segment_meta = self.inventory.new_segment_meta(segment_id, 0);
|
||||
let segment_meta = self
|
||||
.inventory
|
||||
.new_segment_meta(SegmentId::generate_random(), 0);
|
||||
self.segment(segment_meta)
|
||||
}
|
||||
|
||||
@@ -696,7 +688,7 @@ impl Index {
|
||||
|
||||
/// Returns the set of corrupted files
|
||||
pub fn validate_checksum(&self) -> crate::Result<HashSet<PathBuf>> {
|
||||
let managed_files = self.directory.list_managed_files()?;
|
||||
let managed_files = self.directory.list_managed_files();
|
||||
let active_segments_files: HashSet<PathBuf> = self
|
||||
.searchable_segment_metas()?
|
||||
.iter()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
@@ -15,23 +14,15 @@ use crate::{Inventory, Opstamp, TrackedObject};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct DeleteMeta {
|
||||
pub num_deleted_docs: u32,
|
||||
num_deleted_docs: u32,
|
||||
pub opstamp: Opstamp,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SegmentMetaInventory {
|
||||
pub(crate) struct SegmentMetaInventory {
|
||||
inventory: Inventory<InnerSegmentMeta>,
|
||||
}
|
||||
|
||||
impl Debug for SegmentMetaInventory {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SegmentMetaInventory")
|
||||
.field("inventory", &self.inventory.list())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentMetaInventory {
|
||||
/// Lists all living `SegmentMeta` object at the time of the call.
|
||||
pub fn all(&self) -> Vec<SegmentMeta> {
|
||||
@@ -59,7 +50,7 @@ impl SegmentMetaInventory {
|
||||
/// how many are deleted, etc.
|
||||
#[derive(Clone)]
|
||||
pub struct SegmentMeta {
|
||||
pub tracked: TrackedObject<InnerSegmentMeta>,
|
||||
tracked: TrackedObject<InnerSegmentMeta>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for SegmentMeta {
|
||||
@@ -219,15 +210,15 @@ impl SegmentMeta {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InnerSegmentMeta {
|
||||
pub segment_id: SegmentId,
|
||||
pub max_doc: u32,
|
||||
struct InnerSegmentMeta {
|
||||
segment_id: SegmentId,
|
||||
max_doc: u32,
|
||||
pub deletes: Option<DeleteMeta>,
|
||||
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
|
||||
/// garbage collection and deleted, set this to true. This is used during merge.
|
||||
#[serde(skip)]
|
||||
#[serde(default = "default_temp_store")]
|
||||
pub include_temp_doc_store: Arc<AtomicBool>,
|
||||
pub(crate) include_temp_doc_store: Arc<AtomicBool>,
|
||||
}
|
||||
fn default_temp_store() -> Arc<AtomicBool> {
|
||||
Arc::new(AtomicBool::new(false))
|
||||
@@ -413,7 +404,10 @@ mod tests {
|
||||
schema_builder.build()
|
||||
};
|
||||
let index_metas = IndexMeta {
|
||||
index_settings: IndexSettings::default(),
|
||||
index_settings: IndexSettings {
|
||||
docstore_compression: Compressor::None,
|
||||
..Default::default()
|
||||
},
|
||||
segments: Vec::new(),
|
||||
schema,
|
||||
opstamp: 0u64,
|
||||
@@ -422,7 +416,7 @@ mod tests {
|
||||
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
|
||||
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
|
||||
);
|
||||
|
||||
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
|
||||
@@ -503,6 +497,8 @@ mod tests {
|
||||
#[test]
|
||||
#[cfg(feature = "lz4-compression")]
|
||||
fn test_index_settings_default() {
|
||||
use crate::store::Compressor;
|
||||
|
||||
let mut index_settings = IndexSettings::default();
|
||||
assert_eq!(
|
||||
index_settings,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::io;
|
||||
|
||||
use common::file_slice::DeferredFileSlice;
|
||||
use common::json_path_writer::JSON_END_OF_PATH;
|
||||
use common::{BinarySerializable, ByteCount};
|
||||
#[cfg(feature = "quickwit")]
|
||||
@@ -31,7 +30,7 @@ use crate::termdict::TermDictionary;
|
||||
pub struct InvertedIndexReader {
|
||||
termdict: TermDictionary,
|
||||
postings_file_slice: FileSlice,
|
||||
positions_file_slice: DeferredFileSlice,
|
||||
positions_file_slice: FileSlice,
|
||||
record_option: IndexRecordOption,
|
||||
total_num_tokens: u64,
|
||||
}
|
||||
@@ -67,7 +66,7 @@ impl InvertedIndexReader {
|
||||
pub(crate) fn new(
|
||||
termdict: TermDictionary,
|
||||
postings_file_slice: FileSlice,
|
||||
positions_file_slice: DeferredFileSlice,
|
||||
positions_file_slice: FileSlice,
|
||||
record_option: IndexRecordOption,
|
||||
) -> io::Result<InvertedIndexReader> {
|
||||
let (total_num_tokens_slice, postings_body) = postings_file_slice.split(8);
|
||||
@@ -87,7 +86,7 @@ impl InvertedIndexReader {
|
||||
InvertedIndexReader {
|
||||
termdict: TermDictionary::empty(),
|
||||
postings_file_slice: FileSlice::empty(),
|
||||
positions_file_slice: DeferredFileSlice::new(|| Ok(FileSlice::empty())),
|
||||
positions_file_slice: FileSlice::empty(),
|
||||
record_option,
|
||||
total_num_tokens: 0u64,
|
||||
}
|
||||
@@ -212,7 +211,7 @@ impl InvertedIndexReader {
|
||||
.slice(term_info.postings_range.clone());
|
||||
BlockSegmentPostings::open(
|
||||
term_info.doc_freq,
|
||||
postings_data.read_bytes()?,
|
||||
postings_data,
|
||||
self.record_option,
|
||||
requested_option,
|
||||
)
|
||||
@@ -234,7 +233,6 @@ impl InvertedIndexReader {
|
||||
if option.has_positions() {
|
||||
let positions_data = self
|
||||
.positions_file_slice
|
||||
.open()?
|
||||
.read_bytes_slice(term_info.positions_range.clone())?;
|
||||
let position_reader = PositionReader::open(positions_data)?;
|
||||
Some(position_reader)
|
||||
@@ -344,7 +342,6 @@ impl InvertedIndexReader {
|
||||
if with_positions {
|
||||
let positions = self
|
||||
.positions_file_slice
|
||||
.open()?
|
||||
.read_bytes_slice_async(term_info.positions_range.clone());
|
||||
futures_util::future::try_join(postings, positions).await?;
|
||||
} else {
|
||||
@@ -387,7 +384,6 @@ impl InvertedIndexReader {
|
||||
if with_positions {
|
||||
let positions = self
|
||||
.positions_file_slice
|
||||
.open()?
|
||||
.read_bytes_slice_async(positions_range);
|
||||
futures_util::future::try_join(postings, positions).await?;
|
||||
} else {
|
||||
@@ -482,7 +478,7 @@ impl InvertedIndexReader {
|
||||
pub async fn warm_postings_full(&self, with_positions: bool) -> io::Result<()> {
|
||||
self.postings_file_slice.read_bytes_async().await?;
|
||||
if with_positions {
|
||||
self.positions_file_slice.open()?.read_bytes_async().await?;
|
||||
self.positions_file_slice.read_bytes_async().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
use std::io;
|
||||
|
||||
use crate::directory::{BufferedFileSlice, FileSlice};
|
||||
use crate::positions::PositionReader;
|
||||
use crate::postings::{BlockSegmentPostings, SegmentPostings, TermInfo};
|
||||
use crate::schema::IndexRecordOption;
|
||||
use crate::termdict::TermDictionary;
|
||||
|
||||
/// The inverted index reader is in charge of accessing
|
||||
/// the inverted index associated with a specific field.
|
||||
///
|
||||
/// This is optimized for merging in that it uses a buffered reader
|
||||
/// for the postings and positions files.
|
||||
/// This eliminates most disk I/O to these files during merging, without
|
||||
/// reading the entire file into memory at once.
|
||||
///
|
||||
/// NB: This is a copy/paste from [`InvertedIndexReader`] and trimmed
|
||||
/// down to only include the methods required by the merge process.
|
||||
pub(crate) struct MergeOptimizedInvertedIndexReader {
|
||||
termdict: TermDictionary,
|
||||
postings_reader: BufferedFileSlice,
|
||||
positions_reader: BufferedFileSlice,
|
||||
record_option: IndexRecordOption,
|
||||
}
|
||||
|
||||
impl MergeOptimizedInvertedIndexReader {
|
||||
pub(crate) fn new(
|
||||
termdict: TermDictionary,
|
||||
postings_file_slice: FileSlice,
|
||||
positions_file_slice: FileSlice,
|
||||
record_option: IndexRecordOption,
|
||||
) -> io::Result<MergeOptimizedInvertedIndexReader> {
|
||||
let (_, postings_body) = postings_file_slice.split(8);
|
||||
Ok(MergeOptimizedInvertedIndexReader {
|
||||
termdict,
|
||||
postings_reader: BufferedFileSlice::new_with_default_buffer_size(postings_body),
|
||||
positions_reader: BufferedFileSlice::new_with_default_buffer_size(positions_file_slice),
|
||||
record_option,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an empty `InvertedIndexReader` object, which
|
||||
/// contains no terms at all.
|
||||
pub fn empty(record_option: IndexRecordOption) -> MergeOptimizedInvertedIndexReader {
|
||||
MergeOptimizedInvertedIndexReader {
|
||||
termdict: TermDictionary::empty(),
|
||||
postings_reader: BufferedFileSlice::empty(),
|
||||
positions_reader: BufferedFileSlice::empty(),
|
||||
record_option,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the term dictionary datastructure.
|
||||
pub fn terms(&self) -> &TermDictionary {
|
||||
&self.termdict
|
||||
}
|
||||
|
||||
/// Returns a block postings given a `term_info`.
|
||||
/// This method is for an advanced usage only.
|
||||
///
|
||||
/// Most users should prefer using [`Self::read_postings()`] instead.
|
||||
pub fn read_block_postings_from_terminfo(
|
||||
&self,
|
||||
term_info: &TermInfo,
|
||||
requested_option: IndexRecordOption,
|
||||
) -> io::Result<BlockSegmentPostings> {
|
||||
let postings_data = self.postings_reader.get_bytes(
|
||||
term_info.postings_range.start as u64..term_info.postings_range.end as u64,
|
||||
)?;
|
||||
BlockSegmentPostings::open(
|
||||
term_info.doc_freq,
|
||||
postings_data,
|
||||
self.record_option,
|
||||
requested_option,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a posting object given a `term_info`.
|
||||
/// This method is for an advanced usage only.
|
||||
///
|
||||
/// Most users should prefer using [`Self::read_postings()`] instead.
|
||||
pub fn read_postings_from_terminfo(
|
||||
&self,
|
||||
term_info: &TermInfo,
|
||||
option: IndexRecordOption,
|
||||
) -> io::Result<SegmentPostings> {
|
||||
let option = option.downgrade(self.record_option);
|
||||
|
||||
let block_postings = self.read_block_postings_from_terminfo(term_info, option)?;
|
||||
let position_reader = {
|
||||
if option.has_positions() {
|
||||
let positions_data = self.positions_reader.get_bytes(
|
||||
term_info.positions_range.start as u64..term_info.positions_range.end as u64,
|
||||
)?;
|
||||
let position_reader = PositionReader::open(positions_data)?;
|
||||
Some(position_reader)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
Ok(SegmentPostings::from_block_postings(
|
||||
block_postings,
|
||||
position_reader,
|
||||
))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user