mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-02 23:32:54 +00:00
Compare commits
23 Commits
bucket_id_
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77505c3d03 | ||
|
|
735c588f4f | ||
|
|
242a1531bf | ||
|
|
6443b63177 | ||
|
|
4987495ee4 | ||
|
|
b11605f045 | ||
|
|
75d7989cc6 | ||
|
|
923f0508f2 | ||
|
|
e0b62e00ac | ||
|
|
ce97beb86f | ||
|
|
c0f21a45ae | ||
|
|
73657dff77 | ||
|
|
e3c9be1f92 | ||
|
|
ba61ed6ef3 | ||
|
|
d0e1600135 | ||
|
|
e9020d17d4 | ||
|
|
5ba0031f7d | ||
|
|
22dde8f9ae | ||
|
|
14cc24614e | ||
|
|
8a1079b2dc | ||
|
|
794ff1ffc9 | ||
|
|
c6912ce89a | ||
|
|
618e3bd11b |
4
.github/workflows/coverage.yml
vendored
4
.github/workflows/coverage.yml
vendored
@@ -15,11 +15,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
|
||||
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
continue-on-error: true
|
||||
|
||||
30
.github/workflows/test.yml
vendored
30
.github/workflows/test.yml
vendored
@@ -39,11 +39,11 @@ jobs:
|
||||
|
||||
- name: Check Formatting
|
||||
run: cargo +nightly fmt --all -- --check
|
||||
|
||||
|
||||
- name: Check Stable Compilation
|
||||
run: cargo build --all-features
|
||||
|
||||
|
||||
|
||||
- name: Check Bench Compilation
|
||||
run: cargo +nightly bench --no-run --profile=dev --all-features
|
||||
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
features: [
|
||||
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
|
||||
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
]
|
||||
features:
|
||||
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
|
||||
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
|
||||
- { label: "none", flags: "" }
|
||||
|
||||
name: test-${{ matrix.features.label}}
|
||||
|
||||
@@ -80,7 +80,21 @@ jobs:
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Run tests
|
||||
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
cargo +stable nextest run --lib --no-default-features --verbose --workspace
|
||||
else
|
||||
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
|
||||
fi
|
||||
|
||||
- name: Run doctests
|
||||
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
run: |
|
||||
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
|
||||
# (as most of them rely on mmap) otherwise run all
|
||||
if [ -z "${{ matrix.features.flags }}" ]; then
|
||||
echo "no doctest for no feature flag"
|
||||
else
|
||||
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
|
||||
fi
|
||||
|
||||
21
Cargo.toml
21
Cargo.toml
@@ -37,7 +37,7 @@ fs4 = { version = "0.13.1", optional = true }
|
||||
levenshtein_automata = "0.2.1"
|
||||
uuid = { version = "1.0.0", features = ["v4", "serde"] }
|
||||
crossbeam-channel = "0.5.4"
|
||||
rust-stemmers = "1.2.0"
|
||||
rust-stemmers = { version = "1.2.0", optional = true }
|
||||
downcast-rs = "2.0.1"
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
@@ -75,12 +75,12 @@ typetag = "0.2.21"
|
||||
winapi = "0.3.9"
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.14.0"
|
||||
binggan = "0.14.2"
|
||||
rand = "0.8.5"
|
||||
maplit = "1.0.2"
|
||||
matches = "0.1.9"
|
||||
pretty_assertions = "1.2.1"
|
||||
proptest = "1.0.0"
|
||||
proptest = "1.7.0"
|
||||
test-log = "0.2.10"
|
||||
futures = "0.3.21"
|
||||
paste = "1.0.11"
|
||||
@@ -113,7 +113,8 @@ debug-assertions = true
|
||||
overflow-checks = true
|
||||
|
||||
[features]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
|
||||
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
|
||||
stemmer = ["rust-stemmers"]
|
||||
mmap = ["fs4", "tempfile", "memmap2"]
|
||||
stopwords = []
|
||||
|
||||
@@ -173,6 +174,18 @@ harness = false
|
||||
name = "exists_json"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_query"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "and_or_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "range_queries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bool_queries_with_range"
|
||||
harness = false
|
||||
|
||||
@@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, stats_f64);
|
||||
register!(group, extendedstats_f64);
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_7);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_all_unique);
|
||||
register!(group, terms_150_000);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_few_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status);
|
||||
register!(group, terms_few_with_histogram);
|
||||
register!(group, terms_status_with_histogram);
|
||||
register!(group, terms_zipf_1000);
|
||||
register!(group, terms_zipf_1000_with_histogram);
|
||||
register!(group, terms_zipf_1000_with_avg_sub_agg);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, terms_status_with_cardinality_agg);
|
||||
register!(group, terms_few_with_cardinality_agg);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
register!(group, range_agg_with_term_agg_status);
|
||||
register!(group, range_agg_with_term_agg_few);
|
||||
register!(group, range_agg_with_term_agg_many);
|
||||
register!(group, histogram);
|
||||
register!(group, histogram_hard_bounds);
|
||||
register!(group, histogram_with_avg_sub_agg);
|
||||
register!(group, histogram_with_term_agg_status);
|
||||
register!(group, histogram_with_term_agg_few);
|
||||
register!(group, avg_and_range_with_avg_sub_agg);
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
@@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
@@ -175,7 +175,13 @@ fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_7(index: &Index) {
|
||||
fn terms_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
});
|
||||
@@ -188,7 +194,7 @@ fn terms_all_unique(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_150_000(index: &Index) {
|
||||
fn terms_many(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms" } },
|
||||
});
|
||||
@@ -247,6 +253,17 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -259,18 +276,17 @@ fn terms_status_with_histogram(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_histogram(index: &Index) {
|
||||
fn terms_few_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -283,25 +299,6 @@ fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -357,7 +354,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn range_agg_with_term_agg_status(index: &Index) {
|
||||
fn range_agg_with_term_agg_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"range": {
|
||||
@@ -372,7 +369,7 @@ fn range_agg_with_term_agg_status(index: &Index) {
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -428,12 +425,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn histogram_with_term_agg_status(index: &Index) {
|
||||
fn histogram_with_term_agg_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"histogram": { "field": "score_f64", "interval": 10 },
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } }
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -478,13 +475,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
// Flag to use existing index
|
||||
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
|
||||
if reuse_index && std::path::Path::new("agg_bench").exists() {
|
||||
return Index::open_in_dir("agg_bench");
|
||||
}
|
||||
// crreate dir
|
||||
std::fs::create_dir_all("agg_bench")?;
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_fieldtype = tantivy::schema::TextOptions::default()
|
||||
.set_indexing_options(
|
||||
@@ -496,44 +486,24 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
|
||||
let text_field_few_terms_status =
|
||||
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
|
||||
let text_field_1000_terms_zipf =
|
||||
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
|
||||
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
|
||||
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
|
||||
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
|
||||
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
|
||||
// use tmp dir
|
||||
let index = if reuse_index {
|
||||
Index::create_in_dir("agg_bench", schema_builder.build())?
|
||||
} else {
|
||||
Index::create_from_tempdir(schema_builder.build())?
|
||||
};
|
||||
// Approximate log proportions
|
||||
let status_field_data = [
|
||||
("INFO", 8000),
|
||||
("ERROR", 300),
|
||||
("WARN", 1200),
|
||||
("DEBUG", 500),
|
||||
("OK", 500),
|
||||
("CRITICAL", 20),
|
||||
("EMERGENCY", 1),
|
||||
];
|
||||
let log_level_distribution =
|
||||
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
|
||||
let index = Index::create_from_tempdir(schema_builder.build())?;
|
||||
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
|
||||
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
|
||||
|
||||
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
|
||||
|
||||
let many_terms_data = (0..150_000)
|
||||
.map(|num| format!("author{num}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Prepare 1000 unique terms sampled using a Zipf distribution.
|
||||
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
|
||||
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
|
||||
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
|
||||
|
||||
{
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
|
||||
@@ -543,12 +513,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!())?;
|
||||
}
|
||||
if cardinality == Cardinality::Multivalued {
|
||||
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let term_1000_a = &terms_1000[idx_a];
|
||||
let term_1000_b = &terms_1000[idx_b];
|
||||
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
@@ -558,10 +524,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
text_field_all_unique_terms => "coolo",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms_status => log_level_sample_a,
|
||||
text_field_few_terms_status => log_level_sample_b,
|
||||
text_field_1000_terms_zipf => term_1000_a.as_str(),
|
||||
text_field_1000_terms_zipf => term_1000_b.as_str(),
|
||||
score_field => 1u64,
|
||||
score_field => 1u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
@@ -588,8 +554,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
|
||||
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
|
||||
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
|
||||
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
|
||||
score_field => val as u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
score_field_i64 => val as i64,
|
||||
@@ -641,7 +607,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -657,7 +623,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
"terms": { "field": "text_few_terms" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
288
benches/bool_queries_with_range.rs
Normal file
288
benches/bool_queries_with_range.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
|
||||
// Unified schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
|
||||
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.gen_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.gen_range(0u64..1000u64);
|
||||
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
// Always add title to avoid empty documents
|
||||
let title_token = if rng.gen_bool(p_title_a as f64) {
|
||||
"a"
|
||||
} else {
|
||||
"b"
|
||||
};
|
||||
|
||||
let num_rand = rng.gen_range(0u64..10000000u64);
|
||||
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_token,
|
||||
f_num_rand=>num_rand,
|
||||
f_num_asc=>num_asc,
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Build query parser for title field
|
||||
let qp_title = QueryParser::for_index(&index, vec![f_title]);
|
||||
|
||||
BenchIndex {
|
||||
index,
|
||||
searcher,
|
||||
query_parser: qp_title,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse and 99% a".to_string(),
|
||||
10_000_000,
|
||||
0.99,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define all four field types
|
||||
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Define the three terms we want to test with
|
||||
let terms = ["a", "b", "z"];
|
||||
|
||||
// Generate all combinations of terms and field names
|
||||
let mut queries = Vec::new();
|
||||
for &term in &terms {
|
||||
for &field_name in &field_names {
|
||||
let query_str = format!(
|
||||
"{} AND {}:[{} TO {}]",
|
||||
term, field_name, range_low, range_high
|
||||
);
|
||||
queries.push((query_str, field_name.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let query_str = format!(
|
||||
"{}:[{} TO {}] AND {}:[{} TO {}]",
|
||||
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
|
||||
);
|
||||
queries.push((query_str, "num_asc_fast".to_string()));
|
||||
|
||||
// Run all benchmark tasks for each query and its corresponding field name
|
||||
for (query_str, field_name) in queries {
|
||||
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given query string and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
field_name: &str,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task(bench_group, bench_index, query_str, Count, "count");
|
||||
|
||||
// Test all results
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
DocSetCollector,
|
||||
"all results",
|
||||
);
|
||||
|
||||
// Test top 100 by the field (if it's a FAST field)
|
||||
if field_name.ends_with("_fast") {
|
||||
// Ascending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
|
||||
// Descending order
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
|
||||
&collector_name,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
|
||||
*count
|
||||
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(top_docs) =
|
||||
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
|
||||
{
|
||||
top_docs.len()
|
||||
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
|
||||
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
|
||||
{
|
||||
doc_set.len()
|
||||
} else {
|
||||
eprintln!(
|
||||
"Unknown collector result type: {:?}",
|
||||
std::any::type_name::<C::Fruit>()
|
||||
);
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
365
benches/range_queries.rs
Normal file
365
benches/range_queries.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use std::ops::Bound;
|
||||
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, DocSetCollector, TopDocs};
|
||||
use tantivy::query::RangeQuery;
|
||||
use tantivy::schema::{Schema, FAST, INDEXED};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
#[allow(dead_code)]
|
||||
index: Index,
|
||||
searcher: Searcher,
|
||||
}
|
||||
|
||||
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
|
||||
// Schema with fast fields only
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
|
||||
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
// Populate index with stable RNG for reproducibility.
|
||||
let mut rng = StdRng::from_seed([7u8; 32]);
|
||||
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
|
||||
|
||||
match distribution {
|
||||
"dense" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.gen_range(0u64..1000u64);
|
||||
let num_asc = (doc_id / 10000) as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
"sparse" => {
|
||||
for doc_id in 0..num_docs {
|
||||
let num_rand = rng.gen_range(0u64..10000000u64);
|
||||
let num_asc = doc_id as u64;
|
||||
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_num_rand_fast=>num_rand,
|
||||
f_num_asc_fast=>num_asc,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Unsupported distribution type");
|
||||
}
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
// Prepare reader/searcher once.
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::Manual)
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
BenchIndex { index, searcher }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Prepare corpora with varying scenarios
|
||||
let scenarios = vec![
|
||||
// Dense distribution - random values in small range (0-999)
|
||||
(
|
||||
"dense_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"dense_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
990,
|
||||
999,
|
||||
),
|
||||
(
|
||||
"dense_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"dense",
|
||||
1000,
|
||||
1002,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_low_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
0,
|
||||
9,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_high_value_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
9_999_990,
|
||||
9_999_999,
|
||||
),
|
||||
(
|
||||
"sparse_values_search_out_of_range".to_string(),
|
||||
10_000_000,
|
||||
"sparse",
|
||||
10_000_000,
|
||||
10_000_002,
|
||||
),
|
||||
];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
|
||||
// Build index for this scenario
|
||||
let bench_index = build_shared_indices(n, num_rand_distribution);
|
||||
|
||||
// Create benchmark group
|
||||
let mut group = runner.new_group();
|
||||
|
||||
// Now set the name (this moves scenario_id)
|
||||
group.set_name(scenario_id);
|
||||
|
||||
// Define fast field types
|
||||
let field_names = ["num_rand_fast", "num_asc_fast"];
|
||||
|
||||
// Generate range queries for fast fields
|
||||
for &field_name in &field_names {
|
||||
// Create the range query
|
||||
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
|
||||
let lower_term = Term::from_field_u64(field, range_low);
|
||||
let upper_term = Term::from_field_u64(field, range_high);
|
||||
|
||||
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
|
||||
|
||||
run_benchmark_tasks(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
}
|
||||
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
/// Run all benchmark tasks for a given range query and field name
|
||||
fn run_benchmark_tasks(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
// Test count
|
||||
add_bench_task_count(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
"count",
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
);
|
||||
|
||||
// Test top 100 by the field (ascending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_asc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_asc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query.clone(),
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
|
||||
// Test top 100 by the field (descending order)
|
||||
{
|
||||
let collector_name = format!("top100_by_{}_desc", field_name);
|
||||
let field_name_owned = field_name.to_string();
|
||||
add_bench_task_top100_desc(
|
||||
bench_group,
|
||||
bench_index,
|
||||
query,
|
||||
&collector_name,
|
||||
field_name,
|
||||
range_low,
|
||||
range_high,
|
||||
field_name_owned,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task_count(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = CountSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_docset(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = DocSetSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_asc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100AscSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
fn add_bench_task_top100_desc(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query: RangeQuery,
|
||||
collector_name: &str,
|
||||
field_name: &str,
|
||||
range_low: u64,
|
||||
range_high: u64,
|
||||
field_name_owned: String,
|
||||
) {
|
||||
let task_name = format!(
|
||||
"range_{}_[{} TO {}]_{}",
|
||||
field_name, range_low, range_high, collector_name
|
||||
);
|
||||
|
||||
let search_task = Top100DescSearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
query,
|
||||
field_name: field_name_owned,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct CountSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl CountSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &Count).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct DocSetSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
}
|
||||
|
||||
impl DocSetSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100AscSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100AscSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct Top100DescSearchTask {
|
||||
searcher: Searcher,
|
||||
query: RangeQuery,
|
||||
field_name: String,
|
||||
}
|
||||
|
||||
impl Top100DescSearchTask {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
let collector =
|
||||
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
|
||||
let result = self.searcher.search(&self.query, &collector).unwrap();
|
||||
for (_score, doc_address) in &result {
|
||||
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
|
||||
}
|
||||
result.len()
|
||||
}
|
||||
}
|
||||
260
benches/range_query.rs
Normal file
260
benches/range_query.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
use std::fmt::Display;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy::collector::{Count, TopDocs};
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::*;
|
||||
use tantivy::{doc, Index};
|
||||
|
||||
#[global_allocator]
|
||||
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
|
||||
|
||||
fn main() {
|
||||
bench_range_query();
|
||||
}
|
||||
|
||||
fn bench_range_query() {
|
||||
let index = get_index_0_to_100();
|
||||
let mut runner = BenchRunner::new();
|
||||
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
|
||||
|
||||
runner.set_name("range_query on u64");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("id", "Single Valued Range Field"),
|
||||
("ids", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent()),
|
||||
("10_percent", get_10_percent()),
|
||||
("1_percent", get_1_percent()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
|
||||
runner.set_name("range_query on ip");
|
||||
let field_name_and_descr: Vec<_> = vec![
|
||||
("ip", "Single Valued Range Field"),
|
||||
("ips", "Multi Valued Range Field"),
|
||||
];
|
||||
let range_num_hits = vec![
|
||||
("90_percent", get_90_percent_ip()),
|
||||
("10_percent", get_10_percent_ip()),
|
||||
("1_percent", get_1_percent_ip()),
|
||||
];
|
||||
|
||||
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
|
||||
}
|
||||
|
||||
fn test_range<T: Display>(
|
||||
runner: &mut BenchRunner,
|
||||
index: &Index,
|
||||
field_name_and_descr: &[(&str, &str)],
|
||||
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
|
||||
) {
|
||||
for (field, suffix) in field_name_and_descr {
|
||||
let term_num_hits = vec![
|
||||
("", ""),
|
||||
("1_percent", "veryfew"),
|
||||
("10_percent", "few"),
|
||||
("90_percent", "most"),
|
||||
];
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(suffix);
|
||||
// all intersect combinations
|
||||
for (range_name, range) in &range_num_hits {
|
||||
for (term_name, term) in &term_num_hits {
|
||||
let index = &index;
|
||||
let test_name = if term_name.is_empty() {
|
||||
format!("id_range_hit_{}", range_name)
|
||||
} else {
|
||||
format!(
|
||||
"id_range_hit_{}_intersect_with_term_{}",
|
||||
range_name, term_name
|
||||
)
|
||||
};
|
||||
group.register(test_name, move |_| {
|
||||
let query = if term_name.is_empty() {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("AND id_name:{}", term)
|
||||
};
|
||||
black_box(execute_query(field, range, &query, index));
|
||||
});
|
||||
}
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id_name = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"most".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id_name,
|
||||
id: rng.gen_range(0..100),
|
||||
// Multiply by 1000, so that we create most buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_docs(&docs)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
pub id_name: String,
|
||||
pub id: u64,
|
||||
pub ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
|
||||
let ids_u64_field =
|
||||
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
|
||||
|
||||
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
|
||||
let ids_f64_field = schema_builder.add_f64_field(
|
||||
"ids_f64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
|
||||
let ids_i64_field = schema_builder.add_i64_field(
|
||||
"ids_i64",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
|
||||
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
|
||||
|
||||
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
|
||||
|
||||
let schema = schema_builder.build();
|
||||
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
|
||||
for doc in docs.iter() {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_i64_field => doc.id as i64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_f64_field => doc.id as f64,
|
||||
ids_u64_field => doc.id,
|
||||
ids_u64_field => doc.id,
|
||||
id_u64_field => doc.id,
|
||||
id_f64_field => doc.id as f64,
|
||||
id_i64_field => doc.id as i64,
|
||||
text_field => doc.id_name.to_string(),
|
||||
text_field2 => doc.id_name.to_string(),
|
||||
ips_field => doc.ip,
|
||||
ips_field => doc.ip,
|
||||
ip_field => doc.ip,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<u64> {
|
||||
0..=90
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<u64> {
|
||||
0..=10
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<u64> {
|
||||
10..=10
|
||||
}
|
||||
|
||||
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
struct NumHits {
|
||||
count: usize,
|
||||
}
|
||||
impl OutputValue for NumHits {
|
||||
fn column_title() -> &'static str {
|
||||
"NumHits"
|
||||
}
|
||||
fn format(&self) -> Option<String> {
|
||||
Some(self.count.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_query<T: Display>(
|
||||
field: &str,
|
||||
id_range: &RangeInclusive<T>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> NumHits {
|
||||
let gen_query_inclusive = |from: &T, to: &T| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(id_range.start(), id_range.end());
|
||||
execute_query_(&query, index)
|
||||
}
|
||||
|
||||
fn execute_query_(query: &str, index: &Index) -> NumHits {
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let num_hits = searcher
|
||||
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
|
||||
.unwrap()
|
||||
.1;
|
||||
NumHits { count: num_hits }
|
||||
}
|
||||
@@ -29,20 +29,12 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub fn fetch_block_with_missing(
|
||||
&mut self,
|
||||
docs: &[u32],
|
||||
accessor: &Column<T>,
|
||||
missing: Option<T>,
|
||||
) {
|
||||
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
|
||||
self.fetch_block(docs, accessor);
|
||||
// no missing values
|
||||
if accessor.index.get_cardinality().is_full() {
|
||||
return;
|
||||
}
|
||||
let Some(missing) = missing else {
|
||||
return;
|
||||
};
|
||||
|
||||
// We can compare docid_cache length with docs to find missing docs
|
||||
// For multi value columns we can't rely on the length and always need to scan
|
||||
|
||||
@@ -41,12 +41,6 @@ fn transform_range_before_linear_transformation(
|
||||
if range.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if stats.min_value > *range.end() {
|
||||
return None;
|
||||
}
|
||||
if stats.max_value < *range.start() {
|
||||
return None;
|
||||
}
|
||||
let shifted_range =
|
||||
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
|
||||
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
|
||||
|
||||
@@ -3,7 +3,8 @@ use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
|
||||
use common::{ByteCount, DateTime, OwnedBytes};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::column::{BytesColumn, Column, StrColumn};
|
||||
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
|
||||
@@ -317,10 +318,89 @@ impl DynamicColumnHandle {
|
||||
}
|
||||
|
||||
pub fn num_bytes(&self) -> ByteCount {
|
||||
self.file_slice.len().into()
|
||||
self.file_slice.num_bytes()
|
||||
}
|
||||
|
||||
/// Legacy helper returning the column space usage.
|
||||
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
self.space_usage()
|
||||
}
|
||||
|
||||
/// Return the space usage of the column, optionally broken down by dictionary and column
|
||||
/// values.
|
||||
///
|
||||
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
|
||||
/// the dictionary and the remaining column data (including index and values).
|
||||
/// For all other column types, the dictionary size is `None` and the column size
|
||||
/// equals the total bytes.
|
||||
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
let total_num_bytes = self.num_bytes();
|
||||
let dynamic_column = self.open()?;
|
||||
let dictionary_num_bytes = match &dynamic_column {
|
||||
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
|
||||
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
|
||||
_ => {
|
||||
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
|
||||
}
|
||||
};
|
||||
assert!(dictionary_num_bytes <= total_num_bytes);
|
||||
let column_num_bytes =
|
||||
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
|
||||
Ok(ColumnSpaceUsage::new(
|
||||
column_num_bytes,
|
||||
Some(dictionary_num_bytes),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn column_type(&self) -> ColumnType {
|
||||
self.column_type
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents space usage of a column.
|
||||
///
|
||||
/// `column_num_bytes` tracks the column payload (index, values and footer).
|
||||
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
|
||||
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnSpaceUsage {
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
}
|
||||
|
||||
impl ColumnSpaceUsage {
|
||||
pub(crate) fn new(
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
) -> Self {
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn column_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes
|
||||
}
|
||||
|
||||
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
|
||||
self.dictionary_num_bytes
|
||||
}
|
||||
|
||||
pub fn total_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Merge two space usage values by summing their components.
|
||||
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
|
||||
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
|
||||
(Some(val), None) | (None, Some(val)) => Some(val),
|
||||
(None, None) => None,
|
||||
};
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ pub use columnar::{
|
||||
use sstable::VoidSSTable;
|
||||
pub use value::{NumericalType, NumericalValue};
|
||||
|
||||
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
|
||||
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
|
||||
|
||||
pub type RowId = u32;
|
||||
pub type DocId = u32;
|
||||
|
||||
@@ -181,14 +181,6 @@ pub struct BitSet {
|
||||
len: u64,
|
||||
max_value: u32,
|
||||
}
|
||||
impl std::fmt::Debug for BitSet {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BitSet")
|
||||
.field("len", &self.len)
|
||||
.field("max_value", &self.max_value)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_buckets(max_val: u32) -> u32 {
|
||||
max_val.div_ceil(64u32)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
|
||||
use columnar::{Column, ColumnType, StrColumn};
|
||||
use common::BitSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::Serialize;
|
||||
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
|
||||
};
|
||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
use crate::aggregation::bucket::{
|
||||
build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
|
||||
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
|
||||
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
TermsAggregationInternal,
|
||||
};
|
||||
use crate::aggregation::metric::{
|
||||
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
|
||||
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
|
||||
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
|
||||
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
|
||||
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
@@ -35,7 +35,6 @@ pub struct AggregationsSegmentCtx {
|
||||
/// Request data for each aggregation type.
|
||||
pub per_request: PerRequestAggSegCtx,
|
||||
pub context: AggContextParams,
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
}
|
||||
|
||||
impl AggregationsSegmentCtx {
|
||||
@@ -108,14 +107,21 @@ impl AggregationsSegmentCtx {
|
||||
.as_deref()
|
||||
.expect("range_req_data slot is empty (taken)")
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
|
||||
self.per_request.filter_req_data[idx]
|
||||
.as_deref()
|
||||
.expect("filter_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- mutable getters ----------
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
|
||||
self.per_request.term_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_cardinality_req_data_mut(
|
||||
&mut self,
|
||||
@@ -123,7 +129,10 @@ impl AggregationsSegmentCtx {
|
||||
) -> &mut CardinalityAggReqData {
|
||||
&mut self.per_request.cardinality_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
|
||||
self.per_request.histogram_req_data[idx]
|
||||
@@ -133,6 +142,21 @@ impl AggregationsSegmentCtx {
|
||||
|
||||
// ---------- take / put (terms, histogram, range) ----------
|
||||
|
||||
/// Move out the boxed Terms request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
|
||||
self.per_request.term_req_data[idx]
|
||||
.take()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
/// Put back a Terms request into an empty slot at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
|
||||
debug_assert!(self.per_request.term_req_data[idx].is_none());
|
||||
self.per_request.term_req_data[idx] = Some(value);
|
||||
}
|
||||
|
||||
/// Move out the boxed Histogram request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
|
||||
@@ -296,7 +320,6 @@ impl PerRequestAggSegCtx {
|
||||
|
||||
/// Convert the aggregation tree into a serializable struct representation.
|
||||
/// Each node contains: { name, kind, children }.
|
||||
#[allow(dead_code)]
|
||||
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
|
||||
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
|
||||
let mut children: Vec<AggTreeViewNode> =
|
||||
@@ -322,19 +345,12 @@ impl PerRequestAggSegCtx {
|
||||
pub(crate) fn build_segment_agg_collectors_root(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
|
||||
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_agg_collectors(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, nodes)
|
||||
}
|
||||
|
||||
fn build_segment_agg_collectors_generic(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let mut collectors = Vec::new();
|
||||
for node in nodes.iter() {
|
||||
@@ -372,8 +388,6 @@ pub(crate) fn build_segment_agg_collector(
|
||||
Ok(Box::new(SegmentCardinalityCollector::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
)))
|
||||
}
|
||||
AggKind::StatsKind(stats_type) => {
|
||||
@@ -384,21 +398,20 @@ pub(crate) fn build_segment_agg_collector(
|
||||
| StatsType::Count
|
||||
| StatsType::Max
|
||||
| StatsType::Min
|
||||
| StatsType::Stats => build_segment_stats_collector(req_data),
|
||||
StatsType::ExtendedStats(sigma) => Ok(Box::new(
|
||||
SegmentExtendedStatsCollector::from_req(req_data, sigma),
|
||||
)),
|
||||
StatsType::Percentiles => {
|
||||
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
|
||||
Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(
|
||||
req_data.field_type,
|
||||
req_data.missing_u64,
|
||||
req_data.accessor.clone(),
|
||||
node.idx_in_req_data,
|
||||
),
|
||||
))
|
||||
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
node.idx_in_req_data,
|
||||
))),
|
||||
StatsType::ExtendedStats(sigma) => {
|
||||
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
|
||||
req_data.field_type,
|
||||
sigma,
|
||||
node.idx_in_req_data,
|
||||
req_data.missing,
|
||||
)))
|
||||
}
|
||||
StatsType::Percentiles => Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
AggKind::TopHits => {
|
||||
@@ -415,7 +428,9 @@ pub(crate) fn build_segment_agg_collector(
|
||||
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
|
||||
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
@@ -478,7 +493,6 @@ pub(crate) fn build_aggregations_data_from_req(
|
||||
let mut data = AggregationsSegmentCtx {
|
||||
per_request: Default::default(),
|
||||
context,
|
||||
column_block_accessor: ColumnBlockAccessor::default(),
|
||||
};
|
||||
|
||||
for (name, agg) in aggs.iter() {
|
||||
@@ -507,9 +521,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: range_req.clone(),
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -527,7 +541,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req.clone(),
|
||||
is_date_histogram: false,
|
||||
bounds: HistogramBounds {
|
||||
@@ -552,7 +568,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req,
|
||||
is_date_histogram: true,
|
||||
bounds: HistogramBounds {
|
||||
@@ -632,6 +650,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for,
|
||||
missing: *missing,
|
||||
@@ -659,6 +678,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for: StatsType::Percentiles,
|
||||
missing: percentiles_req.missing,
|
||||
@@ -875,7 +895,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
});
|
||||
}
|
||||
|
||||
// Add one node per accessor
|
||||
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
|
||||
for (accessor, column_type) in column_and_types {
|
||||
let missing_value_for_accessor = if use_special_missing_agg {
|
||||
None
|
||||
@@ -906,8 +926,11 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
// Will be filled later when building collectors
|
||||
sub_aggregation_blueprint: None,
|
||||
sug_aggregations: sub_aggs.clone(),
|
||||
allowed_term_ids,
|
||||
is_top_level,
|
||||
@@ -920,6 +943,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
});
|
||||
|
||||
@@ -2,441 +2,15 @@ use serde_json::Value;
|
||||
|
||||
use crate::aggregation::agg_req::{Aggregation, Aggregations};
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||
use crate::aggregation::collector::AggregationCollector;
|
||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||
use crate::aggregation::DistributedAggregationCollector;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, FAST};
|
||||
use crate::{Index, IndexWriter, Term};
|
||||
|
||||
// The following tests ensure that each bucket aggregation type correctly functions as a
|
||||
// sub-aggregation of another bucket aggregation in two scenarios:
|
||||
// 1) The parent has more buckets than the child sub-aggregation
|
||||
// 2) The child sub-aggregation has more buckets than the parent
|
||||
//
|
||||
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
|
||||
|
||||
#[test]
|
||||
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 4 buckets
|
||||
// Child: terms on text -> 2 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
// Exact expected structure and counts
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{
|
||||
"key": "*-3",
|
||||
"doc_count": 1,
|
||||
"to": 3.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "3-7",
|
||||
"doc_count": 3,
|
||||
"from": 3.0,
|
||||
"to": 7.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 2, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "7-20",
|
||||
"doc_count": 3,
|
||||
"from": 7.0,
|
||||
"to": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 3, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "20-*",
|
||||
"doc_count": 2,
|
||||
"from": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: histogram on score with large interval -> 1 bucket
|
||||
// Child: terms on text -> 2 buckets (cool/nohit)
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_hist": {
|
||||
"histogram": {"field": "score", "interval": 100.0},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_hist"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": 0.0,
|
||||
"doc_count": 9,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 7, "key": "cool"},
|
||||
{"doc_count": 2, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 5 buckets
|
||||
// Child: coarse range with 3 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0}
|
||||
]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: range with 4 buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several ranges
|
||||
// Child: histogram with large interval (single bucket per parent)
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text -> 2 buckets
|
||||
// Child: histogram with small interval -> multiple buckets including empties
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 4},
|
||||
{"key": 10.0, "doc_count": 2},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 1},
|
||||
{"key": 10.0, "doc_count": 0},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several buckets
|
||||
// Child: date_histogram with 30d -> single bucket per parent
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
|
||||
// Verify each parent bucket has exactly one child date bucket with matching doc_count
|
||||
for bucket in buckets {
|
||||
let parent_count = bucket["doc_count"].as_u64().unwrap();
|
||||
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(child_buckets.len(), 1);
|
||||
assert_eq!(child_buckets[0]["doc_count"], parent_count);
|
||||
}
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: date_histogram with 1d -> multiple buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
|
||||
|
||||
// cool bucket
|
||||
assert_eq!(buckets[0]["key"], "cool");
|
||||
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(cool_buckets.len(), 3);
|
||||
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
|
||||
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
|
||||
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
|
||||
|
||||
// nohit bucket
|
||||
assert_eq!(buckets[1]["key"], "nohit");
|
||||
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(nohit_buckets.len(), 2);
|
||||
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
|
||||
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_avg_req(field_name: &str) -> Aggregation {
|
||||
serde_json::from_value(json!({
|
||||
"avg": {
|
||||
@@ -451,10 +25,6 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
|
||||
// Note: The flushng part of these tests are outdated, since the buffering change after converting
|
||||
// the collection into one collector per request instead of per bucket.
|
||||
//
|
||||
// However they are useful as they test a complex aggregation requests.
|
||||
fn test_aggregation_flushing(
|
||||
merge_segments: bool,
|
||||
use_distributed_collector: bool,
|
||||
@@ -467,9 +37,8 @@ fn test_aggregation_flushing(
|
||||
|
||||
let reader = index.reader()?;
|
||||
|
||||
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
|
||||
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
|
||||
// block.
|
||||
assert_eq!(DOC_BLOCK_SIZE, 64);
|
||||
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
|
||||
//
|
||||
// Build a request so that on the first level we have one full cache, which is then flushed.
|
||||
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
|
||||
|
||||
@@ -6,12 +6,10 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
|
||||
use crate::docset::DocSet;
|
||||
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
|
||||
use crate::schema::Schema;
|
||||
@@ -412,9 +410,9 @@ impl FilterAggReqData {
|
||||
pub(crate) fn get_memory_consumption(&self) -> usize {
|
||||
// Estimate: name + segment reader reference + bitset + buffer capacity
|
||||
self.name.len()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,19 +489,12 @@ impl Debug for DocumentQueryEvaluator {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Copy)]
|
||||
struct DocCount {
|
||||
doc_count: u64,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// Segment collector for filter aggregation
|
||||
pub struct SegmentFilterCollector {
|
||||
/// Document counts per parent bucket
|
||||
parent_buckets: Vec<DocCount>,
|
||||
/// Document count in this bucket
|
||||
doc_count: u64,
|
||||
/// Sub-aggregation collectors
|
||||
sub_aggregations: Option<CachedSubAggs<true>>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// Accessor index for this filter aggregation (to access FilterAggReqData)
|
||||
accessor_idx: usize,
|
||||
}
|
||||
@@ -520,13 +511,11 @@ impl SegmentFilterCollector {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
|
||||
|
||||
Ok(SegmentFilterCollector {
|
||||
parent_buckets: Vec::new(),
|
||||
doc_count: 0,
|
||||
sub_aggregations: sub_agg_collector,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -534,41 +523,35 @@ impl SegmentFilterCollector {
|
||||
impl Debug for SegmentFilterCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentFilterCollector")
|
||||
.field("buckets", &self.parent_buckets)
|
||||
.field("doc_count", &self.doc_count)
|
||||
.field("has_sub_aggs", &self.sub_aggregations.is_some())
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl CollectorClone for SegmentFilterCollector {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
// For now, panic - this needs proper implementation with weight recreation
|
||||
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let mut sub_results = IntermediateAggregationResults::default();
|
||||
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
|
||||
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_results,
|
||||
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
|
||||
// exist, so that sub-aggregations can still produce results (e.g., zero doc
|
||||
// count)
|
||||
bucket_opt
|
||||
.map(|bucket| bucket.bucket_id)
|
||||
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
|
||||
)?;
|
||||
if let Some(sub_aggs) = self.sub_aggregations {
|
||||
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
|
||||
}
|
||||
|
||||
// Create the filter bucket result
|
||||
let filter_bucket_result = IntermediateBucketResult::Filter {
|
||||
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregations: sub_results,
|
||||
};
|
||||
|
||||
@@ -587,17 +570,32 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(
|
||||
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
// Access the evaluator from FilterAggReqData
|
||||
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
|
||||
|
||||
// O(1) BitSet lookup to check if document matches filter
|
||||
if req_data.evaluator.matches_document(doc) {
|
||||
self.doc_count += 1;
|
||||
|
||||
// If we have sub-aggregations, collect on them for this filtered document
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
docs: &[DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if docs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
|
||||
// Take the request data to avoid borrow checker issues with sub-aggregations
|
||||
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
|
||||
|
||||
@@ -606,24 +604,18 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
req.evaluator
|
||||
.filter_batch(docs, &mut req.matching_docs_buffer);
|
||||
|
||||
bucket.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
self.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
|
||||
// Batch process sub-aggregations if we have matches
|
||||
if !req.matching_docs_buffer.is_empty() {
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
for &doc_id in &req.matching_docs_buffer {
|
||||
sub_aggs.push(bucket.bucket_id, doc_id);
|
||||
}
|
||||
// Use collect_block for better sub-aggregation performance
|
||||
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Put the request data back
|
||||
agg_data.put_back_filter_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.check_flush_local(agg_data)?;
|
||||
}
|
||||
// put back bucket
|
||||
self.parent_buckets[parent_bucket_id as usize] = bucket;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -634,21 +626,6 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.parent_buckets.push(DocCount {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result for filter aggregation
|
||||
@@ -1542,9 +1519,9 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let agg = json!({
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy_bitpacker::minmax;
|
||||
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::BucketEntry;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateHistogramBucketEntry,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -26,8 +26,13 @@ pub struct HistogramAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The field type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
|
||||
/// Will be filled during initialization of the collector.
|
||||
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The histogram aggregation request.
|
||||
pub req: HistogramAggregation,
|
||||
/// True if this is a date_histogram aggregation.
|
||||
@@ -252,24 +257,18 @@ impl HistogramBounds {
|
||||
pub(crate) struct SegmentHistogramBucketEntry {
|
||||
pub key: f64,
|
||||
pub doc_count: u64,
|
||||
pub bucket_id: BucketId,
|
||||
}
|
||||
|
||||
impl SegmentHistogramBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
sub_aggregation: &mut Option<CachedSubAggs>,
|
||||
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateHistogramBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = sub_aggregation {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_aggregation_res,
|
||||
self.bucket_id,
|
||||
)?;
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
|
||||
}
|
||||
Ok(IntermediateHistogramBucketEntry {
|
||||
key: self.key,
|
||||
@@ -279,38 +278,27 @@ impl SegmentHistogramBucketEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct HistogramBuckets {
|
||||
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentHistogramCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
/// One Histogram bucket per parent bucket id.
|
||||
parent_buckets: Vec<HistogramBuckets>,
|
||||
sub_agg: Option<CachedSubAggs>,
|
||||
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
|
||||
accessor_idx: usize,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
// TODO: avoid prepare_max_bucket here and handle empty buckets.
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
|
||||
let bucket = self.into_intermediate_bucket_result(agg_data)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
@@ -319,40 +307,44 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
|
||||
|
||||
let bounds = req.bounds;
|
||||
let interval = req.req.interval;
|
||||
let offset = req.offset;
|
||||
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in agg_data
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let val = f64_from_fastfield_u64(val, req.field_type);
|
||||
let val = f64_from_fastfield_u64(val, &req.field_type);
|
||||
let bucket_pos = get_bucket_pos(val);
|
||||
if bounds.contains(val) {
|
||||
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
|
||||
SegmentHistogramBucketEntry {
|
||||
key,
|
||||
doc_count: 0,
|
||||
bucket_id: self.bucket_id_provider.next_bucket_id(),
|
||||
}
|
||||
SegmentHistogramBucketEntry { key, doc_count: 0 }
|
||||
});
|
||||
bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
|
||||
self.sub_aggregations
|
||||
.entry(bucket_pos)
|
||||
.or_insert_with(|| sub_aggregation_blueprint.clone())
|
||||
.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,30 +358,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
for sub_aggregation in self.sub_aggregations.values_mut() {
|
||||
sub_aggregation.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
self.parent_buckets.push(HistogramBuckets {
|
||||
buckets: FxHashMap::default(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -397,19 +373,22 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
impl SegmentHistogramCollector {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
|
||||
self_mem + buckets_mem
|
||||
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
|
||||
let buckets_mem = self.buckets.memory_consumption();
|
||||
self_mem + sub_aggs_mem + buckets_mem
|
||||
}
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
fn add_intermediate_bucket_result(
|
||||
&mut self,
|
||||
pub fn into_intermediate_bucket_result(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
histogram: HistogramBuckets,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut buckets = Vec::with_capacity(histogram.buckets.len());
|
||||
let mut buckets = Vec::with_capacity(self.buckets.len());
|
||||
|
||||
for bucket in histogram.buckets.into_values() {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
|
||||
for (bucket_pos, bucket) in self.buckets {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(
|
||||
self.sub_aggregations.get(&bucket_pos).cloned(),
|
||||
agg_data,
|
||||
);
|
||||
|
||||
buckets.push(bucket_res?);
|
||||
}
|
||||
@@ -429,7 +408,7 @@ impl SegmentHistogramCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
let blueprint = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
@@ -444,13 +423,13 @@ impl SegmentHistogramCollector {
|
||||
max: f64::MAX,
|
||||
});
|
||||
req_data.offset = req_data.req.offset.unwrap_or(0.0);
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
|
||||
req_data.sub_aggregation_blueprint = blueprint;
|
||||
|
||||
Ok(Self {
|
||||
parent_buckets: Default::default(),
|
||||
sub_agg,
|
||||
buckets: Default::default(),
|
||||
sub_aggregations: Default::default(),
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::AggregationLimitsGuard;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -25,12 +23,12 @@ pub struct RangeAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The range aggregation request.
|
||||
pub req: RangeAggregation,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// Whether this is a top-level aggregation.
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl RangeAggReqData {
|
||||
@@ -153,47 +151,19 @@ pub(crate) struct SegmentRangeAndBucketEntry {
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
pub struct SegmentRangeCollector<const LOWCARD: bool = false> {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentRangeCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
/// One for each ParentBucketId
|
||||
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
|
||||
buckets: Vec<SegmentRangeAndBucketEntry>,
|
||||
column_type: ColumnType,
|
||||
pub(crate) accessor_idx: usize,
|
||||
sub_agg: Option<CachedSubAggs<LOWCARD>>,
|
||||
/// Here things get a bit weird. We need to assign unique bucket ids across all
|
||||
/// parent buckets. So we keep track of the next available bucket id here.
|
||||
/// This allows a kind of flattening of the bucket ids across all parent buckets.
|
||||
/// E.g. in nested aggregations:
|
||||
/// Term Agg -> Range aggregation -> Stats aggregation
|
||||
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
|
||||
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
|
||||
/// - INFO: 0,1,2,3
|
||||
/// - ERROR: 4,5,6,7
|
||||
/// - WARN: 8,9,10,11
|
||||
///
|
||||
/// This allows the Stats aggregation to have unique bucket ids to refer to.
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
limits: AggregationLimitsGuard,
|
||||
}
|
||||
|
||||
impl<const LOWCARD: bool> Debug for SegmentRangeCollector<LOWCARD> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentRangeCollector")
|
||||
.field("parent_buckets_len", &self.parent_buckets.len())
|
||||
.field("column_type", &self.column_type)
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.field("has_sub_agg", &self.sub_agg.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SegmentRangeBucketEntry {
|
||||
pub key: Key,
|
||||
pub doc_count: u64,
|
||||
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
pub bucket_id: BucketId,
|
||||
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
|
||||
@@ -214,50 +184,48 @@ impl Debug for SegmentRangeBucketEntry {
|
||||
impl SegmentRangeBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateRangeBucketEntry> {
|
||||
let sub_aggregation = IntermediateAggregationResults::default();
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = self.sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
|
||||
Ok(IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation_res: sub_aggregation,
|
||||
sub_aggregation: sub_aggregation_res,
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<LOWCARD> {
|
||||
impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let field_type = self.column_type;
|
||||
let name = agg_data
|
||||
.get_range_req_data(self.accessor_idx)
|
||||
.name
|
||||
.to_string();
|
||||
|
||||
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
|
||||
.buckets
|
||||
.into_iter()
|
||||
.map(|range_bucket| {
|
||||
let bucket_id = range_bucket.bucket.bucket_id;
|
||||
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut agg.sub_aggregation_res,
|
||||
bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
|
||||
.map(move |range_bucket| {
|
||||
Ok((
|
||||
range_to_string(&range_bucket.range, &field_type)?,
|
||||
range_bucket
|
||||
.bucket
|
||||
.into_intermediate_bucket_entry(agg_data)?,
|
||||
))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
@@ -274,114 +242,73 @@ impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
// Take request data to avoid borrow conflicts during sub-aggregation
|
||||
let mut req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
|
||||
|
||||
for (doc, val) in agg_data
|
||||
for (doc, val) in req
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let bucket_pos = get_bucket_pos(val, buckets);
|
||||
let bucket = &mut buckets[bucket_pos];
|
||||
let bucket_pos = self.get_bucket_pos(val);
|
||||
let bucket = &mut self.buckets[bucket_pos];
|
||||
bucket.bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket.bucket_id, doc);
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
agg_data.put_back_range_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
for bucket in self.buckets.iter_mut() {
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let new_buckets = self.create_new_buckets(agg_data)?;
|
||||
self.parent_buckets.push(new_buckets);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
pub(crate) fn build_segment_range_collector(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
|
||||
let field_type = req_data.field_type;
|
||||
|
||||
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
|
||||
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
|
||||
// we can are still in low cardinality territory.
|
||||
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
|
||||
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_low_card {
|
||||
Ok(Box::new(SegmentRangeCollector {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::<true>::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
} else {
|
||||
Ok(Box::new(SegmentRangeCollector {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::<false>::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
|
||||
pub(crate) fn create_new_buckets(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
|
||||
let field_type = self.column_type;
|
||||
let req_data = agg_data.get_range_req_data(self.accessor_idx);
|
||||
impl SegmentRangeCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let (field_type, ranges) = {
|
||||
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
|
||||
(req_view.field_type, req_view.req.ranges.clone())
|
||||
};
|
||||
|
||||
// The range input on the request is f64.
|
||||
// We need to convert to u64 ranges, because we read the values as u64.
|
||||
// The mapping from the conversion is monotonic so ordering is preserved.
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
|
||||
let sub_agg_prototype = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(req_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
@@ -390,20 +317,20 @@ impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
// let sub_aggregation = sub_agg_prototype.clone();
|
||||
let sub_aggregation = sub_agg_prototype.clone();
|
||||
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
sub_aggregation,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
@@ -412,19 +339,26 @@ impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
self.limits.add_memory_consumed(
|
||||
req_data.context.limits.add_memory_consumed(
|
||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||
)?;
|
||||
Ok(buckets)
|
||||
|
||||
Ok(SegmentRangeCollector {
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_bucket_pos(&self, val: u64) -> usize {
|
||||
let pos = self
|
||||
.buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(self.buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
|
||||
let pos = buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
|
||||
/// Converts the user provided f64 range value to fast field value space.
|
||||
@@ -522,7 +456,7 @@ pub(crate) fn range_to_string(
|
||||
let val = i64::from_u64(val);
|
||||
format_date(val)
|
||||
} else {
|
||||
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
|
||||
Ok(f64_from_fastfield_u64(val, field_type).to_string())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -572,33 +506,30 @@ mod tests {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation: None,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
bucket_id: 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
SegmentRangeCollector {
|
||||
parent_buckets: vec![buckets],
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx: 0,
|
||||
sub_agg: None,
|
||||
bucket_id_provider: Default::default(),
|
||||
limits: AggregationLimitsGuard::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -845,7 +776,7 @@ mod tests {
|
||||
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -868,7 +799,7 @@ mod tests {
|
||||
];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -883,7 +814,7 @@ mod tests {
|
||||
let buckets = vec![(-10f64..-1f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
|
||||
}
|
||||
@@ -892,7 +823,7 @@ mod tests {
|
||||
let buckets = vec![(0f64..10f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
let buckets = collector.buckets;
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
|
||||
}
|
||||
@@ -901,7 +832,7 @@ mod tests {
|
||||
fn range_binary_search_test_u64() {
|
||||
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9), 0);
|
||||
@@ -947,7 +878,7 @@ mod tests {
|
||||
let ranges = vec![(10.0..100.0).into()];
|
||||
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9f64.to_u64()), 0);
|
||||
@@ -959,3 +890,63 @@ mod tests {
|
||||
// the max value
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
|
||||
|
||||
const TOTAL_DOCS: u64 = 1_000_000u64;
|
||||
const NUM_DOCS: u64 = 50_000u64;
|
||||
|
||||
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
|
||||
let bucket_size = num_docs / num_buckets;
|
||||
let mut buckets: Vec<RangeAggregationRange> = vec![];
|
||||
for i in 0..num_buckets {
|
||||
let bucket_start = (i * bucket_size) as f64;
|
||||
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
|
||||
}
|
||||
|
||||
get_collector_from_ranges(buckets, ColumnType::U64)
|
||||
}
|
||||
|
||||
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let all_docs = (0..total_docs - 1).collect_vec();
|
||||
let mut vals = all_docs
|
||||
.as_slice()
|
||||
.choose_multiple(&mut rng, num_docs_returned as usize)
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
vals.sort();
|
||||
vals
|
||||
}
|
||||
|
||||
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
|
||||
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
|
||||
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
|
||||
b.iter(|| {
|
||||
let mut bucket_pos = 0;
|
||||
for val in &vals {
|
||||
bucket_pos = collector.get_bucket_pos(*val);
|
||||
}
|
||||
bucket_pos
|
||||
})
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_100_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 100)
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_10_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 10)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,13 +5,11 @@ use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::term_agg::TermsAggregation;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
|
||||
/// Special aggregation to handle missing values for term aggregations.
|
||||
/// This missing aggregation will check multiple columns for existence.
|
||||
@@ -37,55 +35,41 @@ impl MissingTermAggReqData {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct MissingCount {
|
||||
missing_count: u32,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct TermMissingAgg {
|
||||
missing_count: u32,
|
||||
accessor_idx: usize,
|
||||
sub_agg: Option<CachedSubAggs>,
|
||||
/// Idx = parent bucket id, Value = missing count for that bucket
|
||||
missing_count_per_bucket: Vec<MissingCount>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
}
|
||||
impl TermMissingAgg {
|
||||
pub(crate) fn new(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let sub_agg = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
let bucket_id_provider = BucketIdProvider::default();
|
||||
|
||||
Ok(Self {
|
||||
accessor_idx,
|
||||
sub_agg,
|
||||
missing_count_per_bucket: Vec::new(),
|
||||
bucket_id_provider,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let term_agg = &req_data.req;
|
||||
let missing = term_agg
|
||||
@@ -96,16 +80,13 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
|
||||
Default::default();
|
||||
|
||||
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let mut missing_entry = IntermediateTermBucketEntry {
|
||||
doc_count: missing_count.missing_count,
|
||||
doc_count: self.missing_count,
|
||||
sub_aggregation: Default::default(),
|
||||
};
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
if let Some(sub_agg) = self.sub_agg {
|
||||
let mut res = IntermediateAggregationResults::default();
|
||||
sub_agg
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
|
||||
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
|
||||
missing_entry.sub_aggregation = res;
|
||||
}
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
@@ -128,52 +109,30 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
self.missing_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
|
||||
for doc in docs {
|
||||
let doc = *doc;
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
bucket.missing_count += 1;
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.missing_count_per_bucket.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.missing_count_per_bucket.push(MissingCount {
|
||||
missing_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
self.collect(*doc, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
87
src/aggregation/buf_collector.rs
Normal file
87
src/aggregation/buf_collector.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
|
||||
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BufAggregationCollector {
|
||||
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
|
||||
staged_docs: DocBlock,
|
||||
num_staged_docs: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufAggregationCollector {
|
||||
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
num_staged_docs: 0,
|
||||
staged_docs: [0; DOC_BLOCK_SIZE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.staged_docs[self.num_staged_docs] = doc;
|
||||
self.num_staged_docs += 1;
|
||||
if self.num_staged_docs == self.staged_docs.len() {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
|
||||
self.collector.flush(agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::DocId;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A cache for sub-aggregations, storing doc ids per bucket id.
|
||||
/// Depending on the cardinality of the parent aggregation, we use different
|
||||
/// storage strategies.
|
||||
///
|
||||
/// ## Low Cardinality
|
||||
/// Cardinality here refers to the number of unique flattened buckets that can be created
|
||||
/// by the parent aggregation.
|
||||
/// Flattened buckets are the result of combining all buckets per collector
|
||||
/// into a single list of buckets, where each bucket is identified by its BucketId.
|
||||
///
|
||||
/// ## Usage
|
||||
/// Since this is caching for sub-aggregations, it is only used by bucket
|
||||
/// aggregations.
|
||||
///
|
||||
/// TODO: consider using a more advanced data structure for high cardinality
|
||||
/// aggregations.
|
||||
/// What this datastructure does in general is to group docs by bucket id.
|
||||
pub(crate) struct CachedSubAggs<const LOWCARD: bool = false> {
|
||||
/// Only used when LOWCARD is true.
|
||||
/// Cache doc ids per bucket for sub-aggregations.
|
||||
///
|
||||
/// The outer Vec is indexed by BucketId.
|
||||
per_bucket_docs: Vec<Vec<DocId>>,
|
||||
/// Only used when LOWCARD is false.
|
||||
///
|
||||
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
|
||||
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
|
||||
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
|
||||
///
|
||||
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
|
||||
/// buckets, and they may be nothing to group.
|
||||
partitions: [PartitionEntry; NUM_PARTITIONS],
|
||||
pub(crate) sub_agg_collector: Box<dyn SegmentAggregationCollector>,
|
||||
num_docs: usize,
|
||||
}
|
||||
|
||||
const FLUSH_THRESHOLD: usize = 2048;
|
||||
const NUM_PARTITIONS: usize = 16;
|
||||
|
||||
impl<const LOWCARD: bool> CachedSubAggs<LOWCARD> {
|
||||
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
|
||||
&mut self.sub_agg_collector
|
||||
}
|
||||
|
||||
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
per_bucket_docs: Vec::new(),
|
||||
num_docs: 0,
|
||||
sub_agg_collector: sub_agg,
|
||||
partitions: core::array::from_fn(|_| PartitionEntry::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn clear(&mut self) {
|
||||
for v in &mut self.per_bucket_docs {
|
||||
v.clear();
|
||||
}
|
||||
for partition in &mut self.partitions {
|
||||
partition.clear();
|
||||
}
|
||||
self.num_docs = 0;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
if LOWCARD {
|
||||
// TODO: We could flush single buckets here
|
||||
let idx = bucket_id as usize;
|
||||
if self.per_bucket_docs.len() <= idx {
|
||||
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
|
||||
}
|
||||
self.per_bucket_docs[idx].push(doc_id);
|
||||
} else {
|
||||
let idx = bucket_id % NUM_PARTITIONS as u32;
|
||||
let slot = &mut self.partitions[idx as usize];
|
||||
slot.bucket_ids.push(bucket_id);
|
||||
slot.docs.push(doc_id);
|
||||
}
|
||||
self.num_docs += 1;
|
||||
}
|
||||
|
||||
/// Check if we need to flush based on the number of documents cached.
|
||||
/// If so, flushes the cache to the provided aggregation collector.
|
||||
pub fn check_flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.num_docs >= FLUSH_THRESHOLD {
|
||||
self.flush_local(agg_data, false)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this does _not_ flush the sub aggregations
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()> {
|
||||
if LOWCARD {
|
||||
// Pre-aggregated: call collect per bucket.
|
||||
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
|
||||
self.sub_agg_collector
|
||||
.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
// The threshold above which we flush buckets individually.
|
||||
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
|
||||
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
|
||||
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
|
||||
const _: () = {
|
||||
// MAX_NUM_TERMS_FOR_VEC == LOWCARD threshold
|
||||
let bucket_treshold = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
|
||||
assert!(
|
||||
bucket_treshold > 0,
|
||||
"Bucket threshold must be greater than 0"
|
||||
);
|
||||
};
|
||||
if force {
|
||||
bucket_treshold = 0;
|
||||
}
|
||||
for (bucket_id, docs) in self
|
||||
.per_bucket_docs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, docs)| docs.len() > bucket_treshold)
|
||||
{
|
||||
self.sub_agg_collector
|
||||
.collect(bucket_id as BucketId, docs, agg_data)?;
|
||||
}
|
||||
} else {
|
||||
let mut max_bucket = 0u32;
|
||||
for partition in &self.partitions {
|
||||
if let Some(&local_max) = partition.bucket_ids.iter().max() {
|
||||
max_bucket = max_bucket.max(local_max);
|
||||
}
|
||||
}
|
||||
|
||||
self.sub_agg_collector
|
||||
.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
|
||||
for slot in &self.partitions {
|
||||
if !slot.bucket_ids.is_empty() {
|
||||
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
|
||||
self.sub_agg_collector.collect_multiple(
|
||||
&slot.bucket_ids,
|
||||
&slot.docs,
|
||||
agg_data,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this _does_ flush the sub aggregations
|
||||
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if self.num_docs != 0 {
|
||||
self.flush_local(agg_data, true)?;
|
||||
}
|
||||
self.sub_agg_collector.flush(agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct PartitionEntry {
|
||||
bucket_ids: Vec<BucketId>,
|
||||
docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl PartitionEntry {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
self.bucket_ids.clear();
|
||||
self.docs.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::agg_req::Aggregations;
|
||||
use super::agg_result::AggregationResults;
|
||||
use super::cached_sub_aggs::CachedSubAggs;
|
||||
use super::buf_collector::BufAggregationCollector;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use super::AggContextParams;
|
||||
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
|
||||
use crate::aggregation::agg_data::{
|
||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||
};
|
||||
@@ -136,7 +136,7 @@ fn merge_fruits(
|
||||
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
|
||||
pub struct AggregationSegmentCollector {
|
||||
aggs_with_accessor: AggregationsSegmentCtx,
|
||||
agg_collector: CachedSubAggs<true>,
|
||||
agg_collector: BufAggregationCollector,
|
||||
error: Option<TantivyError>,
|
||||
}
|
||||
|
||||
@@ -151,10 +151,8 @@ impl AggregationSegmentCollector {
|
||||
) -> crate::Result<Self> {
|
||||
let mut agg_data =
|
||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||
let mut result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
result
|
||||
.get_sub_agg_collector()
|
||||
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
|
||||
let result =
|
||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
|
||||
Ok(AggregationSegmentCollector {
|
||||
aggs_with_accessor: agg_data,
|
||||
@@ -172,31 +170,26 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
self.agg_collector.push(0, doc);
|
||||
match self
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.check_flush_local(&mut self.aggs_with_accessor)
|
||||
.collect(doc, &mut self.aggs_with_accessor)
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
self.error = Some(err);
|
||||
}
|
||||
}
|
||||
|
||||
/// The query pushes the documents to the collector via this method.
|
||||
///
|
||||
/// Only valid for Collectors that ignore docs
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
match self.agg_collector.get_sub_agg_collector().collect(
|
||||
0,
|
||||
docs,
|
||||
&mut self.aggs_with_accessor,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.collect_block(docs, &mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,13 +200,10 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
|
||||
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
self.agg_collector
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
0,
|
||||
)?;
|
||||
Box::new(self.agg_collector).add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
)?;
|
||||
|
||||
Ok(sub_aggregation_res)
|
||||
}
|
||||
|
||||
@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
|
||||
/// The number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
/// The sub_aggregation in this bucket.
|
||||
pub sub_aggregation_res: IntermediateAggregationResults,
|
||||
pub sub_aggregation: IntermediateAggregationResults,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`.
|
||||
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: self
|
||||
.sub_aggregation_res
|
||||
.sub_aggregation
|
||||
.into_final_result_internal(req, limits)?,
|
||||
to: self.to,
|
||||
from: self.from,
|
||||
@@ -857,8 +857,7 @@ impl MergeFruits for IntermediateTermBucketEntry {
|
||||
impl MergeFruits for IntermediateRangeBucketEntry {
|
||||
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
|
||||
self.doc_count += other.doc_count;
|
||||
self.sub_aggregation_res
|
||||
.merge_fruits(other.sub_aggregation_res)?;
|
||||
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -888,7 +887,7 @@ mod tests {
|
||||
IntermediateRangeBucketEntry {
|
||||
key: IntermediateKey::Str(key.to_string()),
|
||||
doc_count: *doc_count,
|
||||
sub_aggregation_res: Default::default(),
|
||||
sub_aggregation: Default::default(),
|
||||
from: None,
|
||||
to: None,
|
||||
},
|
||||
@@ -921,7 +920,7 @@ mod tests {
|
||||
doc_count: *doc_count,
|
||||
from: None,
|
||||
to: None,
|
||||
sub_aggregation_res: get_sub_test_tree(&[(
|
||||
sub_aggregation: get_sub_test_tree(&[(
|
||||
sub_aggregation_key.to_string(),
|
||||
*sub_aggregation_count,
|
||||
)]),
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateAverage {
|
||||
|
||||
impl IntermediateAverage {
|
||||
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
@@ -106,6 +106,8 @@ pub struct CardinalityAggReqData {
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_value_for_accessor: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The aggregation request.
|
||||
@@ -133,34 +135,45 @@ impl CardinalityAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
buckets: Vec<SegmentCardinalityCollectorBucket>,
|
||||
accessor_idx: usize,
|
||||
/// The column accessor to access the fast field values.
|
||||
accessor: Column<u64>,
|
||||
/// The column_type of the field.
|
||||
column_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Default)]
|
||||
pub(crate) struct SegmentCardinalityCollectorBucket {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
impl SegmentCardinalityCollectorBucket {
|
||||
pub fn new(column_type: ColumnType) -> Self {
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: FxHashSet::default(),
|
||||
entries: Default::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut CardinalityAggReqData,
|
||||
) {
|
||||
if let Some(missing) = agg_data.missing_value_for_accessor {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_data.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
req_data: &CardinalityAggReqData,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = req_data
|
||||
@@ -181,7 +194,6 @@ impl SegmentCardinalityCollectorBucket {
|
||||
term_ids.push(term_ord as u32);
|
||||
}
|
||||
}
|
||||
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
@@ -215,49 +227,16 @@ impl SegmentCardinalityCollectorBucket {
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
accessor: Column<u64>,
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
|
||||
column_type,
|
||||
accessor_idx,
|
||||
accessor,
|
||||
missing_value_for_accessor,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_value_for_accessor,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
let name = req_data.name.to_string();
|
||||
// take the bucket in buckets and replace it with a new empty one
|
||||
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -268,20 +247,27 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.fetch_block_with_field(docs, agg_data);
|
||||
let bucket = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
|
||||
self.fetch_block_with_field(docs, req_data);
|
||||
|
||||
let col_block_accessor = &agg_data.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
let col_block_accessor = &req_data.column_block_accessor;
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
bucket.entries.insert(term_ord);
|
||||
self.entries.insert(term_ord);
|
||||
}
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
} else if req_data.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = req_data
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
@@ -296,29 +282,16 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if max_bucket as usize >= self.buckets.len() {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
SegmentCardinalityCollectorBucket::new(self.column_type)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateCount {
|
||||
|
||||
impl IntermediateCount {
|
||||
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateCount) {
|
||||
|
||||
@@ -8,9 +8,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of extended statistics
|
||||
/// on numeric values that are extracted
|
||||
@@ -317,28 +318,51 @@ impl IntermediateExtendedStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentExtendedStatsCollector {
|
||||
name: String,
|
||||
missing: Option<u64>,
|
||||
field_type: ColumnType,
|
||||
accessor: columnar::Column<u64>,
|
||||
buckets: Vec<IntermediateExtendedStats>,
|
||||
sigma: Option<f64>,
|
||||
pub(crate) extended_stats: IntermediateExtendedStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
}
|
||||
|
||||
impl SegmentExtendedStatsCollector {
|
||||
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
|
||||
let missing = req
|
||||
.missing
|
||||
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
|
||||
pub fn from_req(
|
||||
field_type: ColumnType,
|
||||
sigma: Option<f64>,
|
||||
accessor_idx: usize,
|
||||
missing: Option<f64>,
|
||||
) -> Self {
|
||||
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
Self {
|
||||
name: req.name.clone(),
|
||||
field_type: req.field_type,
|
||||
accessor: req.accessor.clone(),
|
||||
field_type,
|
||||
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
|
||||
accessor_idx,
|
||||
missing,
|
||||
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
|
||||
sigma,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,18 +370,15 @@ impl SegmentExtendedStatsCollector {
|
||||
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = self.name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
|
||||
extended_stats,
|
||||
self.extended_stats,
|
||||
)),
|
||||
)?;
|
||||
|
||||
@@ -367,36 +388,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block_with_missing(docs, &self.accessor, self.missing);
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
extended_stats.collect(val1);
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = self.missing {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.extended_stats
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
// store back
|
||||
self.buckets[parent_bucket_id as usize] = extended_stats;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
IntermediateExtendedStats::with_sigma(self.sigma)
|
||||
});
|
||||
}
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateMax {
|
||||
|
||||
impl IntermediateMax {
|
||||
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMax) {
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateMin {
|
||||
|
||||
impl IntermediateMin {
|
||||
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMin) {
|
||||
|
||||
@@ -31,7 +31,7 @@ use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
use columnar::{Column, ColumnType};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
@@ -55,6 +55,8 @@ pub struct MetricAggReqData {
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// Used when converting to intermediate result
|
||||
|
||||
@@ -7,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// # Percentiles
|
||||
///
|
||||
@@ -130,16 +131,10 @@ impl PercentilesAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentPercentilesCollector {
|
||||
pub(crate) buckets: Vec<PercentilesCollector>,
|
||||
pub(crate) percentiles: PercentilesCollector,
|
||||
pub(crate) accessor_idx: usize,
|
||||
/// The type of the field.
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -234,18 +229,33 @@ impl PercentilesCollector {
|
||||
}
|
||||
|
||||
impl SegmentPercentilesCollector {
|
||||
pub fn from_req_and_validate(
|
||||
field_type: ColumnType,
|
||||
missing_u64: Option<u64>,
|
||||
accessor: Column<u64>,
|
||||
accessor_idx: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: Vec::with_capacity(64),
|
||||
field_type,
|
||||
missing_u64,
|
||||
accessor,
|
||||
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
percentiles: PercentilesCollector::new(),
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,18 +263,12 @@ impl SegmentPercentilesCollector {
|
||||
impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
// Swap collector with an empty one to avoid cloning
|
||||
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_metric_result =
|
||||
IntermediateMetricResult::Percentiles(percentiles_collector);
|
||||
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
|
||||
|
||||
results.push(
|
||||
name,
|
||||
@@ -277,33 +281,40 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let percentiles = &mut self.buckets[parent_bucket_id as usize];
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
percentiles.collect(val1);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.percentiles
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.push(PercentilesCollector::new());
|
||||
}
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
@@ -8,9 +7,10 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
/// are extracted from the aggregated documents.
|
||||
@@ -83,7 +83,7 @@ impl Stats {
|
||||
|
||||
/// Intermediate result of the stats aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateStats {
|
||||
/// The number of extracted values.
|
||||
pub(crate) count: u64,
|
||||
@@ -187,75 +187,75 @@ pub enum StatsType {
|
||||
Percentiles,
|
||||
}
|
||||
|
||||
fn create_collector<const TYPE_ID: u8>(
|
||||
req: &MetricAggReqData,
|
||||
) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(SegmentStatsCollector::<TYPE_ID> {
|
||||
name: req.name.clone(),
|
||||
collecting_for: req.collecting_for,
|
||||
is_number_or_date_type: req.is_number_or_date_type,
|
||||
missing_u64: req.missing_u64,
|
||||
accessor: req.accessor.clone(),
|
||||
buckets: vec![IntermediateStats::default()],
|
||||
})
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
}
|
||||
|
||||
/// Build a concrete `SegmentStatsCollector` depending on the column type.
|
||||
pub(crate) fn build_segment_stats_collector(
|
||||
req: &MetricAggReqData,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match req.field_type {
|
||||
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
|
||||
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
|
||||
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
|
||||
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
|
||||
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
|
||||
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
|
||||
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
|
||||
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
if req_data.is_number_or_date_type {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in req_data.column_block_accessor.iter_vals() {
|
||||
// we ignore the value and simply record that we got something
|
||||
self.stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
|
||||
pub(crate) missing_u64: Option<u64>,
|
||||
pub(crate) accessor: Column<u64>,
|
||||
pub(crate) is_number_or_date_type: bool,
|
||||
pub(crate) buckets: Vec<IntermediateStats>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) collecting_for: StatsType,
|
||||
}
|
||||
|
||||
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
for SegmentStatsCollector<COLUMN_TYPE_ID>
|
||||
{
|
||||
impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = self.name.clone();
|
||||
let req = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let name = req.name.clone();
|
||||
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let stats = self.buckets[parent_bucket_id as usize];
|
||||
let intermediate_metric_result = match self.collecting_for {
|
||||
let intermediate_metric_result = match req.collecting_for {
|
||||
StatsType::Average => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
|
||||
}
|
||||
StatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
|
||||
}
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
|
||||
_ => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Unsupported stats type for stats aggregation: {:?}",
|
||||
self.collecting_for
|
||||
req.collecting_for
|
||||
)))
|
||||
}
|
||||
};
|
||||
@@ -271,67 +271,41 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.stats
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
// TODO: remove once we fetch all values for all bucket ids in one go
|
||||
if docs.len() == 1 && self.missing_u64.is_none() {
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
self.accessor.values_for_doc(docs[0]),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
agg_data.column_block_accessor.iter_vals(),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let required_buckets = (max_bucket as usize) + 1;
|
||||
if self.buckets.len() < required_buckets {
|
||||
self.buckets
|
||||
.resize_with(required_buckets, IntermediateStats::default);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_stats<const COLUMN_TYPE_ID: u8>(
|
||||
stats: &mut IntermediateStats,
|
||||
vals: impl Iterator<Item = u64>,
|
||||
is_number_or_date_type: bool,
|
||||
) -> crate::Result<()> {
|
||||
if is_number_or_date_type {
|
||||
for val in vals {
|
||||
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
|
||||
stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in vals {
|
||||
// we ignore the value and simply record that we got something
|
||||
stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -52,8 +52,10 @@ pub struct IntermediateSum {
|
||||
|
||||
impl IntermediateSum {
|
||||
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
|
||||
@@ -15,11 +15,12 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::{AggregationError, BucketId};
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
// duplicate import removed; already imported above
|
||||
|
||||
/// Contains all information required by the TopHitsSegmentCollector to perform the
|
||||
/// top_hits aggregation on a segment.
|
||||
@@ -471,10 +472,7 @@ impl TopHitsTopNComputer {
|
||||
/// Create a new TopHitsCollector
|
||||
pub fn new(req: &TopHitsAggregationReq) -> Self {
|
||||
Self {
|
||||
top_n: TopNComputer::new_with_comparator(
|
||||
req.size + req.from.unwrap_or(0),
|
||||
ReverseComparator,
|
||||
),
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
req: req.clone(),
|
||||
}
|
||||
}
|
||||
@@ -520,8 +518,7 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
|
||||
num_hits: usize,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -530,29 +527,19 @@ impl TopHitsSegmentCollector {
|
||||
accessor_idx: usize,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
) -> Self {
|
||||
let num_hits = req.size + req.from.unwrap_or(0);
|
||||
Self {
|
||||
num_hits,
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
segment_ordinal,
|
||||
accessor_idx,
|
||||
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
|
||||
}
|
||||
}
|
||||
fn get_top_hits_computer(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
fn into_top_hits_collector(
|
||||
self,
|
||||
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
if parent_bucket_id as usize >= self.buckets.len() {
|
||||
return TopHitsTopNComputer::new(req);
|
||||
}
|
||||
let top_n = std::mem::replace(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
TopNComputer::new(0),
|
||||
);
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
let top_results = top_n.into_vec();
|
||||
let top_results = self.top_n.into_vec();
|
||||
|
||||
for res in top_results {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
@@ -567,24 +554,54 @@ impl TopHitsSegmentCollector {
|
||||
|
||||
top_hits_computer
|
||||
}
|
||||
|
||||
/// TODO add a specialized variant for a single sort field
|
||||
fn collect_with(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessors: &[(Column<u64>, ColumnType)],
|
||||
) -> crate::Result<()> {
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
|
||||
let value_accessors = &req_data.value_accessors;
|
||||
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
|
||||
parent_bucket_id,
|
||||
value_accessors,
|
||||
&req_data.req,
|
||||
));
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(
|
||||
self.into_top_hits_collector(value_accessors, &req_data.req),
|
||||
);
|
||||
results.push(
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -594,55 +611,24 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
/// TODO: Consider a caching layer to reduce the call overhead
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc_id: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let top_n = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
let req = &req_data.req;
|
||||
let accessors = &req_data.accessors;
|
||||
for doc_id in docs {
|
||||
let doc_id = *doc_id;
|
||||
// TODO: this is terrible, a new vec is allocated for every doc
|
||||
// We can fetch blocks instead
|
||||
// We don't need to store the order for every value
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
}
|
||||
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.buckets.resize(
|
||||
(max_bucket as usize) + 1,
|
||||
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
|
||||
);
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
// TODO: Consider getting fields with the column block accessor.
|
||||
for doc in docs {
|
||||
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -760,7 +746,7 @@ mod tests {
|
||||
],
|
||||
"from": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
@@ -889,7 +875,7 @@ mod tests {
|
||||
"mixed.*",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
|
||||
let collector = AggregationCollector::from_aggs(d, Default::default());
|
||||
|
||||
@@ -133,7 +133,7 @@ mod agg_limits;
|
||||
pub mod agg_req;
|
||||
pub mod agg_result;
|
||||
pub mod bucket;
|
||||
pub(crate) mod cached_sub_aggs;
|
||||
mod buf_collector;
|
||||
mod collector;
|
||||
mod date;
|
||||
mod error;
|
||||
@@ -162,19 +162,6 @@ use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
|
||||
/// A bucket id is a dense identifier for a bucket within an aggregation.
|
||||
/// It is used to index into a Vec that hold per-bucket data.
|
||||
///
|
||||
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
|
||||
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
|
||||
///
|
||||
/// This allows to have a single AggregationCollector instance per aggregation,
|
||||
/// that can handle multiple buckets efficiently.
|
||||
///
|
||||
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
|
||||
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
|
||||
pub type BucketId = u32;
|
||||
|
||||
/// Context parameters for aggregation execution
|
||||
///
|
||||
/// This struct holds shared resources needed during aggregation execution:
|
||||
@@ -348,37 +335,19 @@ impl Display for Key {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
|
||||
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
|
||||
val as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|
||||
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
|
||||
{
|
||||
i64::from_u64(val) as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
|
||||
f64::from_u64(val)
|
||||
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
|
||||
val as f64
|
||||
} else {
|
||||
panic!(
|
||||
"ColumnType ID {} cannot be converted to f64 metric",
|
||||
COLUMN_TYPE_ID
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
|
||||
///
|
||||
/// # Panics
|
||||
/// Only `u64`, `f64`, `date`, and `i64` are supported.
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
|
||||
match field_type {
|
||||
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
|
||||
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
|
||||
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
|
||||
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
|
||||
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
|
||||
_ => panic!("unexpected type {field_type:?}. This should not happen"),
|
||||
ColumnType::U64 => val as f64,
|
||||
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
|
||||
ColumnType::F64 => f64::from_u64(val),
|
||||
ColumnType::Bool => val as f64,
|
||||
_ => {
|
||||
panic!("unexpected type {field_type:?}. This should not happen")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,67 +8,25 @@ use std::fmt::Debug;
|
||||
pub(crate) use super::agg_limits::AggregationLimitsGuard;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Monotonically increasing provider of BucketIds.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BucketIdProvider(u32);
|
||||
impl BucketIdProvider {
|
||||
/// Get the next BucketId.
|
||||
pub fn next_bucket_id(&mut self) -> BucketId {
|
||||
let bucket_id = self.0;
|
||||
self.0 += 1;
|
||||
bucket_id
|
||||
}
|
||||
}
|
||||
|
||||
/// A SegmentAggregationCollector is used to collect aggregation results.
|
||||
pub trait SegmentAggregationCollector: Debug {
|
||||
pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Collect docs for multiple buckets in one call.
|
||||
/// Minimizes dynamic dispatch overhead when collecting many buckets.
|
||||
///
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect_multiple(
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
bucket_ids: &[BucketId],
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
debug_assert_eq!(bucket_ids.len(), docs.len());
|
||||
let mut start = 0;
|
||||
while start < bucket_ids.len() {
|
||||
let bucket_id = bucket_ids[start];
|
||||
let mut end = start + 1;
|
||||
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
|
||||
end += 1;
|
||||
}
|
||||
self.collect(bucket_id, &docs[start..end], agg_data)?;
|
||||
start = end;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the collector for collecting up to BucketId `max_bucket`.
|
||||
/// This is useful so we can split allocation ahead of time of collecting.
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
|
||||
@@ -78,7 +36,26 @@ pub trait SegmentAggregationCollector: Debug {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
|
||||
pub trait CollectorClone {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
|
||||
}
|
||||
|
||||
impl<T> CollectorClone for T
|
||||
where T: 'static + SegmentAggregationCollector + Clone
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn SegmentAggregationCollector> {
|
||||
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
|
||||
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
|
||||
/// and can provide specialized versions instead, that remove some of its overhead.
|
||||
@@ -96,13 +73,12 @@ impl Debug for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
&mut self,
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
for agg in &mut self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
|
||||
for agg in self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -110,13 +86,23 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.collect(parent_bucket_id, docs, agg_data)?;
|
||||
collector.collect_block(docs, agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -126,15 +112,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,48 @@
|
||||
mod order;
|
||||
mod sort_by_erased_type;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_erased_type::SortByErasedType;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
|
||||
// By spec, regardless of whether ascending or descending order was requested, in presence of a
|
||||
// tie, we sort by ascending doc id/doc address.
|
||||
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
|
||||
hits: &mut [ComparableDoc<TSortKey, D>],
|
||||
order: Order,
|
||||
) {
|
||||
if order.is_asc() {
|
||||
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
|
||||
} else {
|
||||
hits.sort_by(|l, r| {
|
||||
l.sort_key
|
||||
.cmp(&r.sort_key)
|
||||
.reverse() // This is descending
|
||||
.then(l.doc.cmp(&r.doc))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
|
||||
use crate::collector::sort_key::{
|
||||
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, QueryParser};
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
|
||||
|
||||
fn make_index() -> crate::Result<Index> {
|
||||
@@ -294,11 +317,9 @@ mod tests {
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByString::for_field("city"), city_order),
|
||||
));
|
||||
Ok(searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(f, doc)| (f, ids[&doc]))
|
||||
.collect())
|
||||
let results: Vec<((Score, Option<String>), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
@@ -323,6 +344,51 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, OwnedValue);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByErasedType::for_field("city"), city_order),
|
||||
));
|
||||
let results: Vec<((Score, OwnedValue), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
@@ -372,15 +438,10 @@ mod tests {
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = if order.is_desc() {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
} else {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
sort_hits(&mut comparable_docs, order);
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
|
||||
@@ -1,36 +1,116 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::MonotonicallyMappableToU64;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::schema::{OwnedValue, Schema};
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
match (lhs, rhs) {
|
||||
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
|
||||
(OwnedValue::Null, _) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Greater
|
||||
}
|
||||
}
|
||||
(_, OwnedValue::Null) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
Ordering::Less
|
||||
}
|
||||
}
|
||||
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
|
||||
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
|
||||
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
|
||||
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
|
||||
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
|
||||
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
|
||||
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
|
||||
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
|
||||
if *b < 0 {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
a.cmp(&(*b as u64))
|
||||
}
|
||||
}
|
||||
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
|
||||
if *a < 0 {
|
||||
Ordering::Less
|
||||
} else {
|
||||
(*a as u64).cmp(b)
|
||||
}
|
||||
}
|
||||
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(a, b) => {
|
||||
let ord = a.discriminant_value().cmp(&b.discriminant_value());
|
||||
// If the discriminant is equal, it's because a new type was added, but hasn't been
|
||||
// included in this `match` statement.
|
||||
assert!(
|
||||
ord != Ordering::Equal,
|
||||
"Unimplemented comparison for type of {a:?}, {b:?}"
|
||||
);
|
||||
ord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparator trait defining the order in which documents should be ordered.
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
}
|
||||
|
||||
/// With the natural comparator, the top k collector will return
|
||||
/// the top documents in decreasing order.
|
||||
/// Compare values naturally (e.g. 1 < 2).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first).
|
||||
///
|
||||
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
|
||||
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalComparator;
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order.
|
||||
/// A (partial) implementation of comparison for OwnedValue.
|
||||
///
|
||||
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
|
||||
/// first.
|
||||
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
|
||||
/// mismatched types. The one exception is Null, for which we do define all comparisons.
|
||||
impl Comparator<OwnedValue> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse (e.g. 2 < 1).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first).
|
||||
///
|
||||
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
|
||||
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
|
||||
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
|
||||
///
|
||||
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
|
||||
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
|
||||
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
|
||||
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
@@ -43,11 +123,15 @@ where NaturalComparator: Comparator<T>
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order, but considers None as having the lowest value.
|
||||
/// Compare values in reverse, but treating `None` as lower than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
|
||||
/// (e.g. `[Some(10), Some(20), None]`).
|
||||
///
|
||||
/// This is usually what is wanted when sorting by a field in an ascending order.
|
||||
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
|
||||
/// cheapest items first, and the items without a price at last.
|
||||
/// For instance, in an e-commerce website, if sorting by price ascending,
|
||||
/// the cheapest items would appear first, and items without a price would appear last.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct ReverseNoneIsLowerComparator;
|
||||
|
||||
@@ -107,6 +191,84 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values naturally, but treating `None` as higher than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first), but with `None` values appearing first
|
||||
/// (e.g. `[None, Some(20), Some(10)]`).
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalNoneIsHigherComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Greater,
|
||||
(Some(_), None) => Ordering::Less,
|
||||
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
|
||||
pub enum ComparatorEnum {
|
||||
@@ -115,8 +277,10 @@ pub enum ComparatorEnum {
|
||||
Natural,
|
||||
/// Reverse order (See [ReverseComparator])
|
||||
Reverse,
|
||||
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
|
||||
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
|
||||
ReverseNoneLower,
|
||||
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
|
||||
NaturalNoneHigher,
|
||||
}
|
||||
|
||||
impl From<Order> for ComparatorEnum {
|
||||
@@ -133,6 +297,7 @@ where
|
||||
ReverseNoneIsLowerComparator: Comparator<T>,
|
||||
NaturalComparator: Comparator<T>,
|
||||
ReverseComparator: Comparator<T>,
|
||||
NaturalNoneIsHigherComparator: Comparator<T>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
@@ -140,6 +305,7 @@ where
|
||||
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -322,11 +488,12 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
|
||||
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
|
||||
TSegmentSortKey: Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
type SegmentComparator = TComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
@@ -346,3 +513,55 @@ where
|
||||
.convert_segment_sort_key(sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::OwnedValue;
|
||||
|
||||
#[test]
|
||||
fn test_natural_none_is_higher() {
|
||||
let comp = NaturalNoneIsHigherComparator;
|
||||
let null = None;
|
||||
let v1 = Some(1_u64);
|
||||
let v2 = Some(2_u64);
|
||||
|
||||
// NaturalNoneIsGreaterComparator logic:
|
||||
// 1. Delegates to NaturalComparator for non-nulls.
|
||||
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
|
||||
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
|
||||
|
||||
// 2. Treats None (Null) as Greater than any value.
|
||||
// compare(None, Some(2)) should be Greater.
|
||||
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
|
||||
|
||||
// compare(Some(1), None) should be Less.
|
||||
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
|
||||
|
||||
// compare(None, None) should be Equal.
|
||||
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_ownedvalue_compare() {
|
||||
let u = OwnedValue::U64(10);
|
||||
let i = OwnedValue::I64(10);
|
||||
let f = OwnedValue::F64(10.0);
|
||||
|
||||
let nc = NaturalComparator;
|
||||
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
|
||||
|
||||
let u2 = OwnedValue::U64(11);
|
||||
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
|
||||
|
||||
let s = OwnedValue::Str("a".to_string());
|
||||
// Str < U64
|
||||
assert_eq!(nc.compare(&s, &u), Ordering::Less);
|
||||
// Str < I64
|
||||
assert_eq!(nc.compare(&s, &i), Ordering::Less);
|
||||
// Str < F64
|
||||
assert_eq!(nc.compare(&s, &f), Ordering::Less);
|
||||
}
|
||||
}
|
||||
|
||||
361
src/collector/sort_key/sort_by_erased_type.rs
Normal file
361
src/collector/sort_key/sort_by_erased_type.rs
Normal file
@@ -0,0 +1,361 @@
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64};
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastFieldNotAvailableError;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DateTime, DocId, Score};
|
||||
|
||||
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
|
||||
///
|
||||
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
|
||||
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
|
||||
/// use a SortKeyComputer implementation with a known-type at compile time.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SortByErasedType {
|
||||
/// Sort by a fast field
|
||||
Field(String),
|
||||
/// Sort by score
|
||||
Score,
|
||||
}
|
||||
|
||||
impl SortByErasedType {
|
||||
/// Creates a new sort key computer which will sort by the given fast field column, with type
|
||||
/// erasure.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
Self::Field(column_name.to_string())
|
||||
}
|
||||
|
||||
/// Creates a new sort key computer which will sort by score, with type erasure.
|
||||
pub fn for_score() -> Self {
|
||||
Self::Score
|
||||
}
|
||||
}
|
||||
|
||||
trait ErasedSegmentSortKeyComputer: Send + Sync {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
|
||||
}
|
||||
|
||||
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
|
||||
inner: C,
|
||||
converter: F,
|
||||
}
|
||||
|
||||
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
where
|
||||
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
|
||||
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
|
||||
{
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let val = self.inner.convert_segment_sort_key(sort_key);
|
||||
(self.converter)(val)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}
|
||||
|
||||
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
|
||||
Some(score_value.to_u64())
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
|
||||
OwnedValue::F64(f64::from_u64(score_value))
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByErasedType {
|
||||
type SortKey = OwnedValue;
|
||||
type Child = ErasedColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
matches!(self, Self::Score)
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
|
||||
Self::Field(column_name) => {
|
||||
let fast_fields = segment_reader.fast_fields();
|
||||
// TODO: We currently double-open the column to avoid relying on the implementation
|
||||
// details of `SortByString` or `SortByStaticFastValue`. Once
|
||||
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
|
||||
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
|
||||
// the column that we open here.
|
||||
let (_column, column_type) =
|
||||
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
|
||||
FastFieldNotAvailableError {
|
||||
field_name: column_name.to_owned(),
|
||||
}
|
||||
})?;
|
||||
|
||||
match column_type {
|
||||
ColumnType::Str => {
|
||||
let computer = SortByString::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<String>| {
|
||||
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<u64>| {
|
||||
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<i64>| {
|
||||
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<f64>| {
|
||||
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<bool>| {
|
||||
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<DateTime>| {
|
||||
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
column_type => {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {column_type:?}, which is not supported for \
|
||||
sorting by owned value yet.",
|
||||
column_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore,
|
||||
}),
|
||||
};
|
||||
Ok(ErasedColumnSegmentSortKeyComputer { inner })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErasedColumnSegmentSortKeyComputer {
|
||||
inner: Box<dyn ErasedSegmentSortKeyComputer>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
|
||||
type SortKey = OwnedValue;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
|
||||
self.inner.convert_segment_sort_key(segment_sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
|
||||
);
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("id"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_string() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(city_field => "tokyo")).unwrap();
|
||||
writer.add_document(doc!(city_field => "austin")).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("city"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Str("austin".to_string()),
|
||||
OwnedValue::Str("tokyo".to_string()),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_reverse() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_score() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let body_field = schema_builder.add_text_field("body", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(body_field => "a a")).unwrap();
|
||||
writer.add_document(doc!(body_field => "a")).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
|
||||
let query = query_parser.parse_query("a").unwrap();
|
||||
|
||||
// Sort by score descending (Natural)
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] > values[1]);
|
||||
|
||||
// Sort by score ascending (ReverseNoneLower)
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_score(),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {key:?}"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] < values[1]);
|
||||
}
|
||||
}
|
||||
@@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore {
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type SegmentSortKey = Score;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
|
||||
@@ -34,9 +34,7 @@ impl<T: FastValue> SortByStaticFastValue<T> {
|
||||
|
||||
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
|
||||
type Child = SortByFastValueSegmentSortKeyComputer<T>;
|
||||
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
|
||||
@@ -84,8 +82,8 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
|
||||
|
||||
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
|
||||
@@ -30,9 +30,7 @@ impl SortByString {
|
||||
|
||||
impl SortKeyComputer for SortByString {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type Child = ByStringColumnSegmentSortKeyComputer;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
@@ -50,8 +48,8 @@ pub struct ByStringColumnSegmentSortKeyComputer {
|
||||
|
||||
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
@@ -60,6 +58,8 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
let term_ord = term_ord_opt?;
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
@@ -12,13 +12,21 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
/// It is the segment local version of the [`SortKeyComputer`].
|
||||
pub trait SegmentSortKeyComputer: 'static {
|
||||
/// The final score being emitted.
|
||||
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
|
||||
type SortKey: 'static + Send + Sync + Clone;
|
||||
|
||||
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
|
||||
///
|
||||
/// It is typically small like a `u64`, and is meant to be converted
|
||||
/// to the final score at the end of the collection of the segment.
|
||||
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
|
||||
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Comparator type.
|
||||
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
|
||||
|
||||
/// Returns the segment sort key comparator.
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
Self::SegmentComparator::default()
|
||||
}
|
||||
|
||||
/// Computes the sort key for the given document and score.
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
|
||||
@@ -47,7 +55,7 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
NaturalComparator.compare(left, right)
|
||||
self.segment_comparator().compare(left, right)
|
||||
}
|
||||
|
||||
/// Implementing this method makes it possible to avoid computing
|
||||
@@ -81,7 +89,7 @@ pub trait SegmentSortKeyComputer: 'static {
|
||||
/// the sort key at a segment scale.
|
||||
pub trait SortKeyComputer: Sync {
|
||||
/// The sort key type.
|
||||
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
|
||||
type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
|
||||
/// Type of the associated [`SegmentSortKeyComputer`].
|
||||
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
|
||||
/// Comparator type.
|
||||
@@ -136,10 +144,7 @@ where
|
||||
HeadSortKeyComputer: SortKeyComputer,
|
||||
TailSortKeyComputer: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
);
|
||||
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
|
||||
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
|
||||
|
||||
type Comparator = (
|
||||
@@ -188,6 +193,11 @@ where
|
||||
TailSegmentSortKeyComputer::SegmentSortKey,
|
||||
);
|
||||
|
||||
type SegmentComparator = (
|
||||
HeadSegmentSortKeyComputer::SegmentComparator,
|
||||
TailSegmentSortKeyComputer::SegmentComparator,
|
||||
);
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
@@ -269,11 +279,12 @@ impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
|
||||
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
|
||||
where
|
||||
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
|
||||
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
NewScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
PreviousScore: 'static + Clone + Send + Sync,
|
||||
NewScore: 'static + Clone + Send + Sync,
|
||||
{
|
||||
type SortKey = NewScore;
|
||||
type SegmentSortKey = T::SegmentSortKey;
|
||||
type SegmentComparator = T::SegmentComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_key_computer.segment_sort_key(doc, score)
|
||||
@@ -463,6 +474,7 @@ where
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
|
||||
(self)(doc)
|
||||
|
||||
@@ -1,64 +1,22 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// It guarantees stable sorting: in case of a tie on the feature, the document
|
||||
/// address is used.
|
||||
///
|
||||
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
|
||||
/// should be reversed, which is useful for achieving for example largest-first
|
||||
/// semantics without having to wrap the feature in a `Reverse`.
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
|
||||
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is any type that implements `PartialOrd`.
|
||||
/// is a type which can be compared with a `Comparator<T>`.
|
||||
pub sort_key: T,
|
||||
/// The document address. In practice, this is any
|
||||
/// type that implements `PartialOrd`, and is guaranteed
|
||||
/// to be unique for each document.
|
||||
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
|
||||
pub doc: D,
|
||||
}
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
|
||||
for ComparableDoc<T, D, R>
|
||||
{
|
||||
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
|
||||
f.debug_struct("ComparableDoc")
|
||||
.field("feature", &self.sort_key)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
|
||||
#[inline]
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let by_feature = self
|
||||
.sort_key
|
||||
.partial_cmp(&other.sort_key)
|
||||
.map(|ord| if R { ord.reverse() } else { ord })
|
||||
.unwrap_or(Ordering::Equal);
|
||||
|
||||
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
|
||||
|
||||
// In case of a tie on the feature, we sort by ascending
|
||||
// `DocAddress` in order to ensure a stable sorting of the
|
||||
// documents.
|
||||
by_feature.then_with(lazy_by_doc_address)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other) == Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
|
||||
|
||||
@@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
|
||||
/// The theoretical complexity for collecting the top `K` out of `N` documents
|
||||
/// is `O(N + K)`.
|
||||
///
|
||||
/// This collector does not guarantee a stable sorting in case of a tie on the
|
||||
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
|
||||
/// like docid in case of score equality.
|
||||
/// Only then, it is suitable for pagination.
|
||||
/// This collector guarantees a stable sorting in case of a tie on the
|
||||
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
|
||||
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use tantivy::collector::TopDocs;
|
||||
@@ -325,7 +324,7 @@ impl TopDocs {
|
||||
sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static,
|
||||
) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>>
|
||||
where
|
||||
TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug,
|
||||
TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug,
|
||||
{
|
||||
TopBySortKeyCollector::new(sort_key_computer, self.doc_range())
|
||||
}
|
||||
@@ -446,7 +445,7 @@ where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
|
||||
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
|
||||
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
|
||||
SegmentSortKeyComputer<SortKey = TSortKey>,
|
||||
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
@@ -481,6 +480,7 @@ where
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
|
||||
(self.sort_key_fn)(doc, score)
|
||||
@@ -500,8 +500,13 @@ where
|
||||
///
|
||||
/// For TopN == 0, it will be relative expensive.
|
||||
///
|
||||
/// When using the natural comparator, the top N computer returns the top N elements in
|
||||
/// descending order, as expected for a top N.
|
||||
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
|
||||
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
|
||||
/// ascending order, regardless of the `Comparator` used for the `Score` type.
|
||||
///
|
||||
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
|
||||
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
|
||||
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(from = "TopNComputerDeser<Score, D, C>")]
|
||||
pub struct TopNComputer<Score, D, C> {
|
||||
@@ -580,6 +585,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
|
||||
c: &C,
|
||||
lhs: &ComparableDoc<TSortKey, D>,
|
||||
rhs: &ComparableDoc<TSortKey, D>,
|
||||
) -> std::cmp::Ordering {
|
||||
c.compare(&lhs.sort_key, &rhs.sort_key)
|
||||
.reverse() // Reverse here because we want top K.
|
||||
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
|
||||
// sort by doc id
|
||||
}
|
||||
|
||||
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
|
||||
where
|
||||
D: Ord,
|
||||
@@ -600,10 +617,13 @@ where
|
||||
|
||||
/// Push a new document to the top n.
|
||||
/// If the document is below the current threshold, it will be ignored.
|
||||
///
|
||||
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
|
||||
#[inline]
|
||||
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
|
||||
if let Some(last_median) = &self.threshold {
|
||||
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
|
||||
// See the struct docs for an explanation of why this comparison is strict.
|
||||
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -629,9 +649,7 @@ where
|
||||
fn truncate_top_n(&mut self) -> TSortKey {
|
||||
// Use select_nth_unstable to find the top nth score
|
||||
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
|
||||
self.comparator
|
||||
.compare(&rhs.sort_key, &lhs.sort_key)
|
||||
.then_with(|| lhs.doc.cmp(&rhs.doc))
|
||||
compare_for_top_k(&self.comparator, lhs, rhs)
|
||||
});
|
||||
|
||||
let median_score = median_el.sort_key.clone();
|
||||
@@ -646,11 +664,8 @@ where
|
||||
if self.buffer.len() > self.top_n {
|
||||
self.truncate_top_n();
|
||||
}
|
||||
self.buffer.sort_unstable_by(|left, right| {
|
||||
self.comparator
|
||||
.compare(&right.sort_key, &left.sort_key)
|
||||
.then_with(|| left.doc.cmp(&right.doc))
|
||||
});
|
||||
self.buffer
|
||||
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
|
||||
self.buffer
|
||||
}
|
||||
|
||||
@@ -755,6 +770,33 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topn_computer_duplicates() {
|
||||
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
|
||||
TopNComputer::new_with_comparator(2, NaturalComparator);
|
||||
|
||||
computer.push(1u32, 1u32);
|
||||
computer.push(1u32, 2u32);
|
||||
computer.push(1u32, 3u32);
|
||||
computer.push(1u32, 4u32);
|
||||
computer.push(1u32, 5u32);
|
||||
|
||||
// In the presence of duplicates, DocIds are always ascending order.
|
||||
assert_eq!(
|
||||
computer.into_sorted_vec(),
|
||||
&[
|
||||
ComparableDoc {
|
||||
sort_key: 1u32,
|
||||
doc: 1u32,
|
||||
},
|
||||
ComparableDoc {
|
||||
sort_key: 1u32,
|
||||
doc: 2u32,
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topn_computer_no_panic() {
|
||||
for top_n in 0..10 {
|
||||
@@ -772,14 +814,17 @@ mod tests {
|
||||
#[test]
|
||||
fn test_topn_computer_asc_prop(
|
||||
limit in 0..10_usize,
|
||||
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
|
||||
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
|
||||
) {
|
||||
// NB: TopNComputer must receive inputs in ascending DocId order.
|
||||
docs.sort_by_key(|(_, doc_id)| *doc_id);
|
||||
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
|
||||
for (feature, doc) in &docs {
|
||||
computer.push(*feature, *doc);
|
||||
}
|
||||
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
|
||||
comparable_docs.sort();
|
||||
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
|
||||
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
|
||||
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
|
||||
comparable_docs.truncate(limit);
|
||||
prop_assert_eq!(
|
||||
computer.into_sorted_vec(),
|
||||
@@ -1406,15 +1451,10 @@ mod tests {
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = if order.is_desc() {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
} else {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
|
||||
@@ -48,15 +48,7 @@ impl Executor {
|
||||
F: Sized + Sync + Fn(A) -> crate::Result<R>,
|
||||
{
|
||||
match self {
|
||||
Executor::SingleThread => {
|
||||
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
|
||||
// harder.
|
||||
let mut result = Vec::with_capacity(args.size_hint().0);
|
||||
for arg in args {
|
||||
result.push(f(arg)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
|
||||
Executor::ThreadPool(pool) => {
|
||||
let args: Vec<A> = args.collect();
|
||||
let num_fruits = args.len();
|
||||
|
||||
@@ -406,7 +406,7 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_str("red");
|
||||
|
||||
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
|
||||
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -416,8 +416,8 @@ mod tests {
|
||||
term.append_type_and_fast_value(-4i64);
|
||||
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -428,8 +428,8 @@ mod tests {
|
||||
term.append_type_and_fast_value(4u64);
|
||||
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -439,8 +439,8 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_fast_value(4.0f64);
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -450,8 +450,8 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_fast_value(true);
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::ops::Range;
|
||||
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
|
||||
|
||||
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
|
||||
use crate::schema::Field;
|
||||
use crate::schema::{Field, Schema};
|
||||
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
|
||||
|
||||
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
|
||||
@@ -167,10 +167,11 @@ impl CompositeFile {
|
||||
.map(|byte_range| self.data.slice(byte_range.clone()))
|
||||
}
|
||||
|
||||
pub fn space_usage(&self) -> PerFieldSpaceUsage {
|
||||
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
|
||||
let mut fields = Vec::new();
|
||||
for (&field_addr, byte_range) in &self.offsets_index {
|
||||
let mut field_usage = FieldUsage::empty(field_addr.field);
|
||||
let field_name = schema.get_field_name(field_addr.field).to_string();
|
||||
let mut field_usage = FieldUsage::empty(field_name);
|
||||
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
|
||||
fields.push(field_usage);
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
mod file_watcher;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
@@ -7,6 +9,7 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock, Weak};
|
||||
|
||||
use common::StableDeref;
|
||||
use file_watcher::FileWatcher;
|
||||
use fs4::fs_std::FileExt;
|
||||
#[cfg(all(feature = "mmap", unix))]
|
||||
pub use memmap2::Advice;
|
||||
@@ -18,7 +21,6 @@ use crate::core::META_FILEPATH;
|
||||
use crate::directory::error::{
|
||||
DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError,
|
||||
};
|
||||
use crate::directory::file_watcher::FileWatcher;
|
||||
use crate::directory::{
|
||||
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
|
||||
WatchCallback, WatchHandle, WritePtr,
|
||||
@@ -5,7 +5,6 @@ mod mmap_directory;
|
||||
|
||||
mod directory;
|
||||
mod directory_lock;
|
||||
mod file_watcher;
|
||||
pub mod footer;
|
||||
mod managed_directory;
|
||||
mod ram_directory;
|
||||
|
||||
@@ -40,6 +40,8 @@ pub trait DocSet: Send {
|
||||
/// of `DocSet` should support it.
|
||||
///
|
||||
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
|
||||
///
|
||||
/// `target` has to be larger or equal to `.doc()` when calling `seek`.
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
let mut doc = self.doc();
|
||||
debug_assert!(doc <= target);
|
||||
@@ -49,6 +51,33 @@ pub trait DocSet: Send {
|
||||
doc
|
||||
}
|
||||
|
||||
/// Seeks to the target if possible and returns true if the target is in the DocSet.
|
||||
///
|
||||
/// DocSets that already have an efficient `seek` method don't need to implement
|
||||
/// `seek_into_the_danger_zone`. All wrapper DocSets should forward
|
||||
/// `seek_into_the_danger_zone` to the underlying DocSet.
|
||||
///
|
||||
/// ## API Behaviour
|
||||
/// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target.
|
||||
/// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc
|
||||
/// between the last doc that matched and target or a doc that is a valid next hit after
|
||||
/// target. The DocSet is considered to be in an invalid state until
|
||||
/// `seek_into_the_danger_zone` returns true again.
|
||||
///
|
||||
/// `target` needs to be equal or larger than `doc` when in a valid state.
|
||||
///
|
||||
/// Consecutive calls are not allowed to have decreasing `target` values.
|
||||
///
|
||||
/// # Warning
|
||||
/// This is an advanced API used by intersection. The API contract is tricky, avoid using it.
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
let current_doc = self.doc();
|
||||
if current_doc < target {
|
||||
self.seek(target);
|
||||
}
|
||||
self.doc() == target
|
||||
}
|
||||
|
||||
/// Fills a given mutable buffer with the next doc ids from the
|
||||
/// `DocSet`
|
||||
///
|
||||
@@ -94,6 +123,15 @@ pub trait DocSet: Send {
|
||||
/// which would be the number of documents in the DocSet.
|
||||
///
|
||||
/// By default this returns `size_hint()`.
|
||||
///
|
||||
/// DocSets may have vastly different cost depending on their type,
|
||||
/// e.g. an intersection with 10 hits is much cheaper than
|
||||
/// a phrase search with 10 hits, since it needs to load positions.
|
||||
///
|
||||
/// ### Future Work
|
||||
/// We may want to differentiate `DocSet` costs more more granular, e.g.
|
||||
/// creation_cost, advance_cost, seek_cost on to get a good estimation
|
||||
/// what query types to choose.
|
||||
fn cost(&self) -> u64 {
|
||||
self.size_hint() as u64
|
||||
}
|
||||
@@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet {
|
||||
(**self).seek(target)
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
(**self).seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn doc(&self) -> u32 {
|
||||
(**self).doc()
|
||||
}
|
||||
@@ -169,6 +211,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
|
||||
unboxed.seek(target)
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||
unboxed.seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
|
||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||
unboxed.fill_buffer(buffer)
|
||||
|
||||
@@ -8,7 +8,7 @@ use columnar::{
|
||||
};
|
||||
use common::ByteCount;
|
||||
|
||||
use crate::core::json_utils::encode_column_name;
|
||||
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
|
||||
use crate::directory::FileSlice;
|
||||
use crate::schema::{Field, FieldEntry, FieldType, Schema};
|
||||
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
|
||||
@@ -39,19 +39,15 @@ impl FastFieldReaders {
|
||||
self.resolve_column_name_given_default_field(column_name, default_field_opt)
|
||||
}
|
||||
|
||||
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
|
||||
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
|
||||
let mut per_field_usages: Vec<FieldUsage> = Default::default();
|
||||
for (field, field_entry) in schema.fields() {
|
||||
let column_handles = self.columnar.read_columns(field_entry.name())?;
|
||||
let num_bytes: ByteCount = column_handles
|
||||
.iter()
|
||||
.map(|column_handle| column_handle.num_bytes())
|
||||
.sum();
|
||||
let mut field_usage = FieldUsage::empty(field);
|
||||
field_usage.add_field_idx(0, num_bytes);
|
||||
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
|
||||
json_path_sep_to_dot(&mut field_name);
|
||||
let space_usage = column_handle.space_usage()?;
|
||||
let mut field_usage = FieldUsage::empty(field_name);
|
||||
field_usage.set_column_usage(space_usage);
|
||||
per_field_usages.push(field_usage);
|
||||
}
|
||||
// TODO fix space usage for JSON fields.
|
||||
Ok(PerFieldSpaceUsage::new(per_field_usages))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use super::{fieldnorm_to_id, id_to_fieldnorm};
|
||||
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
|
||||
use crate::schema::Field;
|
||||
use crate::schema::{Field, Schema};
|
||||
use crate::space_usage::PerFieldSpaceUsage;
|
||||
use crate::DocId;
|
||||
|
||||
@@ -37,8 +37,8 @@ impl FieldNormReaders {
|
||||
}
|
||||
|
||||
/// Return a break down of the space usage per field.
|
||||
pub fn space_usage(&self) -> PerFieldSpaceUsage {
|
||||
self.data.space_usage()
|
||||
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
|
||||
self.data.space_usage(schema)
|
||||
}
|
||||
|
||||
/// Returns a handle to inner file
|
||||
|
||||
@@ -13,9 +13,9 @@ use crate::store::Compressor;
|
||||
use crate::{Inventory, Opstamp, TrackedObject};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct DeleteMeta {
|
||||
pub struct DeleteMeta {
|
||||
num_deleted_docs: u32,
|
||||
opstamp: Opstamp,
|
||||
pub opstamp: Opstamp,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -213,7 +213,7 @@ impl SegmentMeta {
|
||||
struct InnerSegmentMeta {
|
||||
segment_id: SegmentId,
|
||||
max_doc: u32,
|
||||
deletes: Option<DeleteMeta>,
|
||||
pub deletes: Option<DeleteMeta>,
|
||||
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
|
||||
/// garbage collection and deleted, set this to true. This is used during merge.
|
||||
#[serde(skip)]
|
||||
@@ -404,7 +404,10 @@ mod tests {
|
||||
schema_builder.build()
|
||||
};
|
||||
let index_metas = IndexMeta {
|
||||
index_settings: IndexSettings::default(),
|
||||
index_settings: IndexSettings {
|
||||
docstore_compression: Compressor::None,
|
||||
..Default::default()
|
||||
},
|
||||
segments: Vec::new(),
|
||||
schema,
|
||||
opstamp: 0u64,
|
||||
@@ -413,7 +416,7 @@ mod tests {
|
||||
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
|
||||
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
|
||||
);
|
||||
|
||||
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
|
||||
@@ -494,6 +497,8 @@ mod tests {
|
||||
#[test]
|
||||
#[cfg(feature = "lz4-compression")]
|
||||
fn test_index_settings_default() {
|
||||
use crate::store::Compressor;
|
||||
|
||||
let mut index_settings = IndexSettings::default();
|
||||
assert_eq!(
|
||||
index_settings,
|
||||
|
||||
@@ -46,7 +46,7 @@ impl Segment {
|
||||
///
|
||||
/// This method is only used when updating `max_doc` from 0
|
||||
/// as we finalize a fresh new segment.
|
||||
pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment {
|
||||
pub fn with_max_doc(self, max_doc: u32) -> Segment {
|
||||
Segment {
|
||||
index: self.index,
|
||||
meta: self.meta.with_max_doc(max_doc),
|
||||
|
||||
@@ -455,11 +455,11 @@ impl SegmentReader {
|
||||
pub fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
|
||||
Ok(SegmentSpaceUsage::new(
|
||||
self.num_docs(),
|
||||
self.termdict_composite.space_usage(),
|
||||
self.postings_composite.space_usage(),
|
||||
self.positions_composite.space_usage(),
|
||||
self.fast_fields_readers.space_usage(self.schema())?,
|
||||
self.fieldnorm_readers.space_usage(),
|
||||
self.termdict_composite.space_usage(self.schema()),
|
||||
self.postings_composite.space_usage(self.schema()),
|
||||
self.positions_composite.space_usage(self.schema()),
|
||||
self.fast_fields_readers.space_usage()?,
|
||||
self.fieldnorm_readers.space_usage(self.schema()),
|
||||
self.get_store_reader(0)?.space_usage(),
|
||||
self.alive_bitset_opt
|
||||
.as_ref()
|
||||
|
||||
@@ -4,38 +4,37 @@ use std::sync::{Arc, RwLock, Weak};
|
||||
use super::operation::DeleteOperation;
|
||||
use crate::Opstamp;
|
||||
|
||||
// The DeleteQueue is similar in conceptually to a multiple
|
||||
// consumer single producer broadcast channel.
|
||||
//
|
||||
// All consumer will receive all messages.
|
||||
//
|
||||
// Consumer of the delete queue are holding a `DeleteCursor`,
|
||||
// which points to a specific place of the `DeleteQueue`.
|
||||
//
|
||||
// New consumer can be created in two ways
|
||||
// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete operation
|
||||
// (and some or none of the past operations... The client is in charge of checking the opstamps.).
|
||||
// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can
|
||||
// now advance independently from the original cursor.
|
||||
/// The DeleteQueue is similar in conceptually to a multiple
|
||||
/// consumer single producer broadcast channel.
|
||||
///
|
||||
/// All consumer will receive all messages.
|
||||
///
|
||||
/// Consumer of the delete queue are holding a `DeleteCursor`,
|
||||
/// which points to a specific place of the `DeleteQueue`.
|
||||
///
|
||||
/// New consumer can be created in two ways
|
||||
/// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete
|
||||
/// operation (and some or none of the past operations... The client is in charge of checking the
|
||||
/// opstamps.).
|
||||
/// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can
|
||||
/// now advance independently from the original cursor.
|
||||
#[derive(Default)]
|
||||
struct InnerDeleteQueue {
|
||||
writer: Vec<DeleteOperation>,
|
||||
last_block: Weak<Block>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The delete queue is a linked list storing delete operations.
|
||||
///
|
||||
/// Several consumers can hold a reference to it. Delete operations
|
||||
/// get dropped/gc'ed when no more consumers are holding a reference
|
||||
/// to them.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct DeleteQueue {
|
||||
inner: Arc<RwLock<InnerDeleteQueue>>,
|
||||
}
|
||||
|
||||
impl DeleteQueue {
|
||||
// Creates a new delete queue.
|
||||
pub fn new() -> DeleteQueue {
|
||||
DeleteQueue {
|
||||
inner: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_last_block(&self) -> Arc<Block> {
|
||||
{
|
||||
// try get the last block with simply acquiring the read lock.
|
||||
@@ -58,10 +57,10 @@ impl DeleteQueue {
|
||||
block
|
||||
}
|
||||
|
||||
// Creates a new cursor that makes it possible to
|
||||
// consume future delete operations.
|
||||
//
|
||||
// Past delete operations are not accessible.
|
||||
/// Creates a new cursor that makes it possible to
|
||||
/// consume future delete operations.
|
||||
///
|
||||
/// Past delete operations are not accessible.
|
||||
pub fn cursor(&self) -> DeleteCursor {
|
||||
let last_block = self.get_last_block();
|
||||
let operations_len = last_block.operations.len();
|
||||
@@ -71,7 +70,7 @@ impl DeleteQueue {
|
||||
}
|
||||
}
|
||||
|
||||
// Appends a new delete operations.
|
||||
/// Appends a new delete operations.
|
||||
pub fn push(&self, delete_operation: DeleteOperation) {
|
||||
self.inner
|
||||
.write()
|
||||
@@ -169,6 +168,7 @@ struct Block {
|
||||
next: NextBlock,
|
||||
}
|
||||
|
||||
/// As we process delete operations, keeps track of our position.
|
||||
#[derive(Clone)]
|
||||
pub struct DeleteCursor {
|
||||
block: Arc<Block>,
|
||||
@@ -261,7 +261,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_deletequeue() {
|
||||
let delete_queue = DeleteQueue::new();
|
||||
let delete_queue = DeleteQueue::default();
|
||||
|
||||
let make_op = |i: usize| DeleteOperation {
|
||||
opstamp: i as u64,
|
||||
|
||||
@@ -128,7 +128,7 @@ fn compute_deleted_bitset(
|
||||
/// is `==` target_opstamp.
|
||||
/// For instance, there was no delete operation between the state of the `segment_entry` and
|
||||
/// the `target_opstamp`, `segment_entry` is not updated.
|
||||
pub(crate) fn advance_deletes(
|
||||
pub fn advance_deletes(
|
||||
mut segment: Segment,
|
||||
segment_entry: &mut SegmentEntry,
|
||||
target_opstamp: Opstamp,
|
||||
@@ -303,7 +303,7 @@ impl<D: Document> IndexWriter<D> {
|
||||
let (document_sender, document_receiver) =
|
||||
crossbeam_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS);
|
||||
|
||||
let delete_queue = DeleteQueue::new();
|
||||
let delete_queue = DeleteQueue::default();
|
||||
|
||||
let current_opstamp = index.load_metas()?.opstamp;
|
||||
|
||||
|
||||
@@ -3,21 +3,21 @@ use std::net::Ipv6Addr;
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
|
||||
use crate::fastfield::FastValue;
|
||||
use crate::schema::{Field, Type};
|
||||
use crate::schema::Field;
|
||||
|
||||
/// Term represents the value that the token can take.
|
||||
/// It's a serialized representation over different types.
|
||||
/// IndexingTerm is used to represent a term during indexing.
|
||||
/// It's a serialized representation over field and value.
|
||||
///
|
||||
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
|
||||
/// 4 bytes are the field id, and the last byte is the type.
|
||||
/// It actually wraps a `Vec<u8>`. The first 4 bytes are the field.
|
||||
///
|
||||
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
|
||||
/// We serialize the field, because we index everything in a single
|
||||
/// global term dictionary during indexing.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
|
||||
where B: AsRef<[u8]>;
|
||||
|
||||
/// The number of bytes used as metadata by `Term`.
|
||||
const TERM_METADATA_LENGTH: usize = 5;
|
||||
const TERM_METADATA_LENGTH: usize = 4;
|
||||
|
||||
impl IndexingTerm {
|
||||
/// Create a new Term with a buffer with a given capacity.
|
||||
@@ -31,10 +31,9 @@ impl IndexingTerm {
|
||||
/// Use `clear_with_field_and_type` in that case.
|
||||
///
|
||||
/// Sets field and the type.
|
||||
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
|
||||
pub(crate) fn set_field(&mut self, field: Field) {
|
||||
assert!(self.is_empty());
|
||||
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
|
||||
self.0[4] = typ.to_code();
|
||||
}
|
||||
|
||||
/// Is empty if there are no value bytes.
|
||||
@@ -42,10 +41,10 @@ impl IndexingTerm {
|
||||
self.0.len() == TERM_METADATA_LENGTH
|
||||
}
|
||||
|
||||
/// Removes the value_bytes and set the field and type code.
|
||||
pub(crate) fn clear_with_field_and_type(&mut self, typ: Type, field: Field) {
|
||||
/// Removes the value_bytes and set the field
|
||||
pub(crate) fn clear_with_field(&mut self, field: Field) {
|
||||
self.truncate_value_bytes(0);
|
||||
self.set_field_and_type(field, typ);
|
||||
self.set_field(field);
|
||||
}
|
||||
|
||||
/// Sets a u64 value in the term.
|
||||
@@ -122,6 +121,23 @@ impl IndexingTerm {
|
||||
impl<B> IndexingTerm<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
/// Wraps serialized term bytes.
|
||||
///
|
||||
/// The input buffer is expected to be the concatenation of the big endian encoded field id
|
||||
/// followed by the serialized value bytes (type tag + payload).
|
||||
#[inline]
|
||||
pub fn wrap(serialized_term: B) -> IndexingTerm<B> {
|
||||
debug_assert!(serialized_term.as_ref().len() >= TERM_METADATA_LENGTH);
|
||||
IndexingTerm(serialized_term)
|
||||
}
|
||||
|
||||
/// Returns the field this term belongs to.
|
||||
#[inline]
|
||||
pub fn field(&self) -> Field {
|
||||
let field_id_bytes: [u8; 4] = self.0.as_ref()[..4].try_into().unwrap();
|
||||
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
|
||||
}
|
||||
|
||||
/// Returns the serialized representation of Term.
|
||||
/// This includes field_id, value type and value.
|
||||
///
|
||||
@@ -136,6 +152,7 @@ where B: AsRef<[u8]>
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::IndexingTerm;
|
||||
use crate::schema::*;
|
||||
|
||||
#[test]
|
||||
@@ -143,42 +160,55 @@ mod tests {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_text_field("text", STRING);
|
||||
let title_field = schema_builder.add_text_field("title", STRING);
|
||||
let term = Term::from_field_text(title_field, "test");
|
||||
let mut term = IndexingTerm::with_capacity(0);
|
||||
term.set_field(title_field);
|
||||
term.set_bytes(b"test");
|
||||
assert_eq!(term.field(), title_field);
|
||||
assert_eq!(term.typ(), Type::Str);
|
||||
assert_eq!(term.value().as_str(), Some("test"))
|
||||
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01test".to_vec())
|
||||
}
|
||||
|
||||
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
|
||||
/// <field> + <type byte> + <value len>
|
||||
///
|
||||
/// - <field> is a big endian encoded u32 field id
|
||||
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
|
||||
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
|
||||
/// type is the type of the leaf of the json.
|
||||
/// - <value> is, if this is not the json term, a binary representation specific to the type.
|
||||
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
|
||||
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
|
||||
const FAST_VALUE_TERM_LEN: usize = 4 + 8;
|
||||
|
||||
#[test]
|
||||
pub fn test_term_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let count_field = schema_builder.add_u64_field("count", INDEXED);
|
||||
let term = Term::from_field_u64(count_field, 983u64);
|
||||
let mut term = IndexingTerm::with_capacity(0);
|
||||
term.set_field(count_field);
|
||||
term.set_u64(983u64);
|
||||
assert_eq!(term.field(), count_field);
|
||||
assert_eq!(term.typ(), Type::U64);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.value().as_u64(), Some(983u64))
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_term_bool() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
|
||||
let term = Term::from_field_bool(bool_field, true);
|
||||
let term = {
|
||||
let mut term = IndexingTerm::with_capacity(0);
|
||||
term.set_field(bool_field);
|
||||
term.set_bool(true);
|
||||
term
|
||||
};
|
||||
assert_eq!(term.field(), bool_field);
|
||||
assert_eq!(term.typ(), Type::Bool);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.value().as_bool(), Some(true))
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn indexing_term_wrap_extracts_field() {
|
||||
let field = Field::from_field_id(7u32);
|
||||
let mut term = IndexingTerm::with_capacity(0);
|
||||
term.set_field(field);
|
||||
term.append_bytes(b"abc");
|
||||
|
||||
let wrapped = IndexingTerm::wrap(term.serialized_term());
|
||||
assert_eq!(wrapped.field(), field);
|
||||
assert_eq!(wrapped.serialized_term(), term.serialized_term());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//! `IndexWriter` is the main entry point for that, which created from
|
||||
//! [`Index::writer`](crate::Index::writer).
|
||||
|
||||
/// Delete queue implementation for broadcasting delete operations to consumers.
|
||||
pub(crate) mod delete_queue;
|
||||
pub(crate) mod path_to_unordered_id;
|
||||
|
||||
@@ -32,12 +33,11 @@ mod stamper;
|
||||
use crossbeam_channel as channel;
|
||||
use smallvec::SmallVec;
|
||||
|
||||
pub use self::index_writer::{IndexWriter, IndexWriterOptions};
|
||||
pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions};
|
||||
pub use self::log_merge_policy::LogMergePolicy;
|
||||
pub use self::merge_operation::MergeOperation;
|
||||
pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy};
|
||||
use self::operation::AddOperation;
|
||||
pub use self::operation::UserOperation;
|
||||
pub use self::operation::{AddOperation, DeleteOperation, UserOperation};
|
||||
pub use self::prepared_commit::PreparedCommit;
|
||||
pub use self::segment_entry::SegmentEntry;
|
||||
pub(crate) use self::segment_serializer::SegmentSerializer;
|
||||
|
||||
@@ -5,14 +5,20 @@ use crate::Opstamp;
|
||||
|
||||
/// Timestamped Delete operation.
|
||||
pub struct DeleteOperation {
|
||||
/// Operation stamp.
|
||||
/// It is used to check whether the delete operation
|
||||
/// applies to an added document operation.
|
||||
pub opstamp: Opstamp,
|
||||
/// Weight is used to define the set of documents to be deleted.
|
||||
pub target: Box<dyn Weight>,
|
||||
}
|
||||
|
||||
/// Timestamped Add operation.
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub struct AddOperation<D: Document = TantivyDocument> {
|
||||
/// Operation stamp.
|
||||
pub opstamp: Opstamp,
|
||||
/// Document to be added.
|
||||
pub document: D,
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_segment_register() {
|
||||
let inventory = SegmentMetaInventory::default();
|
||||
let delete_queue = DeleteQueue::new();
|
||||
let delete_queue = DeleteQueue::default();
|
||||
|
||||
let mut segment_register = SegmentRegister::default();
|
||||
let segment_id_a = SegmentId::generate_random();
|
||||
|
||||
@@ -171,7 +171,7 @@ impl SegmentWriter {
|
||||
let (term_buffer, ctx) = (&mut self.term_buffer, &mut self.ctx);
|
||||
let postings_writer: &mut dyn PostingsWriter =
|
||||
self.per_field_postings_writers.get_for_field_mut(field);
|
||||
term_buffer.clear_with_field_and_type(field_entry.field_type().value_type(), field);
|
||||
term_buffer.clear_with_field(field);
|
||||
|
||||
match field_entry.field_type() {
|
||||
FieldType::Facet(_) => {
|
||||
@@ -421,10 +421,9 @@ fn remap_and_write(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
|
||||
use columnar::ColumnType;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::directory::RamDirectory;
|
||||
@@ -1067,10 +1066,7 @@ mod tests {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_text_field("title", text_options);
|
||||
let schema = schema_builder.build();
|
||||
let tempdir = TempDir::new().unwrap();
|
||||
let tempdir_path = PathBuf::from(tempdir.path());
|
||||
Index::create_in_dir(&tempdir_path, schema).unwrap();
|
||||
let index = Index::open_in_dir(tempdir_path).unwrap();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let schema = index.schema();
|
||||
let mut index_writer = index.writer(50_000_000).unwrap();
|
||||
let title = schema.get_field("title").unwrap();
|
||||
|
||||
18
src/lib.rs
18
src/lib.rs
@@ -17,6 +17,7 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use std::path::Path;
|
||||
//! # use std::fs;
|
||||
//! # use tempfile::TempDir;
|
||||
//! # use tantivy::collector::TopDocs;
|
||||
//! # use tantivy::query::QueryParser;
|
||||
@@ -27,8 +28,11 @@
|
||||
//! # // Let's create a temporary directory for the
|
||||
//! # // sake of this example
|
||||
//! # if let Ok(dir) = TempDir::new() {
|
||||
//! # run_example(dir.path()).unwrap();
|
||||
//! # dir.close().unwrap();
|
||||
//! # let index_path = dir.path().join("index");
|
||||
//! # // In case the directory already exists, we remove it
|
||||
//! # let _ = fs::remove_dir_all(&index_path);
|
||||
//! # fs::create_dir_all(&index_path).unwrap();
|
||||
//! # run_example(&index_path).unwrap();
|
||||
//! # }
|
||||
//! # }
|
||||
//! #
|
||||
@@ -203,6 +207,7 @@ mod docset;
|
||||
mod reader;
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "mmap")]
|
||||
mod compat_tests;
|
||||
|
||||
pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer};
|
||||
@@ -216,9 +221,7 @@ use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
|
||||
#[doc(hidden)]
|
||||
pub use crate::core::json_utils;
|
||||
pub use crate::core::{Executor, Searcher, SearcherGeneration};
|
||||
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
|
||||
pub use crate::directory::Directory;
|
||||
pub use crate::index::{
|
||||
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,
|
||||
@@ -1172,12 +1175,11 @@ pub mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_validate_checksum() -> crate::Result<()> {
|
||||
let index_path = tempfile::tempdir().expect("dir");
|
||||
let mut builder = Schema::builder();
|
||||
let body = builder.add_text_field("body", TEXT | STORED);
|
||||
let schema = builder.build();
|
||||
let index = Index::create_in_dir(&index_path, schema)?;
|
||||
let mut writer: IndexWriter = index.writer(50_000_000)?;
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer: IndexWriter = index.writer_for_tests()?;
|
||||
writer.set_merge_policy(Box::new(NoMergePolicy));
|
||||
for _ in 0..5000 {
|
||||
writer.add_document(doc!(body => "foo"))?;
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use bitpacking::{BitPacker, BitPacker4x};
|
||||
use common::FixedSize;
|
||||
|
||||
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
|
||||
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES;
|
||||
// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to
|
||||
// store a u32 in the worst case, rounding up to 5 bytes total
|
||||
const MAX_VINT_SIZE: usize = 5;
|
||||
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE;
|
||||
|
||||
mod vint;
|
||||
|
||||
/// Returns the size in bytes of a compressed block, given `num_bits`.
|
||||
#[inline]
|
||||
pub fn compressed_block_size(num_bits: u8) -> usize {
|
||||
(num_bits as usize) * COMPRESSION_BLOCK_SIZE / 8
|
||||
}
|
||||
@@ -267,7 +270,6 @@ impl VIntDecoder for BlockDecoder {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::TERMINATED;
|
||||
|
||||
@@ -372,6 +374,13 @@ pub(crate) mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compress_vint_unsorted_does_not_overflow() {
|
||||
let mut encoder = BlockEncoder::new();
|
||||
let input: Vec<u32> = vec![u32::MAX; COMPRESSION_BLOCK_SIZE];
|
||||
encoder.compress_vint_unsorted(&input);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::indexer::path_to_unordered_id::OrderedPathId;
|
||||
use crate::postings::postings_writer::SpecializedPostingsWriter;
|
||||
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
|
||||
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
|
||||
use crate::schema::{Field, Type, ValueBytes};
|
||||
use crate::schema::{Field, Type};
|
||||
use crate::tokenizer::TokenStream;
|
||||
use crate::DocId;
|
||||
|
||||
@@ -79,8 +79,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
term_buffer.truncate(term_path_len);
|
||||
term_buffer.append_bytes(term);
|
||||
|
||||
let json_value = ValueBytes::wrap(term);
|
||||
let typ = json_value.typ();
|
||||
let typ = Type::from_code(term[0]).expect("Invalid type code in JSON term");
|
||||
if typ == Type::Str {
|
||||
SpecializedPostingsWriter::<Rec>::serialize_one_term(
|
||||
term_buffer.as_bytes(),
|
||||
@@ -107,6 +106,8 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build the JSON term bytes that land in the term dictionary.
|
||||
/// Format: `[json path utf8][JSON_END_OF_PATH][type tag][payload]`
|
||||
struct JsonTermSerializer(Vec<u8>);
|
||||
impl JsonTermSerializer {
|
||||
/// Appends a JSON path to the Term.
|
||||
|
||||
@@ -527,6 +527,7 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer> Scorer for UnoptimizedDocSet<TScorer> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.0.score()
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::postings::recorder::{BufferLender, Recorder};
|
||||
use crate::postings::{
|
||||
FieldSerializer, IndexingContext, InvertedIndexSerializer, PerFieldPostingsWriter,
|
||||
};
|
||||
use crate::schema::{Field, Schema, Term, Type};
|
||||
use crate::schema::{Field, Schema, Type};
|
||||
use crate::tokenizer::{Token, TokenStream, MAX_TOKEN_LEN};
|
||||
use crate::DocId;
|
||||
|
||||
@@ -59,14 +59,14 @@ pub(crate) fn serialize_postings(
|
||||
let mut term_offsets: Vec<(Field, OrderedPathId, &[u8], Addr)> =
|
||||
Vec::with_capacity(ctx.term_index.len());
|
||||
term_offsets.extend(ctx.term_index.iter().map(|(key, addr)| {
|
||||
let field = Term::wrap(key).field();
|
||||
let field = IndexingTerm::wrap(key).field();
|
||||
if schema.get_field_entry(field).field_type().value_type() == Type::Json {
|
||||
let byte_range_path = 5..5 + 4;
|
||||
let byte_range_path = 4..4 + 4;
|
||||
let unordered_id = u32::from_be_bytes(key[byte_range_path.clone()].try_into().unwrap());
|
||||
let path_id = unordered_id_to_ordered_id[unordered_id as usize];
|
||||
(field, path_id, &key[byte_range_path.end..], addr)
|
||||
} else {
|
||||
(field, 0.into(), &key[5..], addr)
|
||||
(field, 0.into(), &key[4..], addr)
|
||||
}
|
||||
}));
|
||||
// Sort by field, path, and term
|
||||
|
||||
@@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED};
|
||||
|
||||
// doc num bits uses the following encoding:
|
||||
// given 0b a b cdefgh
|
||||
// |1|2| 3 |
|
||||
// |1|2|3| 4 |
|
||||
// - 1: unused
|
||||
// - 2: is delta-1 encoded. 0 if not, 1, if yes
|
||||
// - 3: a 6 bit number in 0..=32, the actual bitwidth
|
||||
// - 3: unused
|
||||
// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32
|
||||
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
|
||||
// be represented on 31b without delta encoding.
|
||||
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
|
||||
assert!(bitwidth < 32);
|
||||
bitwidth | ((delta_1 as u8) << 6)
|
||||
}
|
||||
|
||||
fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) {
|
||||
let delta_1 = ((raw_bitwidth >> 6) & 1) != 0;
|
||||
let bitwidth = raw_bitwidth & 0x3f;
|
||||
let bitwidth = raw_bitwidth & 0x1f;
|
||||
(bitwidth, delta_1)
|
||||
}
|
||||
|
||||
@@ -430,7 +434,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_bitwidth() {
|
||||
for bitwidth in 0..=32 {
|
||||
for bitwidth in 0..32 {
|
||||
for delta_1 in [false, true] {
|
||||
assert_eq!(
|
||||
(bitwidth, delta_1),
|
||||
|
||||
@@ -23,7 +23,11 @@ pub struct AllWeight;
|
||||
impl Weight for AllWeight {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let all_scorer = AllScorer::new(reader.max_doc());
|
||||
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
|
||||
if boost != 1.0 {
|
||||
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
|
||||
} else {
|
||||
Ok(Box::new(all_scorer))
|
||||
}
|
||||
}
|
||||
|
||||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
||||
@@ -58,6 +62,15 @@ impl DocSet for AllScorer {
|
||||
self.doc
|
||||
}
|
||||
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
debug_assert!(target >= self.doc);
|
||||
self.doc = target;
|
||||
if self.doc >= self.max_doc {
|
||||
self.doc = TERMINATED;
|
||||
}
|
||||
self.doc
|
||||
}
|
||||
|
||||
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
|
||||
if self.doc() == TERMINATED {
|
||||
return 0;
|
||||
@@ -92,6 +105,7 @@ impl DocSet for AllScorer {
|
||||
}
|
||||
|
||||
impl Scorer for AllScorer {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
1.0
|
||||
}
|
||||
|
||||
@@ -483,7 +483,7 @@ mod tests {
|
||||
let checkpoints_for_each_pruning =
|
||||
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
|
||||
let checkpoints_manual =
|
||||
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
|
||||
compute_checkpoints_manual(term_scorers.clone(), top_k, max_doc as u32);
|
||||
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
|
||||
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
|
||||
.iter()
|
||||
|
||||
@@ -97,6 +97,65 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the effective MUST scorer, accounting for removed AllScorers.
|
||||
///
|
||||
/// When AllScorer instances are removed from must_scorers as an optimization,
|
||||
/// we must restore the "match all" semantics if the list becomes empty.
|
||||
fn effective_must_scorer(
|
||||
must_scorers: Vec<Box<dyn Scorer>>,
|
||||
removed_all_scorer_count: usize,
|
||||
max_doc: DocId,
|
||||
num_docs: u32,
|
||||
) -> Option<Box<dyn Scorer>> {
|
||||
if must_scorers.is_empty() {
|
||||
if removed_all_scorer_count > 0 {
|
||||
// Had AllScorer(s) only - all docs match
|
||||
Some(Box::new(AllScorer::new(max_doc)))
|
||||
} else {
|
||||
// No MUST constraint at all
|
||||
None
|
||||
}
|
||||
} else {
|
||||
Some(intersect_scorers(must_scorers, num_docs))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a SHOULD scorer with AllScorer union if any were removed.
|
||||
///
|
||||
/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result
|
||||
/// should include all documents. We restore this by unioning with AllScorer.
|
||||
///
|
||||
/// When `scoring_enabled` is false, we can just return AllScorer alone since
|
||||
/// we don't need score contributions from the should_scorer.
|
||||
fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
|
||||
should_scorer: SpecializedScorer,
|
||||
removed_all_scorer_count: usize,
|
||||
max_doc: DocId,
|
||||
num_docs: u32,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
scoring_enabled: bool,
|
||||
) -> SpecializedScorer {
|
||||
if removed_all_scorer_count > 0 {
|
||||
if scoring_enabled {
|
||||
// Need to union to get score contributions from both
|
||||
let all_scorers: Vec<Box<dyn Scorer>> = vec![
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||
Box::new(AllScorer::new(max_doc)),
|
||||
];
|
||||
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||
all_scorers,
|
||||
score_combiner_fn,
|
||||
num_docs,
|
||||
)))
|
||||
} else {
|
||||
// Scoring disabled - AllScorer alone is sufficient
|
||||
SpecializedScorer::Other(Box::new(AllScorer::new(max_doc)))
|
||||
}
|
||||
} else {
|
||||
should_scorer
|
||||
}
|
||||
}
|
||||
|
||||
enum ShouldScorersCombinationMethod {
|
||||
// Should scorers are irrelevant.
|
||||
Ignored,
|
||||
@@ -193,18 +252,18 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
|
||||
let minimum_number_should_match = self
|
||||
let effective_minimum_number_should_match = self
|
||||
.minimum_number_should_match
|
||||
.saturating_sub(should_special_scorer_counts.num_all_scorers);
|
||||
|
||||
let should_scorers: ShouldScorersCombinationMethod = {
|
||||
let num_of_should_scorers = should_scorers.len();
|
||||
if minimum_number_should_match > num_of_should_scorers {
|
||||
if effective_minimum_number_should_match > num_of_should_scorers {
|
||||
// We don't have enough scorers to satisfy the minimum number of should matches.
|
||||
// The request will match no documents.
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
match minimum_number_should_match {
|
||||
match effective_minimum_number_should_match {
|
||||
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
|
||||
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
|
||||
should_scorers,
|
||||
@@ -226,7 +285,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
scorer_disjunction(
|
||||
should_scorers,
|
||||
score_combiner_fn(),
|
||||
self.minimum_number_should_match,
|
||||
effective_minimum_number_should_match,
|
||||
),
|
||||
)),
|
||||
}
|
||||
@@ -246,53 +305,78 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
|
||||
let include_scorer = match (should_scorers, must_scorers) {
|
||||
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
|
||||
let boxed_scorer: Box<dyn Scorer> = if must_scorers.is_empty() {
|
||||
// We do not have any should scorers, nor all scorers.
|
||||
// There are still two cases here.
|
||||
//
|
||||
// If this follows the removal of some AllScorers in the should/must clauses,
|
||||
// then we match all documents.
|
||||
//
|
||||
// Otherwise, it is really just an EmptyScorer.
|
||||
if must_special_scorer_counts.num_all_scorers
|
||||
+ should_special_scorer_counts.num_all_scorers
|
||||
> 0
|
||||
{
|
||||
Box::new(AllScorer::new(reader.max_doc()))
|
||||
} else {
|
||||
Box::new(EmptyScorer)
|
||||
}
|
||||
} else {
|
||||
intersect_scorers(must_scorers, num_docs)
|
||||
};
|
||||
// No SHOULD clauses (or they were absorbed into MUST).
|
||||
// Result depends entirely on MUST + any removed AllScorers.
|
||||
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
|
||||
+ should_special_scorer_counts.num_all_scorers;
|
||||
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
|
||||
must_scorers,
|
||||
combined_all_scorer_count,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
)
|
||||
.unwrap_or_else(|| Box::new(EmptyScorer));
|
||||
SpecializedScorer::Other(boxed_scorer)
|
||||
}
|
||||
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
|
||||
if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 {
|
||||
// Optional options are promoted to required if no must scorers exists.
|
||||
should_scorer
|
||||
} else {
|
||||
let must_scorer = intersect_scorers(must_scorers, num_docs);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
_,
|
||||
_,
|
||||
TScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||
)))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
// Optional SHOULD: contributes to scoring but not required for matching.
|
||||
match effective_must_scorer(
|
||||
must_scorers,
|
||||
must_special_scorer_counts.num_all_scorers,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
) {
|
||||
None => {
|
||||
// No MUST constraint: promote SHOULD to required.
|
||||
// Must preserve any removed AllScorers from SHOULD via union.
|
||||
effective_should_scorer_for_union(
|
||||
should_scorer,
|
||||
should_special_scorer_counts.num_all_scorers,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
&score_combiner_fn,
|
||||
self.scoring_enabled,
|
||||
)
|
||||
}
|
||||
Some(must_scorer) => {
|
||||
// Has MUST constraint: SHOULD only affects scoring.
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
_,
|
||||
_,
|
||||
TScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||
)))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => {
|
||||
if must_scorers.is_empty() {
|
||||
should_scorer
|
||||
} else {
|
||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
(ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => {
|
||||
// Required SHOULD: at least `minimum_number_should_match` must match.
|
||||
// Semantics: (MUST constraint) AND (SHOULD constraint)
|
||||
match effective_must_scorer(
|
||||
must_scorers,
|
||||
must_special_scorer_counts.num_all_scorers,
|
||||
reader.max_doc(),
|
||||
num_docs,
|
||||
) {
|
||||
None => {
|
||||
// No MUST constraint: SHOULD alone determines matching.
|
||||
should_scorer
|
||||
}
|
||||
Some(must_scorer) => {
|
||||
// Has MUST constraint: intersect MUST with SHOULD.
|
||||
let should_boxed =
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs);
|
||||
SpecializedScorer::Other(intersect_scorers(
|
||||
vec![must_scorer, should_boxed],
|
||||
num_docs,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -9,12 +9,14 @@ pub use self::boolean_weight::BooleanWeight;
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::ops::Bound;
|
||||
|
||||
use super::*;
|
||||
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
|
||||
use crate::collector::TopDocs;
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::{
|
||||
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser,
|
||||
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery,
|
||||
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
|
||||
};
|
||||
use crate::schema::*;
|
||||
@@ -374,4 +376,466 @@ mod tests {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_min_should_match_with_all_query() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let num_field =
|
||||
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
|
||||
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
let effective_all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(num_field, 0)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "apple"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
|
||||
// in some previous version, we would remove the 2 all_match, but then say we need *4*
|
||||
// matches out of the 3 term queries, which matches nothing.
|
||||
let mut bool_query = BooleanQuery::new(vec![
|
||||
(Occur::Should, effective_all_match_query.box_clone()),
|
||||
(Occur::Should, effective_all_match_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
bool_query.set_minimum_number_should_match(4);
|
||||
let count = searcher.search(&bool_query, &Count)?;
|
||||
assert_eq!(count, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// AllScorer Preservation Regression Tests
|
||||
// =========================================================================
|
||||
//
|
||||
// These tests verify the fix for a bug where AllScorer instances (produced by
|
||||
// queries matching all documents, such as range queries covering all values)
|
||||
// were incorrectly removed from Boolean query processing, causing documents
|
||||
// to be unexpectedly excluded from results.
|
||||
//
|
||||
// The bug manifested in several scenarios:
|
||||
// 1. SHOULD + SHOULD where one clause is AllScorer
|
||||
// 2. MUST (AllScorer) + SHOULD
|
||||
// 3. Range queries in Boolean clauses when all documents match the range
|
||||
|
||||
/// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses.
|
||||
///
|
||||
/// When a SHOULD clause produces an AllScorer (e.g., from a range query matching
|
||||
/// all documents), the Boolean query should still match all documents.
|
||||
///
|
||||
/// Bug before fix: AllScorer was removed during optimization, leaving only the
|
||||
/// other SHOULD clauses, which incorrectly excluded documents.
|
||||
#[test]
|
||||
pub fn test_should_with_all_scorer_regression() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let num_field =
|
||||
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// All docs have num > 0, so range query will return AllScorer
|
||||
index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?;
|
||||
index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?;
|
||||
index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?;
|
||||
index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?;
|
||||
index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?;
|
||||
index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
// Range query matching all docs (returns AllScorer)
|
||||
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(num_field, 0)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "hello"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
|
||||
// Verify range matches all 6 docs
|
||||
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6);
|
||||
|
||||
// RangeQuery(all) OR TermQuery should match all 6 docs
|
||||
let bool_query = BooleanQuery::new(vec![
|
||||
(Occur::Should, all_match_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
let count = searcher.search(&bool_query, &Count)?;
|
||||
assert_eq!(count, 6, "SHOULD with AllScorer should match all docs");
|
||||
|
||||
// Order should not matter
|
||||
let bool_query_reversed = BooleanQuery::new(vec![
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
(Occur::Should, all_match_query.box_clone()),
|
||||
]);
|
||||
let count_reversed = searcher.search(&bool_query_reversed, &Count)?;
|
||||
assert_eq!(
|
||||
count_reversed, 6,
|
||||
"Order of SHOULD clauses should not matter"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Regression test: MUST clause with AllScorer combined with SHOULD clause.
|
||||
///
|
||||
/// When MUST contains an AllScorer, all documents satisfy the MUST constraint.
|
||||
/// The SHOULD clause should only affect scoring, not filtering.
|
||||
///
|
||||
/// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector.
|
||||
/// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents.
|
||||
#[test]
|
||||
pub fn test_must_all_with_should_regression() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let num_field =
|
||||
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// All docs have num > 0, so range query will return AllScorer
|
||||
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
|
||||
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
|
||||
index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?;
|
||||
index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
// Range query matching all docs (returns AllScorer)
|
||||
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(num_field, 0)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "apple"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
|
||||
// Verify range matches all 4 docs
|
||||
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4);
|
||||
|
||||
// MUST(range matching all) AND SHOULD(term) should match all 4 docs
|
||||
let bool_query = BooleanQuery::new(vec![
|
||||
(Occur::Must, all_match_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
let count = searcher.search(&bool_query, &Count)?;
|
||||
assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Regression test: Range queries in Boolean clauses when all documents match.
|
||||
///
|
||||
/// Range queries can return AllScorer as an optimization when all indexed values
|
||||
/// fall within the range. This test ensures such queries work correctly in
|
||||
/// Boolean combinations.
|
||||
///
|
||||
/// This is the most common real-world manifestation of the bug, occurring in
|
||||
/// queries like: (age > 50 OR name = 'Alice') AND status = 'active'
|
||||
/// when all documents have age > 50.
|
||||
#[test]
|
||||
pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let name_field = schema_builder.add_text_field("name", TEXT);
|
||||
let age_field =
|
||||
schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// All documents have age > 50, so range query will return AllScorer
|
||||
index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?;
|
||||
index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?;
|
||||
index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?;
|
||||
index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
let range_query: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(age_field, 50)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(name_field, "alice"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
|
||||
// Verify preconditions
|
||||
assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4);
|
||||
assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1);
|
||||
|
||||
// SHOULD(range) OR SHOULD(term): range matches all, so result is 4
|
||||
let should_query = BooleanQuery::new(vec![
|
||||
(Occur::Should, range_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
assert_eq!(
|
||||
searcher.search(&should_query, &Count)?,
|
||||
4,
|
||||
"SHOULD range OR term should match all"
|
||||
);
|
||||
|
||||
// MUST(range) AND SHOULD(term): range matches all, term is optional
|
||||
let must_should_query = BooleanQuery::new(vec![
|
||||
(Occur::Must, range_query.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
assert_eq!(
|
||||
searcher.search(&must_should_query, &Count)?,
|
||||
4,
|
||||
"MUST range + SHOULD term should match all"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test multiple AllScorer instances in different clause types.
|
||||
///
|
||||
/// Verifies correct behavior when AllScorers appear in multiple positions.
|
||||
#[test]
|
||||
pub fn test_multiple_all_scorers() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let num_field =
|
||||
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// All docs have num > 0, so range queries will return AllScorer
|
||||
index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?;
|
||||
index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?;
|
||||
index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?;
|
||||
index_writer.commit()?;
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
// Two different range queries that both match all docs (return AllScorer)
|
||||
let all_query1: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(num_field, 0)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let all_query2: Box<dyn Query> = Box::new(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(num_field, 5)),
|
||||
Bound::Unbounded,
|
||||
));
|
||||
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "doc1"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
|
||||
// Multiple AllScorers in SHOULD
|
||||
let multi_all_should = BooleanQuery::new(vec![
|
||||
(Occur::Should, all_query1.box_clone()),
|
||||
(Occur::Should, all_query2.box_clone()),
|
||||
(Occur::Should, term_query.box_clone()),
|
||||
]);
|
||||
assert_eq!(
|
||||
searcher.search(&multi_all_should, &Count)?,
|
||||
3,
|
||||
"Multiple AllScorers in SHOULD"
|
||||
);
|
||||
|
||||
// AllScorer in both MUST and SHOULD
|
||||
let all_must_and_should = BooleanQuery::new(vec![
|
||||
(Occur::Must, all_query1.box_clone()),
|
||||
(Occur::Should, all_query2.box_clone()),
|
||||
]);
|
||||
assert_eq!(
|
||||
searcher.search(&all_must_and_should, &Count)?,
|
||||
3,
|
||||
"AllScorer in both MUST and SHOULD"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches
|
||||
/// the result against an index which contains all permutations of documents with N fields.
|
||||
#[cfg(test)]
|
||||
mod proptest_boolean_query {
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::ops::{Bound, Range};
|
||||
|
||||
use proptest::collection::vec;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use crate::collector::DocSetCollector;
|
||||
use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery};
|
||||
use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT};
|
||||
use crate::{DocId, Index, Term};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum BooleanQueryAST {
|
||||
/// Matches all documents via AllQuery (wraps AllScorer in BoostScorer)
|
||||
All,
|
||||
/// Matches all documents via RangeQuery (returns bare AllScorer)
|
||||
/// This is the actual trigger for the AllScorer preservation bug
|
||||
RangeAll,
|
||||
/// Matches documents where the field has value "true"
|
||||
Leaf {
|
||||
field_idx: usize,
|
||||
},
|
||||
Union(Vec<BooleanQueryAST>),
|
||||
Intersection(Vec<BooleanQueryAST>),
|
||||
}
|
||||
|
||||
impl BooleanQueryAST {
|
||||
fn matches(&self, doc_id: DocId) -> bool {
|
||||
match self {
|
||||
BooleanQueryAST::All => true,
|
||||
BooleanQueryAST::RangeAll => true,
|
||||
BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx),
|
||||
BooleanQueryAST::Union(children) => {
|
||||
children.iter().any(|child| child.matches(doc_id))
|
||||
}
|
||||
BooleanQueryAST::Intersection(children) => {
|
||||
children.iter().all(|child| child.matches(doc_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_field(doc_id: DocId, field_idx: usize) -> bool {
|
||||
((doc_id as usize) >> field_idx) & 1 == 1
|
||||
}
|
||||
|
||||
fn to_query(&self, fields: &[Field], range_field: Field) -> Box<dyn Query> {
|
||||
match self {
|
||||
BooleanQueryAST::All => Box::new(AllQuery),
|
||||
BooleanQueryAST::RangeAll => {
|
||||
// Range query that matches all docs (all have value >= 0)
|
||||
// This returns bare AllScorer, triggering the bug we fixed
|
||||
Box::new(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_i64(range_field, 0)),
|
||||
Bound::Unbounded,
|
||||
))
|
||||
}
|
||||
BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new(
|
||||
Term::from_field_text(fields[*field_idx], "true"),
|
||||
crate::schema::IndexRecordOption::Basic,
|
||||
)),
|
||||
BooleanQueryAST::Union(children) => {
|
||||
let sub_queries = children
|
||||
.iter()
|
||||
.map(|child| (Occur::Should, child.to_query(fields, range_field)))
|
||||
.collect();
|
||||
Box::new(BooleanQuery::new(sub_queries))
|
||||
}
|
||||
BooleanQueryAST::Intersection(children) => {
|
||||
let sub_queries = children
|
||||
.iter()
|
||||
.map(|child| (Occur::Must, child.to_query(fields, range_field)))
|
||||
.collect();
|
||||
Box::new(BooleanQuery::new(sub_queries))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn doc_ids(num_docs: usize, num_fields: usize) -> Range<DocId> {
|
||||
let permutations = 1 << num_fields;
|
||||
let copies = (num_docs as f32 / permutations as f32).ceil() as u32;
|
||||
0..(permutations * copies)
|
||||
}
|
||||
|
||||
fn create_index_with_boolean_permutations(
|
||||
num_docs: usize,
|
||||
num_fields: usize,
|
||||
) -> (Index, Vec<Field>, Field) {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let fields: Vec<Field> = (0..num_fields)
|
||||
.map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT))
|
||||
.collect();
|
||||
// Add a numeric field for RangeQuery tests - all docs have value = doc_id
|
||||
let range_field = schema_builder.add_i64_field(
|
||||
"range_field",
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for doc_id in doc_ids(num_docs, num_fields) {
|
||||
let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default();
|
||||
for (field_idx, &field) in fields.iter().enumerate() {
|
||||
if (doc_id >> field_idx) & 1 == 1 {
|
||||
doc.insert(field, "true".into());
|
||||
}
|
||||
}
|
||||
// All docs have non-negative values, so RangeQuery(>=0) matches all
|
||||
doc.insert(range_field, (doc_id as i64).into());
|
||||
writer.add_document(doc).unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
(index, fields, range_field)
|
||||
}
|
||||
|
||||
fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy<Value = BooleanQueryAST> {
|
||||
// Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs
|
||||
let leaf = prop_oneof![
|
||||
(0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }),
|
||||
Just(BooleanQueryAST::All),
|
||||
Just(BooleanQueryAST::RangeAll),
|
||||
];
|
||||
leaf.prop_recursive(
|
||||
8, // 8 levels of recursion
|
||||
256, // 256 nodes max
|
||||
10, // 10 items per collection
|
||||
|inner| {
|
||||
prop_oneof![
|
||||
vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union),
|
||||
vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection),
|
||||
]
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proptest_boolean_query() {
|
||||
// In the presence of optimizations around buffering, it can take large numbers of
|
||||
// documents to uncover some issues.
|
||||
let num_fields = 8;
|
||||
let num_docs = 1 << num_fields;
|
||||
let (index, fields, range_field) =
|
||||
create_index_with_boolean_permutations(num_docs, num_fields);
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
proptest!(|(ast in arb_boolean_query_ast(num_fields))| {
|
||||
let query = ast.to_query(&fields, range_field);
|
||||
|
||||
let mut matching_docs = HashSet::new();
|
||||
for doc_id in doc_ids(num_docs, num_fields) {
|
||||
if ast.matches(doc_id as DocId) {
|
||||
matching_docs.insert(doc_id as DocId);
|
||||
}
|
||||
}
|
||||
|
||||
let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap();
|
||||
let result_docs: HashSet<DocId> =
|
||||
doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect();
|
||||
prop_assert_eq!(result_docs, matching_docs);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,9 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
self.underlying.seek(target)
|
||||
}
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
self.underlying.seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
|
||||
self.underlying.fill_buffer(buffer)
|
||||
@@ -131,6 +134,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
||||
}
|
||||
|
||||
impl<S: Scorer> Scorer for BoostScorer<S> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.underlying.score() * self.boost
|
||||
}
|
||||
|
||||
@@ -137,6 +137,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
|
||||
}
|
||||
|
||||
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
@@ -62,6 +62,16 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
|
||||
self.current_doc = doc_id;
|
||||
doc_id
|
||||
}
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
let doc_id = self.scorer.seek(target);
|
||||
self.current_doc = doc_id;
|
||||
doc_id
|
||||
}
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
let found = self.scorer.seek_into_the_danger_zone(target);
|
||||
self.current_doc = self.scorer.doc();
|
||||
found
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.current_doc
|
||||
@@ -163,6 +173,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
||||
for Disjunction<TScorer, TScoreCombiner>
|
||||
{
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.current_score
|
||||
}
|
||||
@@ -297,6 +308,7 @@ mod tests {
|
||||
}
|
||||
|
||||
impl Scorer for DummyScorer {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ impl DocSet for EmptyScorer {
|
||||
}
|
||||
|
||||
impl Scorer for EmptyScorer {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
0.0
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ where
|
||||
TScorer: Scorer,
|
||||
TDocSetExclude: DocSet + 'static,
|
||||
{
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.underlying_docset.score()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::size_hint::estimate_intersection;
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::query::size_hint::estimate_intersection;
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::{EmptyScorer, Scorer};
|
||||
use crate::{DocId, Score};
|
||||
@@ -12,6 +12,9 @@ use crate::{DocId, Score};
|
||||
/// For better performance, the function uses a
|
||||
/// specialized implementation if the two
|
||||
/// shortest scorers are `TermScorer`s.
|
||||
///
|
||||
/// num_docs_segment is the number of documents in the segment. It is used for estimating the
|
||||
/// `size_hint` of the intersection.
|
||||
pub fn intersect_scorers(
|
||||
mut scorers: Vec<Box<dyn Scorer>>,
|
||||
num_docs_segment: u32,
|
||||
@@ -102,35 +105,48 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
||||
}
|
||||
|
||||
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
|
||||
#[inline]
|
||||
fn advance(&mut self) -> DocId {
|
||||
let (left, right) = (&mut self.left, &mut self.right);
|
||||
let mut candidate = left.advance();
|
||||
if candidate == TERMINATED {
|
||||
return TERMINATED;
|
||||
}
|
||||
|
||||
'outer: loop {
|
||||
loop {
|
||||
// In the first part we look for a document in the intersection
|
||||
// of the two rarest `DocSet` in the intersection.
|
||||
|
||||
loop {
|
||||
let right_doc = right.seek(candidate);
|
||||
candidate = left.seek(right_doc);
|
||||
if candidate == right_doc {
|
||||
if right.seek_into_the_danger_zone(candidate) {
|
||||
break;
|
||||
}
|
||||
let right_doc = right.doc();
|
||||
// TODO: Think about which value would make sense here
|
||||
// It depends on the DocSet implementation, when a seek would outweigh an advance.
|
||||
if right_doc > candidate.wrapping_add(100) {
|
||||
candidate = left.seek(right_doc);
|
||||
} else {
|
||||
candidate = left.advance();
|
||||
}
|
||||
if candidate == TERMINATED {
|
||||
return TERMINATED;
|
||||
}
|
||||
}
|
||||
|
||||
debug_assert_eq!(left.doc(), right.doc());
|
||||
// test the remaining scorers;
|
||||
for docset in self.others.iter_mut() {
|
||||
let seek_doc = docset.seek(candidate);
|
||||
if seek_doc > candidate {
|
||||
candidate = left.seek(seek_doc);
|
||||
continue 'outer;
|
||||
}
|
||||
// test the remaining scorers
|
||||
if self
|
||||
.others
|
||||
.iter_mut()
|
||||
.all(|docset| docset.seek_into_the_danger_zone(candidate))
|
||||
{
|
||||
debug_assert_eq!(candidate, self.left.doc());
|
||||
debug_assert_eq!(candidate, self.right.doc());
|
||||
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
|
||||
return candidate;
|
||||
}
|
||||
debug_assert_eq!(candidate, self.left.doc());
|
||||
debug_assert_eq!(candidate, self.right.doc());
|
||||
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
|
||||
return candidate;
|
||||
candidate = left.advance();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +162,20 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
|
||||
doc
|
||||
}
|
||||
|
||||
/// Seeks to the target if necessary and checks if the target is an exact match.
|
||||
///
|
||||
/// Some implementations may choose to advance past the target if beneficial for performance.
|
||||
/// The return value is `true` if the target is in the docset, and `false` otherwise.
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
self.left.seek_into_the_danger_zone(target)
|
||||
&& self.right.seek_into_the_danger_zone(target)
|
||||
&& self
|
||||
.others
|
||||
.iter_mut()
|
||||
.all(|docset| docset.seek_into_the_danger_zone(target))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn doc(&self) -> DocId {
|
||||
self.left.doc()
|
||||
}
|
||||
@@ -172,6 +202,7 @@ where
|
||||
TScorer: Scorer,
|
||||
TOtherScorer: Scorer,
|
||||
{
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.left.score()
|
||||
+ self.right.score()
|
||||
@@ -181,6 +212,8 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::Intersection;
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::postings::tests::test_skip_against_unoptimized;
|
||||
@@ -270,4 +303,38 @@ mod tests {
|
||||
let intersection = Intersection::new(vec![a, b, c], 10);
|
||||
assert_eq!(intersection.doc(), TERMINATED);
|
||||
}
|
||||
|
||||
// Strategy to generate sorted and deduplicated vectors of u32 document IDs
|
||||
fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy<Value = Vec<u32>> {
|
||||
prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| {
|
||||
vec.sort();
|
||||
vec.dedup();
|
||||
vec
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn prop_test_intersection_consistency(
|
||||
a in sorted_deduped_vec(100, 10),
|
||||
b in sorted_deduped_vec(100, 10),
|
||||
num_docs in 100u32..500u32
|
||||
) {
|
||||
let left = VecDocSet::from(a.clone());
|
||||
let right = VecDocSet::from(b.clone());
|
||||
let mut intersection = Intersection::new(vec![left, right], num_docs);
|
||||
|
||||
let expected: Vec<u32> = a.iter()
|
||||
.cloned()
|
||||
.filter(|doc| b.contains(doc))
|
||||
.collect();
|
||||
|
||||
for expected_doc in expected {
|
||||
assert_eq!(intersection.doc(), expected_doc);
|
||||
intersection.advance();
|
||||
}
|
||||
assert_eq!(intersection.doc(), TERMINATED);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,9 +70,83 @@ pub use self::weight::Weight;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::phrase_query::tests::create_index;
|
||||
use crate::query::QueryParser;
|
||||
use crate::schema::{Schema, TEXT};
|
||||
use crate::{Index, Term};
|
||||
use crate::{DocAddress, Index, Term};
|
||||
|
||||
#[test]
|
||||
pub fn test_mixed_intersection_and_union() -> crate::Result<()> {
|
||||
let index = create_index(&["a b", "a c", "a b c", "b"])?;
|
||||
let schema = index.schema();
|
||||
let text_field = schema.get_field("text").unwrap();
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
let do_search = |term: &str| {
|
||||
let query = QueryParser::for_index(&index, vec![text_field])
|
||||
.parse_query(term)
|
||||
.unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(10).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
assert_eq!(do_search("a AND b"), vec![0, 2]);
|
||||
assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]);
|
||||
// The intersection code has special code for more than 2 intersections
|
||||
// left, right + others
|
||||
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
|
||||
// is called
|
||||
assert_eq!(
|
||||
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
|
||||
vec![2, 1, 0]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> {
|
||||
// Test 4096 skip in BufferedUnionScorer
|
||||
let mut data: Vec<&str> = Vec::new();
|
||||
data.push("a b");
|
||||
let zz_data = vec!["z z"; 5000];
|
||||
data.extend_from_slice(&zz_data);
|
||||
data.extend_from_slice(&["a c"]);
|
||||
data.extend_from_slice(&zz_data);
|
||||
data.extend_from_slice(&["a b c", "b"]);
|
||||
let index = create_index(&data)?;
|
||||
let schema = index.schema();
|
||||
let text_field = schema.get_field("text").unwrap();
|
||||
let searcher = index.reader()?.searcher();
|
||||
|
||||
let do_search = |term: &str| {
|
||||
let query = QueryParser::for_index(&index, vec![text_field])
|
||||
.parse_query(term)
|
||||
.unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(10).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
assert_eq!(do_search("a AND b"), vec![0, 10002]);
|
||||
assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]);
|
||||
// The intersection code has special code for more than 2 intersections
|
||||
// left, right + others
|
||||
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
|
||||
// is called
|
||||
assert_eq!(
|
||||
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
|
||||
vec![10002, 5001, 0]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_terms() {
|
||||
|
||||
@@ -81,6 +81,7 @@ impl<TPostings: Postings> DocSet for PhraseKind<TPostings> {
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> Scorer for PhraseKind<TPostings> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
match self {
|
||||
PhraseKind::SinglePrefix { positions, .. } => {
|
||||
@@ -193,6 +194,14 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
|
||||
self.advance()
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
if self.phrase_scorer.seek_into_the_danger_zone(target) {
|
||||
self.matches_prefix()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.phrase_scorer.doc()
|
||||
}
|
||||
@@ -207,6 +216,7 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
// TODO modify score??
|
||||
self.phrase_scorer.score()
|
||||
|
||||
@@ -382,8 +382,9 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
|
||||
PostingsWithOffset::new(postings, (max_offset - offset) as u32)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let intersection_docset = Intersection::new(postings_with_offsets, num_docs);
|
||||
let mut scorer = PhraseScorer {
|
||||
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
|
||||
intersection_docset,
|
||||
num_terms: num_docsets,
|
||||
left_positions: Vec::with_capacity(100),
|
||||
right_positions: Vec::with_capacity(100),
|
||||
@@ -529,25 +530,40 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
|
||||
self.advance()
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
debug_assert!(target >= self.doc());
|
||||
if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.intersection_docset.doc()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.intersection_docset.size_hint()
|
||||
// We adjust the intersection estimate, since actual phrase hits are much lower than where
|
||||
// the all appear.
|
||||
// The estimate should depend on average field length, e.g. if the field is really short
|
||||
// a phrase hit is more likely
|
||||
self.intersection_docset.size_hint() / (10 * self.num_terms as u32)
|
||||
}
|
||||
|
||||
/// Returns a best-effort hint of the
|
||||
/// cost to drive the docset.
|
||||
fn cost(&self) -> u64 {
|
||||
// Evaluating phrase matches is generally more expensive than simple term matches,
|
||||
// as it requires loading and comparing positions. Use a conservative multiplier
|
||||
// based on the number of terms.
|
||||
// While determing a potential hit is cheap for phrases, evaluating an actual hit is
|
||||
// expensive since it requires to load positions for a doc and check if they are next to
|
||||
// each other.
|
||||
// So the cost estimation would be the number of times we need to check if a doc is a hit *
|
||||
// 10 * self.num_terms.
|
||||
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
let doc = self.doc();
|
||||
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
|
||||
|
||||
@@ -62,6 +62,17 @@ pub(crate) struct RangeDocSet<T> {
|
||||
const DEFAULT_FETCH_HORIZON: u32 = 128;
|
||||
impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
|
||||
pub(crate) fn new(value_range: RangeInclusive<T>, column: Column<T>) -> Self {
|
||||
if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() {
|
||||
return Self {
|
||||
value_range,
|
||||
column,
|
||||
loaded_docs: VecCursor::new(),
|
||||
next_fetch_start: TERMINATED,
|
||||
fetch_horizon: DEFAULT_FETCH_HORIZON,
|
||||
last_seek_pos_opt: None,
|
||||
};
|
||||
}
|
||||
|
||||
let mut range_docset = Self {
|
||||
value_range,
|
||||
column,
|
||||
@@ -81,6 +92,9 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
|
||||
|
||||
/// Returns true if more data could be fetched
|
||||
fn fetch_block(&mut self) {
|
||||
if self.next_fetch_start >= self.column.num_docs() {
|
||||
return;
|
||||
}
|
||||
const MAX_HORIZON: u32 = 100_000;
|
||||
while self.loaded_docs.is_empty() {
|
||||
let finished_to_end = self.fetch_horizon(self.fetch_horizon);
|
||||
@@ -105,10 +119,10 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
|
||||
fn fetch_horizon(&mut self, horizon: u32) -> bool {
|
||||
let mut finished_to_end = false;
|
||||
|
||||
let limit = self.column.num_docs();
|
||||
let mut end = self.next_fetch_start + horizon;
|
||||
if end >= limit {
|
||||
end = limit;
|
||||
let num_docs = self.column.num_docs();
|
||||
let mut fetch_end = self.next_fetch_start + horizon;
|
||||
if fetch_end >= num_docs {
|
||||
fetch_end = num_docs;
|
||||
finished_to_end = true;
|
||||
}
|
||||
|
||||
@@ -116,7 +130,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
|
||||
let doc_buffer: &mut Vec<DocId> = self.loaded_docs.get_cleared_data();
|
||||
self.column.get_docids_for_value_range(
|
||||
self.value_range.clone(),
|
||||
self.next_fetch_start..end,
|
||||
self.next_fetch_start..fetch_end,
|
||||
doc_buffer,
|
||||
);
|
||||
if let Some(last_doc) = last_doc {
|
||||
@@ -124,7 +138,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
|
||||
self.loaded_docs.next();
|
||||
}
|
||||
}
|
||||
self.next_fetch_start = end;
|
||||
self.next_fetch_start = fetch_end;
|
||||
|
||||
finished_to_end
|
||||
}
|
||||
@@ -136,9 +150,6 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
|
||||
if let Some(docid) = self.loaded_docs.next() {
|
||||
return docid;
|
||||
}
|
||||
if self.next_fetch_start >= self.column.num_docs() {
|
||||
return TERMINATED;
|
||||
}
|
||||
self.fetch_block();
|
||||
self.loaded_docs.current().unwrap_or(TERMINATED)
|
||||
}
|
||||
@@ -174,15 +185,25 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.column.num_docs()
|
||||
// TODO: Implement a better size hint
|
||||
self.column.num_docs() / 10
|
||||
}
|
||||
|
||||
/// Returns a best-effort hint of the
|
||||
/// cost to drive the docset.
|
||||
fn cost(&self) -> u64 {
|
||||
// Advancing the docset is relatively expensive since it scans the column.
|
||||
// Keep cost relative to a term query driver; use num_docs as baseline.
|
||||
self.column.num_docs() as u64
|
||||
// Advancing the docset is pretty expensive since it scans the whole column, there is no
|
||||
// index currently (will change with an kd-tree)
|
||||
// Since we use SIMD to scan the fast field range query we lower the cost a little bit,
|
||||
// assuming that we hit 10% of the docs like in size_hint.
|
||||
//
|
||||
// If we would return a cost higher than num_docs, we would never choose ff range query as
|
||||
// the driver in a DocSet, when intersecting a term query with a fast field. But
|
||||
// it's the faster choice when the term query has a lot of docids and the range
|
||||
// query has not.
|
||||
//
|
||||
// Ideally this would take the fast field codec into account
|
||||
(self.column.num_docs() as f64 * 0.8) as u64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,4 +257,52 @@ mod tests {
|
||||
let count = searcher.search(&query, &Count).unwrap();
|
||||
assert_eq!(count, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_query_no_overlap_optimization() {
|
||||
let mut schema_builder = schema::SchemaBuilder::new();
|
||||
let id_field = schema_builder.add_text_field("id", schema::STRING);
|
||||
let value_field = schema_builder.add_u64_field("value", schema::FAST | schema::INDEXED);
|
||||
|
||||
let dir = RamDirectory::default();
|
||||
let index = IndexBuilder::new()
|
||||
.schema(schema_builder.build())
|
||||
.open_or_create(dir)
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let mut writer = index.writer(15_000_000).unwrap();
|
||||
|
||||
// Add documents with values in the range [10, 20]
|
||||
for i in 0..100 {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_text(id_field, format!("doc{i}"));
|
||||
doc.add_u64(value_field, 10 + (i % 11) as u64); // values in range 10-20
|
||||
|
||||
writer.add_document(doc).unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Test a range query [100, 200] that has no overlap with data range [10, 20]
|
||||
let query = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(value_field, 100)),
|
||||
Bound::Included(Term::from_field_u64(value_field, 200)),
|
||||
);
|
||||
|
||||
let count = searcher.search(&query, &Count).unwrap();
|
||||
assert_eq!(count, 0); // should return 0 results since there's no overlap
|
||||
|
||||
// Test another non-overlapping range: [0, 5] while data range is [10, 20]
|
||||
let query2 = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(value_field, 0)),
|
||||
Bound::Included(Term::from_field_u64(value_field, 5)),
|
||||
);
|
||||
|
||||
let count2 = searcher.search(&query2, &Count).unwrap();
|
||||
assert_eq!(count2, 0); // should return 0 results since there's no overlap
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1598,449 +1598,3 @@ pub(crate) mod ip_range_tests {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::tests::*;
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::Index;
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id_name = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"many".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id_name,
|
||||
id: rng.gen_range(0..100),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_docs(&docs, false)
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<u64> {
|
||||
0..=90
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<u64> {
|
||||
0..=10
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<u64> {
|
||||
10..=10
|
||||
}
|
||||
|
||||
fn execute_query(
|
||||
field: &str,
|
||||
id_range: RangeInclusive<u64>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> usize {
|
||||
let gen_query_inclusive = |from: &u64, to: &u64| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(id_range.start(), id_range.end());
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(&query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
searcher.search(&query, &(Count)).unwrap()
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_id_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench_ip {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::ip_range_tests::*;
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::Index;
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"many".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id,
|
||||
// Multiply by 1000, so that we create many buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_ip_docs(&docs)
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn execute_query(
|
||||
field: &str,
|
||||
ip_range: RangeInclusive<Ipv6Addr>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> usize {
|
||||
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(&query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
searcher.search(&query, &(Count)).unwrap()
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,11 @@ where
|
||||
self.req_scorer.seek(target)
|
||||
}
|
||||
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
self.score_cache = None;
|
||||
self.req_scorer.seek_into_the_danger_zone(target)
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.req_scorer.doc()
|
||||
}
|
||||
@@ -76,6 +81,7 @@ where
|
||||
TOptScorer: Scorer,
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
if let Some(score) = self.score_cache {
|
||||
return score;
|
||||
|
||||
@@ -29,6 +29,7 @@ impl ScoreCombiner for DoNothingCombiner {
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
#[inline]
|
||||
fn score(&self) -> Score {
|
||||
1.0
|
||||
}
|
||||
@@ -49,6 +50,7 @@ impl ScoreCombiner for SumCombiner {
|
||||
self.score = 0.0;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score(&self) -> Score {
|
||||
self.score
|
||||
}
|
||||
@@ -86,6 +88,7 @@ impl ScoreCombiner for DisjunctionMaxCombiner {
|
||||
self.sum = 0.0;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score(&self) -> Score {
|
||||
self.max + (self.sum - self.max) * self.tie_breaker
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
|
||||
impl_downcast!(Scorer);
|
||||
|
||||
impl Scorer for Box<dyn Scorer> {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.deref_mut().score()
|
||||
}
|
||||
|
||||
@@ -98,14 +98,17 @@ impl TermScorer {
|
||||
}
|
||||
|
||||
impl DocSet for TermScorer {
|
||||
#[inline]
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.postings.advance()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
self.postings.seek(target)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn doc(&self) -> DocId {
|
||||
self.postings.doc()
|
||||
}
|
||||
@@ -116,6 +119,7 @@ impl DocSet for TermScorer {
|
||||
}
|
||||
|
||||
impl Scorer for TermScorer {
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
let fieldnorm_id = self.fieldnorm_id();
|
||||
let term_freq = self.term_freq();
|
||||
|
||||
@@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32;
|
||||
// This function is similar except that it does is not unstable, and
|
||||
// it does not keep the original vector ordering.
|
||||
//
|
||||
// Also, it does not "yield" any elements.
|
||||
// Elements are dropped and not yielded.
|
||||
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
|
||||
where P: FnMut(&mut T) -> bool {
|
||||
let mut i = 0;
|
||||
@@ -128,6 +128,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn advance_buffered(&mut self) -> bool {
|
||||
while self.bucket_idx < HORIZON_NUM_TINYBITSETS {
|
||||
if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() {
|
||||
@@ -143,6 +144,12 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn is_in_horizon(&self, target: DocId) -> bool {
|
||||
// wrapping_sub, because target may be < window_start_doc
|
||||
let gap = target.wrapping_sub(self.window_start_doc);
|
||||
gap < HORIZON
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer, TScoreCombiner> DocSet for BufferedUnionScorer<TScorer, TScoreCombiner>
|
||||
@@ -150,6 +157,7 @@ where
|
||||
TScorer: Scorer,
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
#[inline]
|
||||
fn advance(&mut self) -> DocId {
|
||||
if self.advance_buffered() {
|
||||
return self.doc;
|
||||
@@ -217,8 +225,29 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Also implement `count` with deletes efficiently.
|
||||
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
|
||||
if self.is_in_horizon(target) {
|
||||
// Our value is within the buffered horizon and the docset may already have been
|
||||
// processed and removed, so we need to use seek, which uses the regular advance.
|
||||
self.seek(target) == target
|
||||
} else {
|
||||
// The docsets are not in the buffered range, so we can use seek_into_the_danger_zone
|
||||
// of the underlying docsets
|
||||
let is_hit = self
|
||||
.docsets
|
||||
.iter_mut()
|
||||
.any(|docset| docset.seek_into_the_danger_zone(target));
|
||||
|
||||
// The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone`
|
||||
// returns true.
|
||||
if is_hit {
|
||||
self.seek(target);
|
||||
}
|
||||
is_hit
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn doc(&self) -> DocId {
|
||||
self.doc
|
||||
}
|
||||
@@ -231,6 +260,7 @@ where
|
||||
self.docsets.iter().map(|docset| docset.cost()).sum()
|
||||
}
|
||||
|
||||
// TODO Also implement `count` with deletes efficiently.
|
||||
fn count_including_deleted(&mut self) -> u32 {
|
||||
if self.doc == TERMINATED {
|
||||
return 0;
|
||||
@@ -259,6 +289,7 @@ where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
TScorer: Scorer,
|
||||
{
|
||||
#[inline]
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
// TODO: use estimate_union
|
||||
self.docsets
|
||||
.iter()
|
||||
.map(|docset| docset.size_hint())
|
||||
|
||||
@@ -58,6 +58,31 @@ impl AsRef<OwnedValue> for OwnedValue {
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedValue {
|
||||
/// Returns a u8 discriminant value for the `OwnedValue` variant.
|
||||
///
|
||||
/// This can be used to sort `OwnedValue` instances by their type.
|
||||
pub fn discriminant_value(&self) -> u8 {
|
||||
match self {
|
||||
OwnedValue::Null => 0,
|
||||
OwnedValue::Str(_) => 1,
|
||||
OwnedValue::PreTokStr(_) => 2,
|
||||
// It is key to make sure U64, I64, F64 are grouped together in there, otherwise we
|
||||
// might be breaking transivity.
|
||||
OwnedValue::U64(_) => 3,
|
||||
OwnedValue::I64(_) => 4,
|
||||
OwnedValue::F64(_) => 5,
|
||||
OwnedValue::Bool(_) => 6,
|
||||
OwnedValue::Date(_) => 7,
|
||||
OwnedValue::Facet(_) => 8,
|
||||
OwnedValue::Bytes(_) => 9,
|
||||
OwnedValue::Array(_) => 10,
|
||||
OwnedValue::Object(_) => 11,
|
||||
OwnedValue::IpAddr(_) => 12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Value<'a> for &'a OwnedValue {
|
||||
type ArrayIter = std::slice::Iter<'a, OwnedValue>;
|
||||
type ObjectIter = ObjectMapIter<'a>;
|
||||
|
||||
@@ -98,6 +98,10 @@
|
||||
//! make it possible to access the value given the doc id rapidly. This is useful if the value
|
||||
//! of the field is required during scoring or collection for instance.
|
||||
//!
|
||||
//! Some queries may leverage Fast fields when run on a field that is not indexed. This can be
|
||||
//! handy if that kind of request is infrequent, however note that searching on a Fast field is
|
||||
//! generally much slower than searching in an index.
|
||||
//!
|
||||
//! ```
|
||||
//! use tantivy::schema::*;
|
||||
//! let mut schema_builder = Schema::builder();
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::hash::Hash;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::{fmt, str};
|
||||
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP_STR};
|
||||
use common::JsonPathWriter;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::date_time_options::DATE_TIME_PRECISION_INDEXED;
|
||||
use super::{Field, Schema};
|
||||
@@ -16,23 +17,54 @@ use crate::DateTime;
|
||||
/// Term represents the value that the token can take.
|
||||
/// It's a serialized representation over different types.
|
||||
///
|
||||
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
|
||||
/// 4 bytes are the field id, and the last byte is the type.
|
||||
///
|
||||
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
|
||||
#[derive(Clone)]
|
||||
pub struct Term<B = Vec<u8>>(B)
|
||||
where B: AsRef<[u8]>;
|
||||
/// A term is composed of Field and the serialized value bytes.
|
||||
/// The serialized value bytes themselves start with a one byte type tag followed by the payload.
|
||||
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub struct Term {
|
||||
field: Field,
|
||||
serialized_value_bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
/// The number of bytes used as metadata by `Term`.
|
||||
const TERM_METADATA_LENGTH: usize = 5;
|
||||
/// The number of bytes used as metadata when serializing a term.
|
||||
const TERM_TYPE_TAG_LEN: usize = 1;
|
||||
|
||||
impl Term {
|
||||
/// Takes a serialized term and wraps it as a Term.
|
||||
/// First 4 bytes are the field id
|
||||
#[deprecated(
|
||||
note = "we want to avoid working on the serialized representation directly, replace with \
|
||||
typed API calls (add more if needed) or use serde to serialize/deserialize"
|
||||
)]
|
||||
pub fn wrap(serialized: &[u8]) -> Term {
|
||||
let field_id_bytes: [u8; 4] = serialized[0..4].try_into().unwrap();
|
||||
let field_id = u32::from_be_bytes(field_id_bytes);
|
||||
Term {
|
||||
field: Field::from_field_id(field_id),
|
||||
serialized_value_bytes: serialized[4..].to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the serialized representation of the term.
|
||||
/// First 4 bytes are the field id
|
||||
#[deprecated(
|
||||
note = "we want to avoid working on the serialized representation directly, replace with \
|
||||
typed API calls (add more if needed) or use serde to serialize/deserialize"
|
||||
)]
|
||||
pub fn serialized_term(&self) -> Vec<u8> {
|
||||
let mut serialized = Vec::with_capacity(4 + self.serialized_value_bytes.len());
|
||||
serialized.extend(self.field.field_id().to_be_bytes().as_ref());
|
||||
serialized.extend_from_slice(&self.serialized_value_bytes);
|
||||
serialized
|
||||
}
|
||||
|
||||
/// Create a new Term with a buffer with a given capacity.
|
||||
pub fn with_capacity(capacity: usize) -> Term {
|
||||
let mut data = Vec::with_capacity(TERM_METADATA_LENGTH + capacity);
|
||||
data.resize(TERM_METADATA_LENGTH, 0u8);
|
||||
Term(data)
|
||||
let mut data = Vec::with_capacity(TERM_TYPE_TAG_LEN + capacity);
|
||||
data.resize(TERM_TYPE_TAG_LEN, 0u8);
|
||||
Term {
|
||||
field: Field::from_field_id(0u32),
|
||||
serialized_value_bytes: data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a term from a json path.
|
||||
@@ -89,7 +121,7 @@ impl Term {
|
||||
fn with_bytes_and_field_and_payload(typ: Type, field: Field, bytes: &[u8]) -> Term {
|
||||
let mut term = Self::with_capacity(bytes.len());
|
||||
term.set_field_and_type(field, typ);
|
||||
term.0.extend_from_slice(bytes);
|
||||
term.serialized_value_bytes.extend_from_slice(bytes);
|
||||
term
|
||||
}
|
||||
|
||||
@@ -105,13 +137,13 @@ impl Term {
|
||||
/// Sets field and the type.
|
||||
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
|
||||
assert!(self.is_empty());
|
||||
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
|
||||
self.0[4] = typ.to_code();
|
||||
self.field = field;
|
||||
self.serialized_value_bytes[0] = typ.to_code();
|
||||
}
|
||||
|
||||
/// Is empty if there are no value bytes.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.len() == TERM_METADATA_LENGTH
|
||||
self.serialized_value_bytes.len() == TERM_TYPE_TAG_LEN
|
||||
}
|
||||
|
||||
/// Builds a term given a field, and a `Ipv6Addr`-value
|
||||
@@ -177,7 +209,7 @@ impl Term {
|
||||
/// Removes the value_bytes and set the type code.
|
||||
pub fn clear_with_type(&mut self, typ: Type) {
|
||||
self.truncate_value_bytes(0);
|
||||
self.0[4] = typ.to_code();
|
||||
self.serialized_value_bytes[0] = typ.to_code();
|
||||
}
|
||||
|
||||
/// Append a type marker + fast value to a term.
|
||||
@@ -185,9 +217,10 @@ impl Term {
|
||||
///
|
||||
/// It will not clear existing bytes.
|
||||
pub fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) {
|
||||
self.0.push(T::to_type().to_code());
|
||||
self.serialized_value_bytes.push(T::to_type().to_code());
|
||||
let value = val.to_u64();
|
||||
self.0.extend(value.to_be_bytes().as_ref());
|
||||
self.serialized_value_bytes
|
||||
.extend(value.to_be_bytes().as_ref());
|
||||
}
|
||||
|
||||
/// Append a string type marker + string to a term.
|
||||
@@ -195,24 +228,25 @@ impl Term {
|
||||
///
|
||||
/// It will not clear existing bytes.
|
||||
pub fn append_type_and_str(&mut self, val: &str) {
|
||||
self.0.push(Type::Str.to_code());
|
||||
self.0.extend(val.as_bytes().as_ref());
|
||||
self.serialized_value_bytes.push(Type::Str.to_code());
|
||||
self.serialized_value_bytes.extend(val.as_bytes().as_ref());
|
||||
}
|
||||
|
||||
/// Sets the value of a `Bytes` field.
|
||||
pub fn set_bytes(&mut self, bytes: &[u8]) {
|
||||
self.truncate_value_bytes(0);
|
||||
self.0.extend(bytes);
|
||||
self.serialized_value_bytes.extend(bytes);
|
||||
}
|
||||
|
||||
/// Truncates the value bytes of the term. Value and field type stays the same.
|
||||
pub fn truncate_value_bytes(&mut self, len: usize) {
|
||||
self.0.truncate(len + TERM_METADATA_LENGTH);
|
||||
self.serialized_value_bytes
|
||||
.truncate(len + TERM_TYPE_TAG_LEN);
|
||||
}
|
||||
|
||||
/// The length of the bytes.
|
||||
pub fn len_bytes(&self) -> usize {
|
||||
self.0.len() - TERM_METADATA_LENGTH
|
||||
self.serialized_value_bytes.len() - TERM_TYPE_TAG_LEN
|
||||
}
|
||||
|
||||
/// Appends value bytes to the Term.
|
||||
@@ -220,18 +254,9 @@ impl Term {
|
||||
/// This function returns the segment that has just been added.
|
||||
#[inline]
|
||||
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
|
||||
let len_before = self.0.len();
|
||||
self.0.extend_from_slice(bytes);
|
||||
&mut self.0[len_before..]
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
/// Wraps a object holding bytes
|
||||
pub fn wrap(data: B) -> Term<B> {
|
||||
Term(data)
|
||||
let len_before = self.serialized_value_bytes.len();
|
||||
self.serialized_value_bytes.extend_from_slice(bytes);
|
||||
&mut self.serialized_value_bytes[len_before..]
|
||||
}
|
||||
|
||||
/// Return the type of the term.
|
||||
@@ -241,8 +266,7 @@ where B: AsRef<[u8]>
|
||||
|
||||
/// Returns the field.
|
||||
pub fn field(&self) -> Field {
|
||||
let field_id_bytes: [u8; 4] = (&self.0.as_ref()[..4]).try_into().unwrap();
|
||||
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
|
||||
self.field
|
||||
}
|
||||
|
||||
/// Returns the serialized representation of the value.
|
||||
@@ -252,23 +276,13 @@ where B: AsRef<[u8]>
|
||||
/// If the term is a u64, its value is encoded according
|
||||
/// to `byteorder::BigEndian`.
|
||||
pub fn serialized_value_bytes(&self) -> &[u8] {
|
||||
&self.0.as_ref()[TERM_METADATA_LENGTH..]
|
||||
&self.serialized_value_bytes[TERM_TYPE_TAG_LEN..]
|
||||
}
|
||||
|
||||
/// Returns the value of the term.
|
||||
/// address or JSON path + value. (this does not include the field.)
|
||||
pub fn value(&self) -> ValueBytes<&[u8]> {
|
||||
ValueBytes::wrap(&self.0.as_ref()[4..])
|
||||
}
|
||||
|
||||
/// Returns the serialized representation of Term.
|
||||
/// This includes field_id, value type and value.
|
||||
///
|
||||
/// Do NOT rely on this byte representation in the index.
|
||||
/// This value is likely to change in the future.
|
||||
#[inline]
|
||||
pub fn serialized_term(&self) -> &[u8] {
|
||||
self.0.as_ref()
|
||||
ValueBytes::wrap(self.serialized_value_bytes.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,10 +466,7 @@ where B: AsRef<[u8]>
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the serialized representation of Term.
|
||||
///
|
||||
/// Do NOT rely on this byte representation in the index.
|
||||
/// This value is likely to change in the future.
|
||||
/// Returns the serialized representation of the value bytes including the type tag.
|
||||
pub fn as_serialized(&self) -> &[u8] {
|
||||
self.0.as_ref()
|
||||
}
|
||||
@@ -508,40 +519,6 @@ where B: AsRef<[u8]>
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Ord for Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.serialized_term().cmp(other.serialized_term())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> PartialOrd for Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> PartialEq for Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.serialized_term() == other.serialized_term()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Eq for Term<B> where B: AsRef<[u8]> {}
|
||||
|
||||
impl<B> Hash for Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.0.as_ref().hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_opt<T: std::fmt::Debug>(f: &mut fmt::Formatter, val_opt: Option<T>) -> fmt::Result {
|
||||
if let Some(val) = val_opt {
|
||||
write!(f, "{val:?}")?;
|
||||
@@ -549,13 +526,11 @@ fn write_opt<T: std::fmt::Debug>(f: &mut fmt::Formatter, val_opt: Option<T>) ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<B> fmt::Debug for Term<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
impl fmt::Debug for Term {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let field_id = self.field().field_id();
|
||||
let field_id = self.field.field_id();
|
||||
write!(f, "Term(field={field_id}, ")?;
|
||||
let value_bytes = ValueBytes::wrap(&self.0.as_ref()[4..]);
|
||||
let value_bytes = ValueBytes::wrap(&self.serialized_value_bytes);
|
||||
value_bytes.debug_value_bytes(f)?;
|
||||
write!(f, ")",)?;
|
||||
Ok(())
|
||||
@@ -578,17 +553,6 @@ mod tests {
|
||||
assert_eq!(term.value().as_str(), Some("test"))
|
||||
}
|
||||
|
||||
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
|
||||
/// <field> + <type byte> + <value len>
|
||||
///
|
||||
/// - <field> is a big endian encoded u32 field id
|
||||
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
|
||||
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
|
||||
/// type is the type of the leaf of the json.
|
||||
/// - <value> is, if this is not the json term, a binary representation specific to the type.
|
||||
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
|
||||
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
|
||||
|
||||
#[test]
|
||||
pub fn test_term_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
@@ -596,7 +560,7 @@ mod tests {
|
||||
let term = Term::from_field_u64(count_field, 983u64);
|
||||
assert_eq!(term.field(), count_field);
|
||||
assert_eq!(term.typ(), Type::U64);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.serialized_value_bytes().len(), 8);
|
||||
assert_eq!(term.value().as_u64(), Some(983u64))
|
||||
}
|
||||
|
||||
@@ -607,7 +571,7 @@ mod tests {
|
||||
let term = Term::from_field_bool(bool_field, true);
|
||||
assert_eq!(term.field(), bool_field);
|
||||
assert_eq!(term.typ(), Type::Bool);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.serialized_value_bytes().len(), 8);
|
||||
assert_eq!(term.value().as_bool(), Some(true))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,7 +483,7 @@ mod tests {
|
||||
|
||||
use super::{collapse_overlapped_ranges, search_fragments, select_best_fragment_combination};
|
||||
use crate::query::QueryParser;
|
||||
use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions, TEXT};
|
||||
use crate::schema::{Schema, TEXT};
|
||||
use crate::snippet::SnippetGenerator;
|
||||
use crate::tokenizer::{NgramTokenizer, SimpleTokenizer};
|
||||
use crate::Index;
|
||||
@@ -727,8 +727,10 @@ Survey in 2016, 2017, and 2018."#;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "stemmer")]
|
||||
#[test]
|
||||
fn test_snippet_generator() -> crate::Result<()> {
|
||||
use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions};
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_options = TextOptions::default().set_indexing_options(
|
||||
TextFieldIndexing::default()
|
||||
|
||||
@@ -7,13 +7,14 @@
|
||||
//! storage-level details into consideration. For example, if your file system block size is 4096
|
||||
//! bytes, we can under-count actual resultant space usage by up to 4095 bytes per file.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::btree_map::Entry;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use columnar::ColumnSpaceUsage;
|
||||
use common::ByteCount;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::index::SegmentComponent;
|
||||
use crate::schema::Field;
|
||||
|
||||
/// Enum containing any of the possible space usage results for segment components.
|
||||
pub enum ComponentSpaceUsage {
|
||||
@@ -212,17 +213,26 @@ impl StoreSpaceUsage {
|
||||
/// Multiple indexes are used to handle variable length things, where
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PerFieldSpaceUsage {
|
||||
fields: HashMap<Field, FieldUsage>,
|
||||
fields: BTreeMap<String, FieldUsage>,
|
||||
total: ByteCount,
|
||||
}
|
||||
|
||||
impl PerFieldSpaceUsage {
|
||||
pub(crate) fn new(fields: Vec<FieldUsage>) -> PerFieldSpaceUsage {
|
||||
let total = fields.iter().map(FieldUsage::total).sum();
|
||||
let field_usage_map: HashMap<Field, FieldUsage> = fields
|
||||
.into_iter()
|
||||
.map(|field_usage| (field_usage.field(), field_usage))
|
||||
.collect();
|
||||
let mut total = ByteCount::default();
|
||||
let mut field_usage_map: BTreeMap<String, FieldUsage> = BTreeMap::new();
|
||||
for field_usage in fields {
|
||||
total += field_usage.total();
|
||||
let field_name = field_usage.field_name().to_string();
|
||||
match field_usage_map.entry(field_name) {
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(field_usage);
|
||||
}
|
||||
Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().merge(field_usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
PerFieldSpaceUsage {
|
||||
fields: field_usage_map,
|
||||
total,
|
||||
@@ -230,8 +240,8 @@ impl PerFieldSpaceUsage {
|
||||
}
|
||||
|
||||
/// Per field space usage
|
||||
pub fn fields(&self) -> impl Iterator<Item = (&Field, &FieldUsage)> {
|
||||
self.fields.iter()
|
||||
pub fn fields(&self) -> impl Iterator<Item = &FieldUsage> {
|
||||
self.fields.values()
|
||||
}
|
||||
|
||||
/// Bytes used by the represented file
|
||||
@@ -246,20 +256,23 @@ impl PerFieldSpaceUsage {
|
||||
/// See documentation for [`PerFieldSpaceUsage`] for slightly more information.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FieldUsage {
|
||||
field: Field,
|
||||
field_name: String,
|
||||
num_bytes: ByteCount,
|
||||
/// A field can be composed of more than one piece.
|
||||
/// These pieces are indexed by arbitrary numbers starting at zero.
|
||||
/// `self.num_bytes` includes all of `self.sub_num_bytes`.
|
||||
sub_num_bytes: Vec<Option<ByteCount>>,
|
||||
/// Space usage of the column for fast fields, if relevant.
|
||||
column_space_usage: Option<ColumnSpaceUsage>,
|
||||
}
|
||||
|
||||
impl FieldUsage {
|
||||
pub(crate) fn empty(field: Field) -> FieldUsage {
|
||||
pub(crate) fn empty(field_name: impl Into<String>) -> FieldUsage {
|
||||
FieldUsage {
|
||||
field,
|
||||
field_name: field_name.into(),
|
||||
num_bytes: Default::default(),
|
||||
sub_num_bytes: Vec::new(),
|
||||
column_space_usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,9 +285,14 @@ impl FieldUsage {
|
||||
self.num_bytes += size
|
||||
}
|
||||
|
||||
pub(crate) fn set_column_usage(&mut self, column_space_usage: ColumnSpaceUsage) {
|
||||
self.num_bytes += column_space_usage.total_num_bytes();
|
||||
self.column_space_usage = Some(column_space_usage);
|
||||
}
|
||||
|
||||
/// Field
|
||||
pub fn field(&self) -> Field {
|
||||
self.field
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field_name
|
||||
}
|
||||
|
||||
/// Space usage for each index
|
||||
@@ -282,16 +300,64 @@ impl FieldUsage {
|
||||
&self.sub_num_bytes[..]
|
||||
}
|
||||
|
||||
/// Returns the number of bytes used by the column payload, if the field is columnar.
|
||||
pub fn column_num_bytes(&self) -> Option<ByteCount> {
|
||||
self.column_space_usage
|
||||
.as_ref()
|
||||
.map(ColumnSpaceUsage::column_num_bytes)
|
||||
}
|
||||
|
||||
/// Returns the number of bytes used by the dictionary for dictionary-encoded columns.
|
||||
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
|
||||
self.column_space_usage
|
||||
.as_ref()
|
||||
.and_then(ColumnSpaceUsage::dictionary_num_bytes)
|
||||
}
|
||||
|
||||
/// Returns the space usage of the column, if any.
|
||||
pub fn column_space_usage(&self) -> Option<&ColumnSpaceUsage> {
|
||||
self.column_space_usage.as_ref()
|
||||
}
|
||||
|
||||
/// Total bytes used for this field in this context
|
||||
pub fn total(&self) -> ByteCount {
|
||||
self.num_bytes
|
||||
}
|
||||
|
||||
fn merge(&mut self, other: FieldUsage) {
|
||||
assert_eq!(self.field_name, other.field_name);
|
||||
self.num_bytes += other.num_bytes;
|
||||
if other.sub_num_bytes.len() > self.sub_num_bytes.len() {
|
||||
self.sub_num_bytes.resize(other.sub_num_bytes.len(), None);
|
||||
}
|
||||
for (idx, num_bytes_opt) in other.sub_num_bytes.into_iter().enumerate() {
|
||||
if let Some(num_bytes) = num_bytes_opt {
|
||||
match self.sub_num_bytes[idx] {
|
||||
Some(existing) => self.sub_num_bytes[idx] = Some(existing + num_bytes),
|
||||
None => self.sub_num_bytes[idx] = Some(num_bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
self.column_space_usage =
|
||||
merge_column_space_usage(self.column_space_usage.take(), other.column_space_usage);
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_column_space_usage(
|
||||
left: Option<ColumnSpaceUsage>,
|
||||
right: Option<ColumnSpaceUsage>,
|
||||
) -> Option<ColumnSpaceUsage> {
|
||||
match (left, right) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.merge(&rhs)),
|
||||
(Some(space), None) | (None, Some(space)) => Some(space),
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::index::Index;
|
||||
use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT};
|
||||
use crate::schema::{Schema, FAST, INDEXED, STORED, TEXT};
|
||||
use crate::space_usage::PerFieldSpaceUsage;
|
||||
use crate::{IndexWriter, Term};
|
||||
|
||||
@@ -307,17 +373,17 @@ mod test {
|
||||
|
||||
fn expect_single_field(
|
||||
field_space: &PerFieldSpaceUsage,
|
||||
field: &Field,
|
||||
field: &str,
|
||||
min_size: u64,
|
||||
max_size: u64,
|
||||
) {
|
||||
assert!(field_space.total() >= min_size);
|
||||
assert!(field_space.total() <= max_size);
|
||||
assert_eq!(
|
||||
vec![(field, field_space.total())],
|
||||
vec![(field.to_string(), field_space.total())],
|
||||
field_space
|
||||
.fields()
|
||||
.map(|(x, y)| (x, y.total()))
|
||||
.map(|usage| (usage.field_name().to_string(), usage.total()))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
@@ -327,6 +393,7 @@ mod test {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let name = schema_builder.add_u64_field("name", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let field_name = schema.get_field_name(name).to_string();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
@@ -349,11 +416,11 @@ mod test {
|
||||
|
||||
assert_eq!(4, segment.num_docs());
|
||||
|
||||
expect_single_field(segment.termdict(), &name, 1, 512);
|
||||
expect_single_field(segment.postings(), &name, 1, 512);
|
||||
expect_single_field(segment.termdict(), &field_name, 1, 512);
|
||||
expect_single_field(segment.postings(), &field_name, 1, 512);
|
||||
assert_eq!(segment.positions().total(), 0);
|
||||
expect_single_field(segment.fast_fields(), &name, 1, 512);
|
||||
expect_single_field(segment.fieldnorms(), &name, 1, 512);
|
||||
expect_single_field(segment.fast_fields(), &field_name, 1, 512);
|
||||
expect_single_field(segment.fieldnorms(), &field_name, 1, 512);
|
||||
// TODO: understand why the following fails
|
||||
// assert_eq!(0, segment.store().total());
|
||||
assert_eq!(segment.deletes(), 0);
|
||||
@@ -365,6 +432,7 @@ mod test {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let name = schema_builder.add_text_field("name", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let field_name = schema.get_field_name(name).to_string();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
@@ -389,11 +457,11 @@ mod test {
|
||||
|
||||
assert_eq!(4, segment.num_docs());
|
||||
|
||||
expect_single_field(segment.termdict(), &name, 1, 512);
|
||||
expect_single_field(segment.postings(), &name, 1, 512);
|
||||
expect_single_field(segment.positions(), &name, 1, 512);
|
||||
expect_single_field(segment.termdict(), &field_name, 1, 512);
|
||||
expect_single_field(segment.postings(), &field_name, 1, 512);
|
||||
expect_single_field(segment.positions(), &field_name, 1, 512);
|
||||
assert_eq!(segment.fast_fields().total(), 0);
|
||||
expect_single_field(segment.fieldnorms(), &name, 1, 512);
|
||||
expect_single_field(segment.fieldnorms(), &field_name, 1, 512);
|
||||
// TODO: understand why the following fails
|
||||
// assert_eq!(0, segment.store().total());
|
||||
assert_eq!(segment.deletes(), 0);
|
||||
@@ -429,10 +497,15 @@ mod test {
|
||||
assert_eq!(4, segment.num_docs());
|
||||
|
||||
assert_eq!(segment.termdict().total(), 0);
|
||||
assert!(segment.termdict().fields().next().is_none());
|
||||
assert_eq!(segment.postings().total(), 0);
|
||||
assert!(segment.postings().fields().next().is_none());
|
||||
assert_eq!(segment.positions().total(), 0);
|
||||
assert!(segment.positions().fields().next().is_none());
|
||||
assert_eq!(segment.fast_fields().total(), 0);
|
||||
assert!(segment.fast_fields().fields().next().is_none());
|
||||
assert_eq!(segment.fieldnorms().total(), 0);
|
||||
assert!(segment.fieldnorms().fields().next().is_none());
|
||||
assert!(segment.store().total() > 0);
|
||||
assert!(segment.store().total() < 512);
|
||||
assert_eq!(segment.deletes(), 0);
|
||||
@@ -444,6 +517,7 @@ mod test {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let name = schema_builder.add_u64_field("name", INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let field_name = schema.get_field_name(name).to_string();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
@@ -474,11 +548,11 @@ mod test {
|
||||
|
||||
assert_eq!(2, segment_space_usage.num_docs());
|
||||
|
||||
expect_single_field(segment_space_usage.termdict(), &name, 1, 512);
|
||||
expect_single_field(segment_space_usage.postings(), &name, 1, 512);
|
||||
expect_single_field(segment_space_usage.termdict(), &field_name, 1, 512);
|
||||
expect_single_field(segment_space_usage.postings(), &field_name, 1, 512);
|
||||
assert_eq!(segment_space_usage.positions().total(), 0u64);
|
||||
assert_eq!(segment_space_usage.fast_fields().total(), 0u64);
|
||||
expect_single_field(segment_space_usage.fieldnorms(), &name, 1, 512);
|
||||
expect_single_field(segment_space_usage.fieldnorms(), &field_name, 1, 512);
|
||||
assert!(segment_space_usage.deletes() > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -102,6 +102,7 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
const NUM_DOCS: usize = 1_000;
|
||||
|
||||
#[test]
|
||||
fn test_doc_store_iter_with_delete_bug_1077() -> crate::Result<()> {
|
||||
// this will cover deletion of the first element in a checkpoint
|
||||
@@ -113,7 +114,7 @@ pub(crate) mod tests {
|
||||
let directory = RamDirectory::create();
|
||||
let store_wrt = directory.open_write(path)?;
|
||||
let schema =
|
||||
write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::Lz4, BLOCK_SIZE, true);
|
||||
write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::default(), BLOCK_SIZE, true);
|
||||
let field_title = schema.get_field("title").unwrap();
|
||||
let store_file = directory.open_read(path)?;
|
||||
let store = StoreReader::open(store_file, 10)?;
|
||||
|
||||
@@ -465,7 +465,7 @@ mod tests {
|
||||
let directory = RamDirectory::create();
|
||||
let path = Path::new("store");
|
||||
let writer = directory.open_write(path)?;
|
||||
let schema = write_lorem_ipsum_store(writer, 500, Compressor::default(), BLOCK_SIZE, true);
|
||||
let schema = write_lorem_ipsum_store(writer, 500, Compressor::None, BLOCK_SIZE, true);
|
||||
let title = schema.get_field("title").unwrap();
|
||||
let store_file = directory.open_read(path)?;
|
||||
let store = StoreReader::open(store_file, DOCSTORE_CACHE_CAPACITY)?;
|
||||
@@ -499,7 +499,7 @@ mod tests {
|
||||
assert_eq!(store.cache_stats().cache_hits, 1);
|
||||
assert_eq!(store.cache_stats().cache_misses, 2);
|
||||
|
||||
assert_eq!(store.cache.peek_lru(), Some(11207));
|
||||
assert_eq!(store.cache.peek_lru(), Some(232206));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -132,13 +132,14 @@ mod regex_tokenizer;
|
||||
mod remove_long;
|
||||
mod simple_tokenizer;
|
||||
mod split_compound_words;
|
||||
mod stemmer;
|
||||
mod stop_word_filter;
|
||||
mod tokenized_string;
|
||||
mod tokenizer;
|
||||
mod tokenizer_manager;
|
||||
mod whitespace_tokenizer;
|
||||
|
||||
#[cfg(feature = "stemmer")]
|
||||
mod stemmer;
|
||||
pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer};
|
||||
|
||||
pub use self::alphanum_only::AlphaNumOnlyFilter;
|
||||
@@ -151,6 +152,7 @@ pub use self::regex_tokenizer::RegexTokenizer;
|
||||
pub use self::remove_long::RemoveLongFilter;
|
||||
pub use self::simple_tokenizer::{SimpleTokenStream, SimpleTokenizer};
|
||||
pub use self::split_compound_words::SplitCompoundWords;
|
||||
#[cfg(feature = "stemmer")]
|
||||
pub use self::stemmer::{Language, Stemmer};
|
||||
pub use self::stop_word_filter::StopWordFilter;
|
||||
pub use self::tokenized_string::{PreTokenizedStream, PreTokenizedString};
|
||||
@@ -167,10 +169,7 @@ pub const MAX_TOKEN_LEN: usize = u16::MAX as usize - 5;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::{
|
||||
Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, TokenizerManager,
|
||||
};
|
||||
use crate::tokenizer::TextAnalyzer;
|
||||
use super::{Token, TokenizerManager};
|
||||
|
||||
/// This is a function that can be used in tests and doc tests
|
||||
/// to assert a token's correctness.
|
||||
@@ -205,59 +204,15 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_en_tokenizer() {
|
||||
fn test_tokenizer_does_not_exist() {
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
assert!(tokenizer_manager.get("en_doesnotexist").is_none());
|
||||
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap();
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
{
|
||||
let mut add_token = |token: &Token| {
|
||||
tokens.push(token.clone());
|
||||
};
|
||||
en_tokenizer
|
||||
.token_stream("Hello, happy tax payer!")
|
||||
.process(&mut add_token);
|
||||
}
|
||||
|
||||
assert_eq!(tokens.len(), 4);
|
||||
assert_token(&tokens[0], 0, "hello", 0, 5);
|
||||
assert_token(&tokens[1], 1, "happi", 7, 12);
|
||||
assert_token(&tokens[2], 2, "tax", 13, 16);
|
||||
assert_token(&tokens[3], 3, "payer", 17, 22);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_en_tokenizer() {
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
tokenizer_manager.register(
|
||||
"el_stem",
|
||||
TextAnalyzer::builder(SimpleTokenizer::default())
|
||||
.filter(RemoveLongFilter::limit(40))
|
||||
.filter(LowerCaser)
|
||||
.filter(Stemmer::new(Language::Greek))
|
||||
.build(),
|
||||
);
|
||||
let mut en_tokenizer = tokenizer_manager.get("el_stem").unwrap();
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
{
|
||||
let mut add_token = |token: &Token| {
|
||||
tokens.push(token.clone());
|
||||
};
|
||||
en_tokenizer
|
||||
.token_stream("Καλημέρα, χαρούμενε φορολογούμενε!")
|
||||
.process(&mut add_token);
|
||||
}
|
||||
|
||||
assert_eq!(tokens.len(), 3);
|
||||
assert_token(&tokens[0], 0, "καλημερ", 0, 16);
|
||||
assert_token(&tokens[1], 1, "χαρουμεν", 18, 36);
|
||||
assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_empty() {
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap();
|
||||
let mut en_tokenizer = tokenizer_manager.get("default").unwrap();
|
||||
{
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
use super::{Token, TokenFilter, TokenStream, Tokenizer};
|
||||
|
||||
/// Available stemmer languages.
|
||||
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone, Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum Language {
|
||||
Arabic,
|
||||
@@ -142,3 +142,60 @@ impl<T: TokenStream> TokenStream for StemmerTokenStream<T> {
|
||||
self.tail.token_mut()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokenizer_api::Token;
|
||||
|
||||
use super::*;
|
||||
use crate::tokenizer::tests::assert_token;
|
||||
use crate::tokenizer::{LowerCaser, SimpleTokenizer, TextAnalyzer, TokenizerManager};
|
||||
|
||||
#[test]
|
||||
fn test_en_stem() {
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap();
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
{
|
||||
let mut add_token = |token: &Token| {
|
||||
tokens.push(token.clone());
|
||||
};
|
||||
en_tokenizer
|
||||
.token_stream("Dogs are the bests!")
|
||||
.process(&mut add_token);
|
||||
}
|
||||
|
||||
assert_eq!(tokens.len(), 4);
|
||||
assert_token(&tokens[0], 0, "dog", 0, 4);
|
||||
assert_token(&tokens[1], 1, "are", 5, 8);
|
||||
assert_token(&tokens[2], 2, "the", 9, 12);
|
||||
assert_token(&tokens[3], 3, "best", 13, 18);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_en_stem() {
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
tokenizer_manager.register(
|
||||
"el_stem",
|
||||
TextAnalyzer::builder(SimpleTokenizer::default())
|
||||
.filter(LowerCaser)
|
||||
.filter(Stemmer::new(Language::Greek))
|
||||
.build(),
|
||||
);
|
||||
let mut el_tokenizer = tokenizer_manager.get("el_stem").unwrap();
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
{
|
||||
let mut add_token = |token: &Token| {
|
||||
tokens.push(token.clone());
|
||||
};
|
||||
el_tokenizer
|
||||
.token_stream("Καλημέρα, χαρούμενε φορολογούμενε!")
|
||||
.process(&mut add_token);
|
||||
}
|
||||
|
||||
assert_eq!(tokens.len(), 3);
|
||||
assert_token(&tokens[0], 0, "καλημερ", 0, 16);
|
||||
assert_token(&tokens[1], 1, "χαρουμεν", 18, 36);
|
||||
assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user