diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 7201897d7..95167ba41 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -15,11 +15,11 @@ jobs: steps: - uses: actions/checkout@v4 - name: Install Rust - run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview + run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview - uses: Swatinem/rust-cache@v2 - uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info + run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 continue-on-error: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 13080f11d..3a6ba2df9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,11 +39,11 @@ jobs: - name: Check Formatting run: cargo +nightly fmt --all -- --check - + - name: Check Stable Compilation run: cargo build --all-features - + - name: Check Bench Compilation run: cargo +nightly bench --no-run --profile=dev --all-features @@ -59,10 +59,10 @@ jobs: strategy: matrix: - features: [ - { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" }, - { label: "quickwit", flags: "mmap,quickwit,failpoints" } - ] + features: + - { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" } + - { label: "quickwit", flags: "mmap,quickwit,failpoints" } + - { label: "none", flags: "" } name: test-${{ matrix.features.label}} @@ -80,7 +80,21 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests - run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace + run: | + # if matrix.feature.flags is empty then run on --lib to avoid compiling examples + # (as most of them rely on mmap) otherwise run all + if [ -z "${{ matrix.features.flags }}" ]; then + cargo +stable nextest run --lib --no-default-features --verbose --workspace + else + cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace + fi - name: Run doctests - run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace + run: | + # if matrix.feature.flags is empty then run on --lib to avoid compiling examples + # (as most of them rely on mmap) otherwise run all + if [ -z "${{ matrix.features.flags }}" ]; then + echo "no doctest for no feature flag" + else + cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace + fi diff --git a/Cargo.toml b/Cargo.toml index 32d7bd990..476117656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ regex = { version = "1.5.5", default-features = false, features = [ aho-corasick = "1.0" tantivy-fst = "0.5" memmap2 = { version = "0.9.0", optional = true } -lz4_flex = { version = "0.11", default-features = false, optional = true } +lz4_flex = { version = "0.12", default-features = false, optional = true } zstd = { version = "0.13", optional = true, default-features = false } tempfile = { version = "3.12.0", optional = true } log = "0.4.16" @@ -37,9 +37,9 @@ fs4 = { version = "0.13.1", optional = true } levenshtein_automata = "0.2.1" uuid = { version = "1.0.0", features = ["v4", "serde"] } crossbeam-channel = "0.5.4" -rust-stemmers = "1.2.0" +rust-stemmers = { version = "1.2.0", optional = true } downcast-rs = "2.0.1" -bitpacking = { version = "0.9.2", default-features = false, features = [ +bitpacking = { version = "0.9.3", default-features = false, features = [ "bitpacker4x", ] } census = "0.4.2" @@ -50,7 +50,7 @@ fail = { version = "0.5.0", optional = true } time = { version = "0.3.35", features = ["serde-well-known"] } smallvec = "1.8.0" rayon = "1.5.2" -lru = "0.12.0" +lru = "0.16.3" fastdivide = "0.4.0" itertools = "0.14.0" measure_time = "0.9.0" @@ -75,17 +75,17 @@ typetag = "0.2.21" winapi = "0.3.9" [dev-dependencies] -binggan = "0.14.0" -rand = "0.8.5" +binggan = "0.14.2" +rand = "0.9" maplit = "1.0.2" matches = "0.1.9" pretty_assertions = "1.2.1" -proptest = "1.0.0" +proptest = "1.7.0" test-log = "0.2.10" futures = "0.3.21" paste = "1.0.11" more-asserts = "0.3.1" -rand_distr = "0.4.3" +rand_distr = "0.5" time = { version = "0.3.10", features = ["serde-well-known", "macros"] } postcard = { version = "1.0.4", features = [ "use-std", @@ -113,7 +113,8 @@ debug-assertions = true overflow-checks = true [features] -default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"] +default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"] +stemmer = ["rust-stemmers"] mmap = ["fs4", "tempfile", "memmap2"] stopwords = [] @@ -173,6 +174,18 @@ harness = false name = "exists_json" harness = false +[[bench]] +name = "range_query" +harness = false + [[bench]] name = "and_or_queries" harness = false + +[[bench]] +name = "range_queries" +harness = false + +[[bench]] +name = "bool_queries_with_range" +harness = false diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index a4115b604..9313cca7a 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -1,8 +1,8 @@ use binggan::plugins::PeakMemAllocPlugin; use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM}; -use rand::distributions::WeightedIndex; -use rand::prelude::SliceRandom; +use rand::distr::weighted::WeightedIndex; use rand::rngs::StdRng; +use rand::seq::IndexedRandom; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; use serde_json::json; @@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup) { register!(group, stats_f64); register!(group, extendedstats_f64); register!(group, percentiles_f64); - register!(group, terms_few); + register!(group, terms_7); register!(group, terms_all_unique); - register!(group, terms_many); + register!(group, terms_150_000); register!(group, terms_many_top_1000); register!(group, terms_many_order_by_term); register!(group, terms_many_with_top_hits); register!(group, terms_all_unique_with_avg_sub_agg); register!(group, terms_many_with_avg_sub_agg); - register!(group, terms_few_with_avg_sub_agg); register!(group, terms_status_with_avg_sub_agg); - register!(group, terms_status); - register!(group, terms_few_with_histogram); register!(group, terms_status_with_histogram); + register!(group, terms_zipf_1000); + register!(group, terms_zipf_1000_with_histogram); + register!(group, terms_zipf_1000_with_avg_sub_agg); register!(group, terms_many_json_mixed_type_with_avg_sub_agg); register!(group, cardinality_agg); - register!(group, terms_few_with_cardinality_agg); + register!(group, terms_status_with_cardinality_agg); register!(group, range_agg); register!(group, range_agg_with_avg_sub_agg); - register!(group, range_agg_with_term_agg_few); + register!(group, range_agg_with_term_agg_status); register!(group, range_agg_with_term_agg_many); register!(group, histogram); register!(group, histogram_hard_bounds); register!(group, histogram_with_avg_sub_agg); - register!(group, histogram_with_term_agg_few); + register!(group, histogram_with_term_agg_status); register!(group, avg_and_range_with_avg_sub_agg); // Filter aggregation benchmarks @@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn terms_few_with_cardinality_agg(index: &Index) { +fn terms_status_with_cardinality_agg(index: &Index) { let agg_req = json!({ "my_texts": { - "terms": { "field": "text_few_terms" }, + "terms": { "field": "text_few_terms_status" }, "aggs": { "cardinality": { "cardinality": { @@ -175,13 +175,7 @@ fn terms_few_with_cardinality_agg(index: &Index) { execute_agg(index, agg_req); } -fn terms_few(index: &Index) { - let agg_req = json!({ - "my_texts": { "terms": { "field": "text_few_terms" } }, - }); - execute_agg(index, agg_req); -} -fn terms_status(index: &Index) { +fn terms_7(index: &Index) { let agg_req = json!({ "my_texts": { "terms": { "field": "text_few_terms_status" } }, }); @@ -194,7 +188,7 @@ fn terms_all_unique(index: &Index) { execute_agg(index, agg_req); } -fn terms_many(index: &Index) { +fn terms_150_000(index: &Index) { let agg_req = json!({ "my_texts": { "terms": { "field": "text_many_terms" } }, }); @@ -253,17 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn terms_few_with_histogram(index: &Index) { - let agg_req = json!({ - "my_texts": { - "terms": { "field": "text_few_terms" }, - "aggs": { - "histo": {"histogram": { "field": "score_f64", "interval": 10 }} - } - } - }); - execute_agg(index, agg_req); -} fn terms_status_with_histogram(index: &Index) { let agg_req = json!({ "my_texts": { @@ -276,17 +259,18 @@ fn terms_status_with_histogram(index: &Index) { execute_agg(index, agg_req); } -fn terms_few_with_avg_sub_agg(index: &Index) { +fn terms_zipf_1000_with_histogram(index: &Index) { let agg_req = json!({ "my_texts": { - "terms": { "field": "text_few_terms" }, + "terms": { "field": "text_1000_terms_zipf" }, "aggs": { - "average_f64": { "avg": { "field": "score_f64" } } + "histo": {"histogram": { "field": "score_f64", "interval": 10 }} } - }, + } }); execute_agg(index, agg_req); } + fn terms_status_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "my_texts": { @@ -299,6 +283,25 @@ fn terms_status_with_avg_sub_agg(index: &Index) { execute_agg(index, agg_req); } +fn terms_zipf_1000_with_avg_sub_agg(index: &Index) { + let agg_req = json!({ + "my_texts": { + "terms": { "field": "text_1000_terms_zipf" }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } + } + }, + }); + execute_agg(index, agg_req); +} + +fn terms_zipf_1000(index: &Index) { + let agg_req = json!({ + "my_texts": { "terms": { "field": "text_1000_terms_zipf" } }, + }); + execute_agg(index, agg_req); +} + fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "my_texts": { @@ -354,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) { execute_agg(index, agg_req); } -fn range_agg_with_term_agg_few(index: &Index) { +fn range_agg_with_term_agg_status(index: &Index) { let agg_req = json!({ "rangef64": { "range": { @@ -369,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) { ] }, "aggs": { - "my_texts": { "terms": { "field": "text_few_terms" } }, + "my_texts": { "terms": { "field": "text_few_terms_status" } }, } }, }); @@ -425,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn histogram_with_term_agg_few(index: &Index) { +fn histogram_with_term_agg_status(index: &Index) { let agg_req = json!({ "rangef64": { "histogram": { "field": "score_f64", "interval": 10 }, "aggs": { - "my_texts": { "terms": { "field": "text_few_terms" } } + "my_texts": { "terms": { "field": "text_few_terms_status" } } } } }); @@ -475,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector { } fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { + // Flag to use existing index + let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok(); + if reuse_index && std::path::Path::new("agg_bench").exists() { + return Index::open_in_dir("agg_bench"); + } + // crreate dir + std::fs::create_dir_all("agg_bench")?; let mut schema_builder = Schema::builder(); let text_fieldtype = tantivy::schema::TextOptions::default() .set_indexing_options( @@ -486,24 +496,44 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { let text_field_all_unique_terms = schema_builder.add_text_field("text_all_unique_terms", STRING | FAST); let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST); - let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST); - let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST); let text_field_few_terms_status = schema_builder.add_text_field("text_few_terms_status", STRING | FAST); + let text_field_1000_terms_zipf = + schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST); let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast(); let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); - let index = Index::create_from_tempdir(schema_builder.build())?; - let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"]; - // Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare. - let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap(); + // use tmp dir + let index = if reuse_index { + Index::create_in_dir("agg_bench", schema_builder.build())? + } else { + Index::create_from_tempdir(schema_builder.build())? + }; + // Approximate log proportions + let status_field_data = [ + ("INFO", 8000), + ("ERROR", 300), + ("WARN", 1200), + ("DEBUG", 500), + ("OK", 500), + ("CRITICAL", 20), + ("EMERGENCY", 1), + ]; + let log_level_distribution = + WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap(); let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap(); let many_terms_data = (0..150_000) .map(|num| format!("author{num}")) .collect::>(); + + // Prepare 1000 unique terms sampled using a Zipf distribution. + // Exponent ~1.1 approximates top-20 terms covering around ~20%. + let terms_1000: Vec = (1..=1000).map(|i| format!("term_{i}")).collect(); + let zipf_1000 = rand_distr::Zipf::new(1000.0, 1.1f64).unwrap(); + { let mut rng = StdRng::from_seed([1u8; 32]); let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?; @@ -513,8 +543,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { index_writer.add_document(doc!())?; } if cardinality == Cardinality::Multivalued { - let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)]; - let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)]; + let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0; + let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0; + let idx_a = zipf_1000.sample(&mut rng) as usize - 1; + let idx_b = zipf_1000.sample(&mut rng) as usize - 1; + let term_1000_a = &terms_1000[idx_a]; + let term_1000_b = &terms_1000[idx_b]; index_writer.add_document(doc!( json_field => json!({"mixed_type": 10.0}), json_field => json!({"mixed_type": 10.0}), @@ -524,10 +558,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { text_field_all_unique_terms => "coolo", text_field_many_terms => "cool", text_field_many_terms => "cool", - text_field_few_terms => "cool", - text_field_few_terms => "cool", text_field_few_terms_status => log_level_sample_a, text_field_few_terms_status => log_level_sample_b, + text_field_1000_terms_zipf => term_1000_a.as_str(), + text_field_1000_terms_zipf => term_1000_b.as_str(), score_field => 1u64, score_field => 1u64, score_field_f64 => lg_norm.sample(&mut rng), @@ -542,8 +576,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { } let _val_max = 1_000_000.0; for _ in 0..doc_with_value { - let val: f64 = rng.gen_range(0.0..1_000_000.0); - let json = if rng.gen_bool(0.1) { + let val: f64 = rng.random_range(0.0..1_000_000.0); + let json = if rng.random_bool(0.1) { // 10% are numeric values json!({ "mixed_type": val }) } else { @@ -552,10 +586,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { index_writer.add_document(doc!( text_field => "cool", json_field => json, - text_field_all_unique_terms => format!("unique_term_{}", rng.gen::()), + text_field_all_unique_terms => format!("unique_term_{}", rng.random::()), text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(), - text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(), - text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)], + text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0, + text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(), score_field => val as u64, score_field_f64 => lg_norm.sample(&mut rng), score_field_i64 => val as i64, @@ -607,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) { "avg_score": { "avg": { "field": "score" } }, "stats_score": { "stats": { "field": "score_f64" } }, "terms_text": { - "terms": { "field": "text_few_terms" } + "terms": { "field": "text_few_terms_status" } } } } @@ -623,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) { "avg_score": { "avg": { "field": "score" } }, "stats_score": { "stats": { "field": "score_f64" } }, "terms_text": { - "terms": { "field": "text_few_terms" } + "terms": { "field": "text_few_terms_status" } } } } diff --git a/benches/and_or_queries.rs b/benches/and_or_queries.rs index 805061c18..5dd213685 100644 --- a/benches/and_or_queries.rs +++ b/benches/and_or_queries.rs @@ -55,29 +55,29 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench { let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap(); for _ in 0..num_docs { - let has_a = rng.gen_bool(p_a as f64); - let has_b = rng.gen_bool(p_b as f64); - let has_c = rng.gen_bool(p_c as f64); - let score = rng.gen_range(0u64..100u64); - let score2 = rng.gen_range(0u64..100_000u64); + let has_a = rng.random_bool(p_a as f64); + let has_b = rng.random_bool(p_b as f64); + let has_c = rng.random_bool(p_c as f64); + let score = rng.random_range(0u64..100u64); + let score2 = rng.random_range(0u64..100_000u64); let mut title_tokens: Vec<&str> = Vec::new(); let mut body_tokens: Vec<&str> = Vec::new(); if has_a { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("a"); } else { body_tokens.push("a"); } } if has_b { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("b"); } else { body_tokens.push("b"); } } if has_c { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("c"); } else { body_tokens.push("c"); diff --git a/benches/bool_queries_with_range.rs b/benches/bool_queries_with_range.rs new file mode 100644 index 000000000..9b2849300 --- /dev/null +++ b/benches/bool_queries_with_range.rs @@ -0,0 +1,288 @@ +use binggan::{black_box, BenchGroup, BenchRunner}; +use rand::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs}; +use tantivy::query::{Query, QueryParser}; +use tantivy::schema::{Schema, FAST, INDEXED, TEXT}; +use tantivy::{doc, Index, Order, ReloadPolicy, Searcher}; + +#[derive(Clone)] +struct BenchIndex { + #[allow(dead_code)] + index: Index, + searcher: Searcher, + query_parser: QueryParser, +} + +fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex { + // Unified schema + let mut schema_builder = Schema::builder(); + let f_title = schema_builder.add_text_field("title", TEXT); + let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED); + let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED); + let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST); + let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + + // Populate index with stable RNG for reproducibility. + let mut rng = StdRng::from_seed([7u8; 32]); + + { + let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap(); + + match distribution { + "dense" => { + for doc_id in 0..num_docs { + // Always add title to avoid empty documents + let title_token = if rng.random_bool(p_title_a as f64) { + "a" + } else { + "b" + }; + + let num_rand = rng.random_range(0u64..1000u64); + + let num_asc = (doc_id / 10000) as u64; + + writer + .add_document(doc!( + f_title=>title_token, + f_num_rand=>num_rand, + f_num_asc=>num_asc, + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + "sparse" => { + for doc_id in 0..num_docs { + // Always add title to avoid empty documents + let title_token = if rng.random_bool(p_title_a as f64) { + "a" + } else { + "b" + }; + + let num_rand = rng.random_range(0u64..10000000u64); + + let num_asc = doc_id as u64; + + writer + .add_document(doc!( + f_title=>title_token, + f_num_rand=>num_rand, + f_num_asc=>num_asc, + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + _ => { + panic!("Unsupported distribution type"); + } + } + writer.commit().unwrap(); + } + + // Prepare reader/searcher once. + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::Manual) + .try_into() + .unwrap(); + let searcher = reader.searcher(); + + // Build query parser for title field + let qp_title = QueryParser::for_index(&index, vec![f_title]); + + BenchIndex { + index, + searcher, + query_parser: qp_title, + } +} + +fn main() { + // Prepare corpora with varying scenarios + let scenarios = vec![ + ( + "dense and 99% a".to_string(), + 10_000_000, + 0.99, + "dense", + 0, + 9, + ), + ( + "dense and 99% a".to_string(), + 10_000_000, + 0.99, + "dense", + 990, + 999, + ), + ( + "sparse and 99% a".to_string(), + 10_000_000, + 0.99, + "sparse", + 0, + 9, + ), + ( + "sparse and 99% a".to_string(), + 10_000_000, + 0.99, + "sparse", + 9_999_990, + 9_999_999, + ), + ]; + + let mut runner = BenchRunner::new(); + for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios { + // Build index for this scenario + let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution); + + // Create benchmark group + let mut group = runner.new_group(); + + // Now set the name (this moves scenario_id) + group.set_name(scenario_id); + + // Define all four field types + let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"]; + + // Define the three terms we want to test with + let terms = ["a", "b", "z"]; + + // Generate all combinations of terms and field names + let mut queries = Vec::new(); + for &term in &terms { + for &field_name in &field_names { + let query_str = format!( + "{} AND {}:[{} TO {}]", + term, field_name, range_low, range_high + ); + queries.push((query_str, field_name.to_string())); + } + } + + let query_str = format!( + "{}:[{} TO {}] AND {}:[{} TO {}]", + "num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high + ); + queries.push((query_str, "num_asc_fast".to_string())); + + // Run all benchmark tasks for each query and its corresponding field name + for (query_str, field_name) in queries { + run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name); + } + + group.run(); + } +} + +/// Run all benchmark tasks for a given query string and field name +fn run_benchmark_tasks( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query_str: &str, + field_name: &str, +) { + // Test count + add_bench_task(bench_group, bench_index, query_str, Count, "count"); + + // Test all results + add_bench_task( + bench_group, + bench_index, + query_str, + DocSetCollector, + "all results", + ); + + // Test top 100 by the field (if it's a FAST field) + if field_name.ends_with("_fast") { + // Ascending order + { + let collector_name = format!("top100_by_{}_asc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task( + bench_group, + bench_index, + query_str, + TopDocs::with_limit(100).order_by_fast_field::(field_name_owned, Order::Asc), + &collector_name, + ); + } + + // Descending order + { + let collector_name = format!("top100_by_{}_desc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task( + bench_group, + bench_index, + query_str, + TopDocs::with_limit(100).order_by_fast_field::(field_name_owned, Order::Desc), + &collector_name, + ); + } + } +} + +fn add_bench_task( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query_str: &str, + collector: C, + collector_name: &str, +) { + let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name); + let query = bench_index.query_parser.parse_query(query_str).unwrap(); + let search_task = SearchTask { + searcher: bench_index.searcher.clone(), + collector, + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +struct SearchTask { + searcher: Searcher, + collector: C, + query: Box, +} + +impl SearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let result = self.searcher.search(&self.query, &self.collector).unwrap(); + if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::() { + *count + } else if let Some(top_docs) = (&result as &dyn std::any::Any) + .downcast_ref::, tantivy::DocAddress)>>() + { + top_docs.len() + } else if let Some(top_docs) = + (&result as &dyn std::any::Any).downcast_ref::>() + { + top_docs.len() + } else if let Some(doc_set) = (&result as &dyn std::any::Any) + .downcast_ref::>() + { + doc_set.len() + } else { + eprintln!( + "Unknown collector result type: {:?}", + std::any::type_name::() + ); + 0 + } + } +} diff --git a/benches/range_queries.rs b/benches/range_queries.rs new file mode 100644 index 000000000..c8095a01b --- /dev/null +++ b/benches/range_queries.rs @@ -0,0 +1,365 @@ +use std::ops::Bound; + +use binggan::{black_box, BenchGroup, BenchRunner}; +use rand::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tantivy::collector::{Count, DocSetCollector, TopDocs}; +use tantivy::query::RangeQuery; +use tantivy::schema::{Schema, FAST, INDEXED}; +use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term}; + +#[derive(Clone)] +struct BenchIndex { + #[allow(dead_code)] + index: Index, + searcher: Searcher, +} + +fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex { + // Schema with fast fields only + let mut schema_builder = Schema::builder(); + let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST); + let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + + // Populate index with stable RNG for reproducibility. + let mut rng = StdRng::from_seed([7u8; 32]); + + { + let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap(); + + match distribution { + "dense" => { + for doc_id in 0..num_docs { + let num_rand = rng.random_range(0u64..1000u64); + let num_asc = (doc_id / 10000) as u64; + + writer + .add_document(doc!( + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + "sparse" => { + for doc_id in 0..num_docs { + let num_rand = rng.random_range(0u64..10000000u64); + let num_asc = doc_id as u64; + + writer + .add_document(doc!( + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + _ => { + panic!("Unsupported distribution type"); + } + } + writer.commit().unwrap(); + } + + // Prepare reader/searcher once. + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::Manual) + .try_into() + .unwrap(); + let searcher = reader.searcher(); + + BenchIndex { index, searcher } +} + +fn main() { + // Prepare corpora with varying scenarios + let scenarios = vec![ + // Dense distribution - random values in small range (0-999) + ( + "dense_values_search_low_value_range".to_string(), + 10_000_000, + "dense", + 0, + 9, + ), + ( + "dense_values_search_high_value_range".to_string(), + 10_000_000, + "dense", + 990, + 999, + ), + ( + "dense_values_search_out_of_range".to_string(), + 10_000_000, + "dense", + 1000, + 1002, + ), + ( + "sparse_values_search_low_value_range".to_string(), + 10_000_000, + "sparse", + 0, + 9, + ), + ( + "sparse_values_search_high_value_range".to_string(), + 10_000_000, + "sparse", + 9_999_990, + 9_999_999, + ), + ( + "sparse_values_search_out_of_range".to_string(), + 10_000_000, + "sparse", + 10_000_000, + 10_000_002, + ), + ]; + + let mut runner = BenchRunner::new(); + for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios { + // Build index for this scenario + let bench_index = build_shared_indices(n, num_rand_distribution); + + // Create benchmark group + let mut group = runner.new_group(); + + // Now set the name (this moves scenario_id) + group.set_name(scenario_id); + + // Define fast field types + let field_names = ["num_rand_fast", "num_asc_fast"]; + + // Generate range queries for fast fields + for &field_name in &field_names { + // Create the range query + let field = bench_index.searcher.schema().get_field(field_name).unwrap(); + let lower_term = Term::from_field_u64(field, range_low); + let upper_term = Term::from_field_u64(field, range_high); + + let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term)); + + run_benchmark_tasks( + &mut group, + &bench_index, + query, + field_name, + range_low, + range_high, + ); + } + + group.run(); + } +} + +/// Run all benchmark tasks for a given range query and field name +fn run_benchmark_tasks( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + field_name: &str, + range_low: u64, + range_high: u64, +) { + // Test count + add_bench_task_count( + bench_group, + bench_index, + query.clone(), + "count", + field_name, + range_low, + range_high, + ); + + // Test top 100 by the field (ascending order) + { + let collector_name = format!("top100_by_{}_asc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task_top100_asc( + bench_group, + bench_index, + query.clone(), + &collector_name, + field_name, + range_low, + range_high, + field_name_owned, + ); + } + + // Test top 100 by the field (descending order) + { + let collector_name = format!("top100_by_{}_desc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task_top100_desc( + bench_group, + bench_index, + query, + &collector_name, + field_name, + range_low, + range_high, + field_name_owned, + ); + } +} + +fn add_bench_task_count( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = CountSearchTask { + searcher: bench_index.searcher.clone(), + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_docset( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = DocSetSearchTask { + searcher: bench_index.searcher.clone(), + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_top100_asc( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, + field_name_owned: String, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = Top100AscSearchTask { + searcher: bench_index.searcher.clone(), + query, + field_name: field_name_owned, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_top100_desc( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, + field_name_owned: String, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = Top100DescSearchTask { + searcher: bench_index.searcher.clone(), + query, + field_name: field_name_owned, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +struct CountSearchTask { + searcher: Searcher, + query: RangeQuery, +} + +impl CountSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + self.searcher.search(&self.query, &Count).unwrap() + } +} + +struct DocSetSearchTask { + searcher: Searcher, + query: RangeQuery, +} + +impl DocSetSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let result = self.searcher.search(&self.query, &DocSetCollector).unwrap(); + result.len() + } +} + +struct Top100AscSearchTask { + searcher: Searcher, + query: RangeQuery, + field_name: String, +} + +impl Top100AscSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let collector = + TopDocs::with_limit(100).order_by_fast_field::(&self.field_name, Order::Asc); + let result = self.searcher.search(&self.query, &collector).unwrap(); + for (_score, doc_address) in &result { + let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap(); + } + result.len() + } +} + +struct Top100DescSearchTask { + searcher: Searcher, + query: RangeQuery, + field_name: String, +} + +impl Top100DescSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let collector = + TopDocs::with_limit(100).order_by_fast_field::(&self.field_name, Order::Desc); + let result = self.searcher.search(&self.query, &collector).unwrap(); + for (_score, doc_address) in &result { + let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap(); + } + result.len() + } +} diff --git a/benches/range_query.rs b/benches/range_query.rs new file mode 100644 index 000000000..e0feddd66 --- /dev/null +++ b/benches/range_query.rs @@ -0,0 +1,260 @@ +use std::fmt::Display; +use std::net::Ipv6Addr; +use std::ops::RangeInclusive; + +use binggan::plugins::PeakMemAllocPlugin; +use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM}; +use columnar::MonotonicallyMappableToU128; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tantivy::collector::{Count, TopDocs}; +use tantivy::query::QueryParser; +use tantivy::schema::*; +use tantivy::{doc, Index}; + +#[global_allocator] +pub static GLOBAL: &PeakMemAlloc = &INSTRUMENTED_SYSTEM; + +fn main() { + bench_range_query(); +} + +fn bench_range_query() { + let index = get_index_0_to_100(); + let mut runner = BenchRunner::new(); + runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL)); + + runner.set_name("range_query on u64"); + let field_name_and_descr: Vec<_> = vec![ + ("id", "Single Valued Range Field"), + ("ids", "Multi Valued Range Field"), + ]; + let range_num_hits = vec![ + ("90_percent", get_90_percent()), + ("10_percent", get_10_percent()), + ("1_percent", get_1_percent()), + ]; + + test_range(&mut runner, &index, &field_name_and_descr, range_num_hits); + + runner.set_name("range_query on ip"); + let field_name_and_descr: Vec<_> = vec![ + ("ip", "Single Valued Range Field"), + ("ips", "Multi Valued Range Field"), + ]; + let range_num_hits = vec![ + ("90_percent", get_90_percent_ip()), + ("10_percent", get_10_percent_ip()), + ("1_percent", get_1_percent_ip()), + ]; + + test_range(&mut runner, &index, &field_name_and_descr, range_num_hits); +} + +fn test_range( + runner: &mut BenchRunner, + index: &Index, + field_name_and_descr: &[(&str, &str)], + range_num_hits: Vec<(&str, RangeInclusive)>, +) { + for (field, suffix) in field_name_and_descr { + let term_num_hits = vec![ + ("", ""), + ("1_percent", "veryfew"), + ("10_percent", "few"), + ("90_percent", "most"), + ]; + let mut group = runner.new_group(); + group.set_name(suffix); + // all intersect combinations + for (range_name, range) in &range_num_hits { + for (term_name, term) in &term_num_hits { + let index = &index; + let test_name = if term_name.is_empty() { + format!("id_range_hit_{}", range_name) + } else { + format!( + "id_range_hit_{}_intersect_with_term_{}", + range_name, term_name + ) + }; + group.register(test_name, move |_| { + let query = if term_name.is_empty() { + "".to_string() + } else { + format!("AND id_name:{}", term) + }; + black_box(execute_query(field, range, &query, index)); + }); + } + } + group.run(); + } +} + +fn get_index_0_to_100() -> Index { + let mut rng = StdRng::from_seed([1u8; 32]); + let num_vals = 100_000; + let docs: Vec<_> = (0..num_vals) + .map(|_i| { + let id_name = if rng.random_bool(0.01) { + "veryfew".to_string() // 1% + } else if rng.random_bool(0.1) { + "few".to_string() // 9% + } else { + "most".to_string() // 90% + }; + Doc { + id_name, + id: rng.random_range(0..100), + // Multiply by 1000, so that we create most buckets in the compact space + // The benches depend on this range to select n-percent of elements with the + // methods below. + ip: Ipv6Addr::from_u128(rng.random_range(0..100) * 1000), + } + }) + .collect(); + + create_index_from_docs(&docs) +} + +#[derive(Clone, Debug)] +pub struct Doc { + pub id_name: String, + pub id: u64, + pub ip: Ipv6Addr, +} + +pub fn create_index_from_docs(docs: &[Doc]) -> Index { + let mut schema_builder = Schema::builder(); + let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST); + let ids_u64_field = + schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed()); + + let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST); + let ids_f64_field = schema_builder.add_f64_field( + "ids_f64", + NumericOptions::default().set_fast().set_indexed(), + ); + + let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST); + let ids_i64_field = schema_builder.add_i64_field( + "ids_i64", + NumericOptions::default().set_fast().set_indexed(), + ); + + let text_field = schema_builder.add_text_field("id_name", STRING | STORED); + let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST); + + let ip_field = schema_builder.add_ip_addr_field("ip", FAST); + let ips_field = schema_builder.add_ip_addr_field("ips", FAST); + + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + + { + let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap(); + for doc in docs.iter() { + index_writer + .add_document(doc!( + ids_i64_field => doc.id as i64, + ids_i64_field => doc.id as i64, + ids_f64_field => doc.id as f64, + ids_f64_field => doc.id as f64, + ids_u64_field => doc.id, + ids_u64_field => doc.id, + id_u64_field => doc.id, + id_f64_field => doc.id as f64, + id_i64_field => doc.id as i64, + text_field => doc.id_name.to_string(), + text_field2 => doc.id_name.to_string(), + ips_field => doc.ip, + ips_field => doc.ip, + ip_field => doc.ip, + )) + .unwrap(); + } + + index_writer.commit().unwrap(); + } + index +} + +fn get_90_percent() -> RangeInclusive { + 0..=90 +} + +fn get_10_percent() -> RangeInclusive { + 0..=10 +} + +fn get_1_percent() -> RangeInclusive { + 10..=10 +} + +fn get_90_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(0); + let end = Ipv6Addr::from_u128(90 * 1000); + start..=end +} + +fn get_10_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(0); + let end = Ipv6Addr::from_u128(10 * 1000); + start..=end +} + +fn get_1_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(10 * 1000); + let end = Ipv6Addr::from_u128(10 * 1000); + start..=end +} + +struct NumHits { + count: usize, +} +impl OutputValue for NumHits { + fn column_title() -> &'static str { + "NumHits" + } + fn format(&self) -> Option { + Some(self.count.to_string()) + } +} + +fn execute_query( + field: &str, + id_range: &RangeInclusive, + suffix: &str, + index: &Index, +) -> NumHits { + let gen_query_inclusive = |from: &T, to: &T| { + format!( + "{}:[{} TO {}] {}", + field, + &from.to_string(), + &to.to_string(), + suffix + ) + }; + + let query = gen_query_inclusive(id_range.start(), id_range.end()); + execute_query_(&query, index) +} + +fn execute_query_(query: &str, index: &Index) -> NumHits { + let query_from_text = |text: &str| { + QueryParser::for_index(index, vec![]) + .parse_query(text) + .unwrap() + }; + let query = query_from_text(query); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let num_hits = searcher + .search(&query, &(TopDocs::with_limit(10).order_by_score(), Count)) + .unwrap() + .1; + NumHits { count: num_hits } +} diff --git a/bitpacker/Cargo.toml b/bitpacker/Cargo.toml index 3b2a3e15e..945bd0082 100644 --- a/bitpacker/Cargo.toml +++ b/bitpacker/Cargo.toml @@ -18,5 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy" bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] } [dev-dependencies] -rand = "0.8" +rand = "0.9" proptest = "1" diff --git a/bitpacker/benches/bench.rs b/bitpacker/benches/bench.rs index 7544687c2..12bfeb53e 100644 --- a/bitpacker/benches/bench.rs +++ b/bitpacker/benches/bench.rs @@ -4,8 +4,8 @@ extern crate test; #[cfg(test)] mod tests { + use rand::rng; use rand::seq::IteratorRandom; - use rand::thread_rng; use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker}; use test::Bencher; @@ -27,7 +27,7 @@ mod tests { let num_els = 1_000_000u32; let bit_unpacker = BitUnpacker::new(bit_width); let data = create_bitpacked_data(bit_width, num_els); - let idxs: Vec = (0..num_els).choose_multiple(&mut thread_rng(), 100_000); + let idxs: Vec = (0..num_els).choose_multiple(&mut rng(), 100_000); b.iter(|| { let mut out = 0u64; for &idx in &idxs { diff --git a/columnar/Cargo.toml b/columnar/Cargo.toml index 9eeafe2d0..b91ab36ff 100644 --- a/columnar/Cargo.toml +++ b/columnar/Cargo.toml @@ -22,7 +22,7 @@ downcast-rs = "2.0.1" [dev-dependencies] proptest = "1" more-asserts = "0.3.1" -rand = "0.8" +rand = "0.9" binggan = "0.14.0" [[bench]] diff --git a/columnar/benches/bench_column_values_get.rs b/columnar/benches/bench_column_values_get.rs index d486b0dde..f2c1674ef 100644 --- a/columnar/benches/bench_column_values_get.rs +++ b/columnar/benches/bench_column_values_get.rs @@ -9,7 +9,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_co fn get_data() -> Vec { let mut rng = StdRng::seed_from_u64(2u64); let mut data: Vec<_> = (100..55_000_u64) - .map(|num| num + rng.r#gen::() as u64) + .map(|num| num + rng.random::() as u64) .collect(); data.push(99_000); data.insert(1000, 2000); diff --git a/columnar/benches/bench_create_column_values.rs b/columnar/benches/bench_create_column_values.rs index aa04e0661..339dbb199 100644 --- a/columnar/benches/bench_create_column_values.rs +++ b/columnar/benches/bench_create_column_values.rs @@ -6,7 +6,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_valu fn get_data() -> Vec { let mut rng = StdRng::seed_from_u64(2u64); let mut data: Vec<_> = (100..55_000_u64) - .map(|num| num + rng.r#gen::() as u64) + .map(|num| num + rng.random::() as u64) .collect(); data.push(99_000); data.insert(1000, 2000); diff --git a/columnar/benches/bench_optional_index.rs b/columnar/benches/bench_optional_index.rs index c157f1455..03ff1df97 100644 --- a/columnar/benches/bench_optional_index.rs +++ b/columnar/benches/bench_optional_index.rs @@ -8,7 +8,7 @@ const TOTAL_NUM_VALUES: u32 = 1_000_000; fn gen_optional_index(fill_ratio: f64) -> OptionalIndex { let mut rng: StdRng = StdRng::from_seed([1u8; 32]); let vals: Vec = (0..TOTAL_NUM_VALUES) - .map(|_| rng.gen_bool(fill_ratio)) + .map(|_| rng.random_bool(fill_ratio)) .enumerate() .filter(|(_pos, val)| *val) .map(|(pos, _)| pos as u32) @@ -25,7 +25,7 @@ fn random_range_iterator( let mut rng: StdRng = StdRng::from_seed([1u8; 32]); let mut current = start; std::iter::from_fn(move || { - current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation); + current += rng.random_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation); if current >= end { None } else { Some(current) } }) } diff --git a/columnar/benches/bench_values_u128.rs b/columnar/benches/bench_values_u128.rs index e0b4f0a1f..09173c678 100644 --- a/columnar/benches/bench_values_u128.rs +++ b/columnar/benches/bench_values_u128.rs @@ -39,7 +39,7 @@ fn get_data_50percent_item() -> Vec { let mut data = vec![]; for _ in 0..300_000 { - let val = rng.gen_range(1..=100); + let val = rng.random_range(1..=100); data.push(val); } data.push(SINGLE_ITEM); diff --git a/columnar/benches/bench_values_u64.rs b/columnar/benches/bench_values_u64.rs index 36711c776..f0419d8c6 100644 --- a/columnar/benches/bench_values_u64.rs +++ b/columnar/benches/bench_values_u64.rs @@ -34,7 +34,7 @@ fn get_data_50percent_item() -> Vec { let mut data = vec![]; for _ in 0..300_000 { - let val = rng.gen_range(1..=100); + let val = rng.random_range(1..=100); data.push(val); } data.push(SINGLE_ITEM); diff --git a/columnar/src/block_accessor.rs b/columnar/src/block_accessor.rs index 6bd24ba3b..9926553a8 100644 --- a/columnar/src/block_accessor.rs +++ b/columnar/src/block_accessor.rs @@ -29,12 +29,20 @@ impl } } #[inline] - pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column, missing: T) { + pub fn fetch_block_with_missing( + &mut self, + docs: &[u32], + accessor: &Column, + missing: Option, + ) { self.fetch_block(docs, accessor); // no missing values if accessor.index.get_cardinality().is_full() { return; } + let Some(missing) = missing else { + return; + }; // We can compare docid_cache length with docs to find missing docs // For multi value columns we can't rely on the length and always need to scan diff --git a/columnar/src/column/mod.rs b/columnar/src/column/mod.rs index cc2938bb8..f6a50b45f 100644 --- a/columnar/src/column/mod.rs +++ b/columnar/src/column/mod.rs @@ -85,8 +85,8 @@ impl Column { } #[inline] - pub fn first(&self, row_id: RowId) -> Option { - self.values_for_doc(row_id).next() + pub fn first(&self, doc_id: DocId) -> Option { + self.values_for_doc(doc_id).next() } /// Load the first value for each docid in the provided slice. diff --git a/columnar/src/column_values/u64_based/bitpacked.rs b/columnar/src/column_values/u64_based/bitpacked.rs index fde012937..71319cbec 100644 --- a/columnar/src/column_values/u64_based/bitpacked.rs +++ b/columnar/src/column_values/u64_based/bitpacked.rs @@ -41,12 +41,6 @@ fn transform_range_before_linear_transformation( if range.is_empty() { return None; } - if stats.min_value > *range.end() { - return None; - } - if stats.max_value < *range.start() { - return None; - } let shifted_range = range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value); let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd); diff --git a/columnar/src/column_values/u64_based/linear.rs b/columnar/src/column_values/u64_based/linear.rs index dbfa13a4c..7caf3bdfb 100644 --- a/columnar/src/column_values/u64_based/linear.rs +++ b/columnar/src/column_values/u64_based/linear.rs @@ -268,7 +268,7 @@ mod tests { #[test] fn linear_interpol_fast_field_rand() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..50 { let mut data = (0..10_000).map(|_| rng.next_u64()).collect::>(); create_and_validate::(&data, "random"); diff --git a/columnar/src/column_values/u64_based/tests.rs b/columnar/src/column_values/u64_based/tests.rs index 6b2697263..ff5b7051a 100644 --- a/columnar/src/column_values/u64_based/tests.rs +++ b/columnar/src/column_values/u64_based/tests.rs @@ -122,7 +122,7 @@ pub(crate) fn create_and_validate( assert_eq!(vals, buffer); if !vals.is_empty() { - let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1); + let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1); let expected_positions: Vec = vals .iter() .enumerate() diff --git a/columnar/src/tests.rs b/columnar/src/tests.rs index 5fa537466..5c4a9366c 100644 --- a/columnar/src/tests.rs +++ b/columnar/src/tests.rs @@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() { let DynamicColumn::Bool(bool_col) = dyn_bool_col else { panic!(); }; - let vals: Vec> = (0..5).map(|row_id| bool_col.first(row_id)).collect(); + let vals: Vec> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect(); assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]); } @@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() { let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else { panic!(); }; - let vals: Vec> = (0..5).map(|row_id| ip_col.first(row_id)).collect(); + let vals: Vec> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect(); assert_eq!( &vals, &[ @@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() { let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else { panic!(); }; - let index: Vec> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect(); + let index: Vec> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect(); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(str_col.num_rows(), 5); let mut term_buffer = String::new(); @@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() { panic!(); }; let index: Vec> = (0..5) - .map(|row_id| bytes_col.ords().first(row_id)) + .map(|doc_id| bytes_col.ords().first(doc_id)) .collect(); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(bytes_col.num_rows(), 5); diff --git a/common/Cargo.toml b/common/Cargo.toml index 206329d39..e5e922869 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -21,5 +21,5 @@ serde = { version = "1.0.136", features = ["derive"] } [dev-dependencies] binggan = "0.14.0" proptest = "1.0.0" -rand = "0.8.4" +rand = "0.9" diff --git a/common/benches/bench.rs b/common/benches/bench.rs index 81260e116..a0b1f9451 100644 --- a/common/benches/bench.rs +++ b/common/benches/bench.rs @@ -1,6 +1,6 @@ use binggan::{BenchRunner, black_box}; +use rand::rng; use rand::seq::IteratorRandom; -use rand::thread_rng; use tantivy_common::{BitSet, TinySet, serialize_vint_u32}; fn bench_vint() { @@ -17,7 +17,7 @@ fn bench_vint() { black_box(out); }); - let vals: Vec = (0..20_000).choose_multiple(&mut thread_rng(), 100_000); + let vals: Vec = (0..20_000).choose_multiple(&mut rng(), 100_000); runner.bench_function("bench_vint_rand", move |_| { let mut out = 0u64; for val in vals.iter().cloned() { diff --git a/common/src/bitset.rs b/common/src/bitset.rs index 8e98e6780..e005ca40b 100644 --- a/common/src/bitset.rs +++ b/common/src/bitset.rs @@ -181,6 +181,14 @@ pub struct BitSet { len: u64, max_value: u32, } +impl std::fmt::Debug for BitSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BitSet") + .field("len", &self.len) + .field("max_value", &self.max_value) + .finish() + } +} fn num_buckets(max_val: u32) -> u32 { max_val.div_ceil(64u32) @@ -408,7 +416,7 @@ mod tests { use std::collections::HashSet; use ownedbytes::OwnedBytes; - use rand::distributions::Bernoulli; + use rand::distr::Bernoulli; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index 40dcf3fe6..156b014f3 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -1,4 +1,4 @@ -use columnar::{Column, ColumnType, StrColumn}; +use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; use common::BitSet; use rustc_hash::FxHashSet; use serde::Serialize; @@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, - MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + build_segment_filter_collector, build_segment_range_collector, FilterAggReqData, + HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, + RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ - AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, - ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation, - SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector, - SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, + build_segment_stats_collector, AverageAggregation, CardinalityAggReqData, + CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation, + MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector, + SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, TopHitsSegmentCollector, }; use crate::aggregation::segment_agg_result::{ @@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx { /// Request data for each aggregation type. pub per_request: PerRequestAggSegCtx, pub context: AggContextParams, + pub column_block_accessor: ColumnBlockAccessor, } impl AggregationsSegmentCtx { @@ -107,21 +108,14 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } - #[inline] - pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData { - self.per_request.filter_req_data[idx] - .as_deref() - .expect("filter_req_data slot is empty (taken)") - } // ---------- mutable getters ---------- #[inline] - pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { - self.per_request.term_req_data[idx] - .as_deref_mut() - .expect("term_req_data slot is empty (taken)") + pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { + &mut self.per_request.stats_metric_req_data[idx] } + #[inline] pub(crate) fn get_cardinality_req_data_mut( &mut self, @@ -129,10 +123,7 @@ impl AggregationsSegmentCtx { ) -> &mut CardinalityAggReqData { &mut self.per_request.cardinality_req_data[idx] } - #[inline] - pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { - &mut self.per_request.stats_metric_req_data[idx] - } + #[inline] pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { self.per_request.histogram_req_data[idx] @@ -142,21 +133,6 @@ impl AggregationsSegmentCtx { // ---------- take / put (terms, histogram, range) ---------- - /// Move out the boxed Terms request at `idx`, leaving `None`. - #[inline] - pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box { - self.per_request.term_req_data[idx] - .take() - .expect("term_req_data slot is empty (taken)") - } - - /// Put back a Terms request into an empty slot at `idx`. - #[inline] - pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box) { - debug_assert!(self.per_request.term_req_data[idx].is_none()); - self.per_request.term_req_data[idx] = Some(value); - } - /// Move out the boxed Histogram request at `idx`, leaving `None`. #[inline] pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box { @@ -320,6 +296,7 @@ impl PerRequestAggSegCtx { /// Convert the aggregation tree into a serializable struct representation. /// Each node contains: { name, kind, children }. + #[allow(dead_code)] pub fn get_view_tree(&self) -> Vec { fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode { let mut children: Vec = @@ -345,12 +322,19 @@ impl PerRequestAggSegCtx { pub(crate) fn build_segment_agg_collectors_root( req: &mut AggregationsSegmentCtx, ) -> crate::Result> { - build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) + build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone()) } pub(crate) fn build_segment_agg_collectors( req: &mut AggregationsSegmentCtx, nodes: &[AggRefNode], +) -> crate::Result> { + build_segment_agg_collectors_generic(req, nodes) +} + +fn build_segment_agg_collectors_generic( + req: &mut AggregationsSegmentCtx, + nodes: &[AggRefNode], ) -> crate::Result> { let mut collectors = Vec::new(); for node in nodes.iter() { @@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector( Ok(Box::new(SegmentCardinalityCollector::from_req( req_data.column_type, node.idx_in_req_data, + req_data.accessor.clone(), + req_data.missing_value_for_accessor, ))) } AggKind::StatsKind(stats_type) => { @@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector( | StatsType::Count | StatsType::Max | StatsType::Min - | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( - node.idx_in_req_data, - ))), - StatsType::ExtendedStats(sigma) => { - Ok(Box::new(SegmentExtendedStatsCollector::from_req( - req_data.field_type, - sigma, - node.idx_in_req_data, - req_data.missing, - ))) - } - StatsType::Percentiles => Ok(Box::new( - SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?, + | StatsType::Stats => build_segment_stats_collector(req_data), + StatsType::ExtendedStats(sigma) => Ok(Box::new( + SegmentExtendedStatsCollector::from_req(req_data, sigma), )), + StatsType::Percentiles => { + let req_data = req.get_metric_req_data_mut(node.idx_in_req_data); + Ok(Box::new( + SegmentPercentilesCollector::from_req_and_validate( + req_data.field_type, + req_data.missing_u64, + req_data.accessor.clone(), + node.idx_in_req_data, + ), + )) + } } } AggKind::TopHits => { @@ -428,12 +415,8 @@ pub(crate) fn build_segment_agg_collector( AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( req, node, )?)), - AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - req, node, - )?)), - AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Range => Ok(build_segment_range_collector(req, node)?), + AggKind::Filter => build_segment_filter_collector(req, node), } } @@ -493,6 +476,7 @@ pub(crate) fn build_aggregations_data_from_req( let mut data = AggregationsSegmentCtx { per_request: Default::default(), context, + column_block_accessor: ColumnBlockAccessor::default(), }; for (name, agg) in aggs.iter() { @@ -521,9 +505,9 @@ fn build_nodes( let idx_in_req_data = data.push_range_req_data(RangeAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: range_req.clone(), + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -541,9 +525,7 @@ fn build_nodes( let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req.clone(), is_date_histogram: false, bounds: HistogramBounds { @@ -568,9 +550,7 @@ fn build_nodes( let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req, is_date_histogram: true, bounds: HistogramBounds { @@ -650,7 +630,6 @@ fn build_nodes( let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), collecting_for, missing: *missing, @@ -678,7 +657,6 @@ fn build_nodes( let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), collecting_for: StatsType::Percentiles, missing: percentiles_req.missing, @@ -753,6 +731,7 @@ fn build_nodes( segment_reader: reader.clone(), evaluator, matching_docs_buffer, + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -895,7 +874,7 @@ fn build_terms_or_cardinality_nodes( }); } - // Add one node per accessor to mirror previous behavior and allow per-type missing handling. + // Add one node per accessor for (accessor, column_type) in column_and_types { let missing_value_for_accessor = if use_special_missing_agg { None @@ -926,11 +905,8 @@ fn build_terms_or_cardinality_nodes( column_type, str_dict_column: str_dict_column.clone(), missing_value_for_accessor, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: TermsAggregationInternal::from_req(req), - // Will be filled later when building collectors - sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), allowed_term_ids, is_top_level, @@ -943,7 +919,6 @@ fn build_terms_or_cardinality_nodes( column_type, str_dict_column: str_dict_column.clone(), missing_value_for_accessor, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: req.clone(), }); diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 2d5ffa769..ba662116d 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -2,15 +2,441 @@ use serde_json::Value; use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; -use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, IndexWriter, Term}; +// The following tests ensure that each bucket aggregation type correctly functions as a +// sub-aggregation of another bucket aggregation in two scenarios: +// 1) The parent has more buckets than the child sub-aggregation +// 2) The child sub-aggregation has more buckets than the parent +// +// These scenarios exercise the bucket id mapping and sub-aggregation routing logic. + +#[test] +fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with 4 buckets + // Child: terms on text -> 2 buckets + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}} + } + } + })) + .unwrap(); + + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + // Exact expected structure and counts + assert_eq!( + res["parent_range"]["buckets"], + json!([ + { + "key": "*-3", + "doc_count": 1, + "to": 3.0, + "child_terms": { + "buckets": [ + {"doc_count": 1, "key": "cool"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "3-7", + "doc_count": 3, + "from": 3.0, + "to": 7.0, + "child_terms": { + "buckets": [ + {"doc_count": 2, "key": "cool"}, + {"doc_count": 1, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "7-20", + "doc_count": 3, + "from": 7.0, + "to": 20.0, + "child_terms": { + "buckets": [ + {"doc_count": 3, "key": "cool"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "20-*", + "doc_count": 2, + "from": 20.0, + "child_terms": { + "buckets": [ + {"doc_count": 1, "key": "cool"}, + {"doc_count": 1, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: histogram on score with large interval -> 1 bucket + // Child: terms on text -> 2 buckets (cool/nohit) + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_hist": { + "histogram": {"field": "score", "interval": 100.0}, + "aggs": { + "child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}} + } + } + })) + .unwrap(); + + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + assert_eq!( + res["parent_hist"], + json!({ + "buckets": [ + { + "key": 0.0, + "doc_count": 9, + "child_terms": { + "buckets": [ + {"doc_count": 7, "key": "cool"}, + {"doc_count": 2, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + } + ] + }) + ); + + Ok(()) +} + +#[test] +fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with 5 buckets + // Child: coarse range with 3 buckets + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 20.0} + ] + } + } + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + assert_eq!( + res["parent_range"]["buckets"], + json!([ + {"key": "*-3", "doc_count": 1, "to": 3.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 1, "to": 3.0}, + {"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "20-*", "doc_count": 2, "from": 20.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 2, "from": 20.0} + ]} + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: terms on text (2 buckets) + // Child: range with 4 buckets + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 20.0} + ] + } + } + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + + assert_eq!( + res["parent_terms"], + json!({ + "buckets": [ + { + "key": "cool", + "doc_count": 7, + "child_range": { + "buckets": [ + {"key": "*-3", "doc_count": 1, "to": 3.0}, + {"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0}, + {"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0}, + {"key": "20-*", "doc_count": 1, "from": 20.0} + ] + } + }, + { + "key": "nohit", + "doc_count": 2, + "child_range": { + "buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0}, + {"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0}, + {"key": "20-*", "doc_count": 1, "from": 20.0} + ] + } + } + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0 + }) + ); + + Ok(()) +} + +#[test] +fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with several ranges + // Child: histogram with large interval (single bucket per parent) + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_hist": {"histogram": {"field": "score", "interval": 100.0}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + assert_eq!( + res["parent_range"]["buckets"], + json!([ + {"key": "*-3", "doc_count": 1, "to": 3.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]} + }, + {"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]} + }, + {"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]} + }, + {"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]} + }, + {"key": "20-*", "doc_count": 2, "from": 20.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]} + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: terms on text -> 2 buckets + // Child: histogram with small interval -> multiple buckets including empties + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_hist": {"histogram": {"field": "score", "interval": 10.0}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + assert_eq!( + res["parent_terms"], + json!({ + "buckets": [ + { + "key": "cool", + "doc_count": 7, + "child_hist": { + "buckets": [ + {"key": 0.0, "doc_count": 4}, + {"key": 10.0, "doc_count": 2}, + {"key": 20.0, "doc_count": 0}, + {"key": 30.0, "doc_count": 0}, + {"key": 40.0, "doc_count": 1} + ] + } + }, + { + "key": "nohit", + "doc_count": 2, + "child_hist": { + "buckets": [ + {"key": 0.0, "doc_count": 1}, + {"key": 10.0, "doc_count": 0}, + {"key": 20.0, "doc_count": 0}, + {"key": 30.0, "doc_count": 0}, + {"key": 40.0, "doc_count": 1} + ] + } + } + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0 + }) + ); + + Ok(()) +} + +#[test] +fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with several buckets + // Child: date_histogram with 30d -> single bucket per parent + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + let buckets = res["parent_range"]["buckets"].as_array().unwrap(); + // Verify each parent bucket has exactly one child date bucket with matching doc_count + for bucket in buckets { + let parent_count = bucket["doc_count"].as_u64().unwrap(); + let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(child_buckets.len(), 1); + assert_eq!(child_buckets[0]["doc_count"], parent_count); + } + + // Case B: child has more buckets than parent + // Parent: terms on text (2 buckets) + // Child: date_histogram with 1d -> multiple buckets + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + let buckets = res["parent_terms"]["buckets"].as_array().unwrap(); + + // cool bucket + assert_eq!(buckets[0]["key"], "cool"); + let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(cool_buckets.len(), 3); + assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0 + assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1 + assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2 + + // nohit bucket + assert_eq!(buckets[1]["key"], "nohit"); + let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(nohit_buckets.len(), 2); + assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1 + assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2 + + Ok(()) +} + fn get_avg_req(field_name: &str) -> Aggregation { serde_json::from_value(json!({ "avg": { @@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector { } // *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE *** +// Note: The flushng part of these tests are outdated, since the buffering change after converting +// the collection into one collector per request instead of per bucket. +// +// However they are useful as they test a complex aggregation requests. fn test_aggregation_flushing( merge_segments: bool, use_distributed_collector: bool, @@ -37,8 +467,9 @@ fn test_aggregation_flushing( let reader = index.reader()?; - assert_eq!(DOC_BLOCK_SIZE, 64); - // In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. + assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64); + // In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one + // block. // // Build a request so that on the first level we have one full cache, which is then flushed. // The same cache should have some residue docs at the end, which are flushed (Range 0-70) diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index 18f2a692a..73518238a 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, }; -use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; use crate::docset::DocSet; use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; use crate::schema::Schema; @@ -404,15 +408,18 @@ pub struct FilterAggReqData { pub evaluator: DocumentQueryEvaluator, /// Reusable buffer for matching documents to minimize allocations during collection pub matching_docs_buffer: Vec, + /// True if this filter aggregation is at the top level of the aggregation tree (not nested). + pub is_top_level: bool, } impl FilterAggReqData { pub(crate) fn get_memory_consumption(&self) -> usize { // Estimate: name + segment reader reference + bitset + buffer capacity self.name.len() - + std::mem::size_of::() - + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) - + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() + + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() } } @@ -489,17 +496,24 @@ impl Debug for DocumentQueryEvaluator { } } -/// Segment collector for filter aggregation -pub struct SegmentFilterCollector { - /// Document count in this bucket +#[derive(Debug, Clone, PartialEq, Copy)] +struct DocCount { doc_count: u64, + bucket_id: BucketId, +} + +/// Segment collector for filter aggregation +pub struct SegmentFilterCollector { + /// Document counts per parent bucket + parent_buckets: Vec, /// Sub-aggregation collectors - sub_aggregations: Option>, + sub_aggregations: Option>, + bucket_id_provider: BucketIdProvider, /// Accessor index for this filter aggregation (to access FilterAggReqData) accessor_idx: usize, } -impl SegmentFilterCollector { +impl SegmentFilterCollector { /// Create a new filter segment collector following the new agg_data pattern pub(crate) fn from_req_and_validate( req: &mut AggregationsSegmentCtx, @@ -511,47 +525,75 @@ impl SegmentFilterCollector { } else { None }; + let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new); Ok(SegmentFilterCollector { - doc_count: 0, + parent_buckets: Vec::new(), sub_aggregations: sub_agg_collector, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } -impl Debug for SegmentFilterCollector { +pub(crate) fn build_segment_filter_collector( + req: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data] + .as_ref() + .expect("filter_req_data slot is empty") + .is_top_level; + + if is_top_level { + Ok(Box::new( + SegmentFilterCollector::::from_req_and_validate(req, node)?, + )) + } else { + Ok(Box::new( + SegmentFilterCollector::::from_req_and_validate(req, node)?, + )) + } +} + +impl Debug for SegmentFilterCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentFilterCollector") - .field("doc_count", &self.doc_count) + .field("buckets", &self.parent_buckets) .field("has_sub_aggs", &self.sub_aggregations.is_some()) .field("accessor_idx", &self.accessor_idx) .finish() } } -impl CollectorClone for SegmentFilterCollector { - fn clone_box(&self) -> Box { - // For now, panic - this needs proper implementation with weight recreation - panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation") - } -} - -impl SegmentAggregationCollector for SegmentFilterCollector { +impl SegmentAggregationCollector for SegmentFilterCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let mut sub_results = IntermediateAggregationResults::default(); + let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize); - if let Some(sub_aggs) = self.sub_aggregations { - sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_results, + // Here we create a new bucket ID for sub-aggregations if the bucket doesn't + // exist, so that sub-aggregations can still produce results (e.g., zero doc + // count) + bucket_opt + .map(|bucket| bucket.bucket_id) + .unwrap_or(self.bucket_id_provider.next_bucket_id()), + )?; } // Create the filter bucket result let filter_bucket_result = IntermediateBucketResult::Filter { - doc_count: self.doc_count, + doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0), sub_aggregations: sub_results, }; @@ -570,32 +612,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector { Ok(()) } - fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - // Access the evaluator from FilterAggReqData - let req_data = agg_data.get_filter_req_data(self.accessor_idx); - - // O(1) BitSet lookup to check if document matches filter - if req_data.evaluator.matches_document(doc) { - self.doc_count += 1; - - // If we have sub-aggregations, collect on them for this filtered document - if let Some(sub_aggs) = &mut self.sub_aggregations { - sub_aggs.collect(doc, agg_data)?; - } - } - Ok(()) - } - - #[inline] - fn collect_block( + fn collect( &mut self, - docs: &[DocId], + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { if docs.is_empty() { return Ok(()); } + let mut bucket = self.parent_buckets[parent_bucket_id as usize]; // Take the request data to avoid borrow checker issues with sub-aggregations let mut req = agg_data.take_filter_req_data(self.accessor_idx); @@ -604,18 +631,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector { req.evaluator .filter_batch(docs, &mut req.matching_docs_buffer); - self.doc_count += req.matching_docs_buffer.len() as u64; + bucket.doc_count += req.matching_docs_buffer.len() as u64; // Batch process sub-aggregations if we have matches if !req.matching_docs_buffer.is_empty() { if let Some(sub_aggs) = &mut self.sub_aggregations { - // Use collect_block for better sub-aggregation performance - sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; + for &doc_id in &req.matching_docs_buffer { + sub_aggs.push(bucket.bucket_id, doc_id); + } } } // Put the request data back agg_data.put_back_filter_req_data(self.accessor_idx, req); + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs.check_flush_local(agg_data)?; + } + // put back bucket + self.parent_buckets[parent_bucket_id as usize] = bucket; Ok(()) } @@ -626,6 +659,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.parent_buckets.push(DocCount { + doc_count: 0, + bucket_id, + }); + } + Ok(()) + } } /// Intermediate result for filter aggregation @@ -1519,9 +1567,9 @@ mod tests { let searcher = reader.searcher(); let agg = json!({ - "test": { - "filter": deserialized, - "aggs": { "count": { "value_count": { "field": "brand" } } } + "test": { + "filter": deserialized, + "aggs": { "count": { "value_count": { "field": "brand" } } } } }); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 36c0fe57e..adf7936c6 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; @@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; -use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::BucketEntry; +use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -26,13 +26,8 @@ pub struct HistogramAggReqData { pub accessor: Column, /// The field type of the fast field. pub field_type: ColumnType, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, - /// The sub aggregation blueprint, used to create sub aggregations for each bucket. - /// Will be filled during initialization of the collector. - pub sub_aggregation_blueprint: Option>, /// The histogram aggregation request. pub req: HistogramAggregation, /// True if this is a date_histogram aggregation. @@ -257,18 +252,24 @@ impl HistogramBounds { pub(crate) struct SegmentHistogramBucketEntry { pub key: f64, pub doc_count: u64, + pub bucket_id: BucketId, } impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - sub_aggregation: Option>, + sub_aggregation: &mut Option, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + self.bucket_id, + )?; } Ok(IntermediateHistogramBucketEntry { key: self.key, @@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry { } } +#[derive(Clone, Debug, Default)] +struct HistogramBuckets { + pub buckets: FxHashMap, +} + /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. - buckets: FxHashMap, - sub_aggregations: FxHashMap>, + /// One Histogram bucket per parent bucket id. + parent_buckets: Vec, + sub_agg: Option, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } impl SegmentAggregationCollector for SegmentHistogramCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data .get_histogram_req_data(self.accessor_idx) .name .clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + // TODO: avoid prepare_max_bucket here and handle empty buckets. + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req = agg_data.take_histogram_req_data(self.accessor_idx); + let req = agg_data.take_histogram_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); + let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets; let bounds = req.bounds; let interval = req.req.interval; let offset = req.offset; let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64; - req.column_block_accessor.fetch_block(docs, &req.accessor); - for (doc, val) in req + agg_data + .column_block_accessor + .fetch_block(docs, &req.accessor); + for (doc, val) in agg_data .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let val = f64_from_fastfield_u64(val, &req.field_type); + let val = f64_from_fastfield_u64(val, req.field_type); let bucket_pos = get_bucket_pos(val); if bounds.contains(val) { - let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { + let bucket = buckets.entry(bucket_pos).or_insert_with(|| { let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); - SegmentHistogramBucketEntry { key, doc_count: 0 } + SegmentHistogramBucketEntry { + key, + doc_count: 0, + bucket_id: self.bucket_id_provider.next_bucket_id(), + } }); bucket.doc_count += 1; - if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { - self.sub_aggregations - .entry(bucket_pos) - .or_insert_with(|| sub_aggregation_blueprint.clone()) - .collect(doc, agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.push(bucket.bucket_id, doc); } } } @@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { .add_memory_consumed(mem_delta as u64)?; } + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } + Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregation in self.sub_aggregations.values_mut() { + if let Some(sub_aggregation) = &mut self.sub_agg { sub_aggregation.flush(agg_data)?; } + Ok(()) + } + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + self.parent_buckets.push(HistogramBuckets { + buckets: FxHashMap::default(), + }); + } Ok(()) } } @@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { impl SegmentHistogramCollector { fn get_memory_consumption(&self) -> usize { let self_mem = std::mem::size_of::(); - let sub_aggs_mem = self.sub_aggregations.memory_consumption(); - let buckets_mem = self.buckets.memory_consumption(); - self_mem + sub_aggs_mem + buckets_mem + let buckets_mem = self.parent_buckets.len() * std::mem::size_of::(); + self_mem + buckets_mem } /// Converts the collector result into a intermediate bucket result. - pub fn into_intermediate_bucket_result( - self, + fn add_intermediate_bucket_result( + &mut self, agg_data: &AggregationsSegmentCtx, + histogram: HistogramBuckets, ) -> crate::Result { - let mut buckets = Vec::with_capacity(self.buckets.len()); + let mut buckets = Vec::with_capacity(histogram.buckets.len()); - for (bucket_pos, bucket) in self.buckets { - let bucket_res = bucket.into_intermediate_bucket_entry( - self.sub_aggregations.get(&bucket_pos).cloned(), - agg_data, - ); + for bucket in histogram.buckets.into_values() { + let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data); buckets.push(bucket_res?); } @@ -408,7 +429,7 @@ impl SegmentHistogramCollector { agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { - let blueprint = if !node.children.is_empty() { + let sub_agg = if !node.children.is_empty() { Some(build_segment_agg_collectors(agg_data, &node.children)?) } else { None @@ -423,13 +444,13 @@ impl SegmentHistogramCollector { max: f64::MAX, }); req_data.offset = req_data.req.offset.unwrap_or(0.0); - - req_data.sub_aggregation_blueprint = blueprint; + let sub_agg = sub_agg.map(CachedSubAggs::new); Ok(Self { - buckets: Default::default(), - sub_aggregations: Default::default(), + parent_buckets: Default::default(), + sub_agg, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index c26872e9b..46e0065ce 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,18 +1,22 @@ use std::fmt::Debug; use std::ops::Range; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::agg_limits::AggregationLimitsGuard; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -23,12 +27,12 @@ pub struct RangeAggReqData { pub accessor: Column, /// The type of the fast field. pub field_type: ColumnType, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The range aggregation request. pub req: RangeAggregation, /// The name of the aggregation. pub name: String, + /// Whether this is a top-level aggregation. + pub is_top_level: bool, } impl RangeAggReqData { @@ -151,19 +155,47 @@ pub(crate) struct SegmentRangeAndBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -pub struct SegmentRangeCollector { +pub struct SegmentRangeCollector { /// The buckets containing the aggregation data. - buckets: Vec, + /// One for each ParentBucketId + parent_buckets: Vec>, column_type: ColumnType, pub(crate) accessor_idx: usize, + sub_agg: Option>, + /// Here things get a bit weird. We need to assign unique bucket ids across all + /// parent buckets. So we keep track of the next available bucket id here. + /// This allows a kind of flattening of the bucket ids across all parent buckets. + /// E.g. in nested aggregations: + /// Term Agg -> Range aggregation -> Stats aggregation + /// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range + /// aggregation with 4 buckets. The Range aggregation will create buckets with ids: + /// - INFO: 0,1,2,3 + /// - ERROR: 4,5,6,7 + /// - WARN: 8,9,10,11 + /// + /// This allows the Stats aggregation to have unique bucket ids to refer to. + bucket_id_provider: BucketIdProvider, + limits: AggregationLimitsGuard, } +impl Debug for SegmentRangeCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeCollector") + .field("parent_buckets_len", &self.parent_buckets.len()) + .field("column_type", &self.column_type) + .field("accessor_idx", &self.accessor_idx) + .field("has_sub_agg", &self.sub_agg.is_some()) + .finish() + } +} + +/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry #[derive(Clone)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key, pub doc_count: u64, - pub sub_aggregation: Option>, + // pub sub_aggregation: Option>, + pub bucket_id: BucketId, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not @@ -184,48 +216,50 @@ impl Debug for SegmentRangeBucketEntry { impl SegmentRangeBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let mut sub_aggregation_res = IntermediateAggregationResults::default(); - if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)? - } else { - Default::default() - }; + let sub_aggregation = IntermediateAggregationResults::default(); Ok(IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, - sub_aggregation: sub_aggregation_res, + sub_aggregation_res: sub_aggregation, from: self.from, to: self.to, }) } } -impl SegmentAggregationCollector for SegmentRangeCollector { +impl SegmentAggregationCollector for SegmentRangeCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let field_type = self.column_type; let name = agg_data .get_range_req_data(self.accessor_idx) .name .to_string(); - let buckets: FxHashMap = self - .buckets + let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + + let buckets: FxHashMap = buckets .into_iter() - .map(move |range_bucket| { - Ok(( - range_to_string(&range_bucket.range, &field_type)?, - range_bucket - .bucket - .into_intermediate_bucket_entry(agg_data)?, - )) + .map(|range_bucket| { + let bucket_id = range_bucket.bucket.bucket_id; + let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?; + if let Some(sub_aggregation) = &mut self.sub_agg { + sub_aggregation + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut agg.sub_aggregation_res, + bucket_id, + )?; + } + Ok((range_to_string(&range_bucket.range, &field_type)?, agg)) }) .collect::>()?; @@ -242,73 +276,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - // Take request data to avoid borrow conflicts during sub-aggregation - let mut req = agg_data.take_range_req_data(self.accessor_idx); + let req = agg_data.take_range_req_data(self.accessor_idx); - req.column_block_accessor.fetch_block(docs, &req.accessor); + agg_data + .column_block_accessor + .fetch_block(docs, &req.accessor); - for (doc, val) in req + let buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + + for (doc, val) in agg_data .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let bucket_pos = self.get_bucket_pos(val); - let bucket = &mut self.buckets[bucket_pos]; + let bucket_pos = get_bucket_pos(val, buckets); + let bucket = &mut buckets[bucket_pos]; bucket.bucket.doc_count += 1; - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.collect(doc, agg_data)?; + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket.bucket_id, doc); } } agg_data.put_back_range_req_data(self.accessor_idx, req); + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.buckets.iter_mut() { - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.flush(agg_data)?; - } + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let new_buckets = self.create_new_buckets(agg_data)?; + self.parent_buckets.push(new_buckets); + } + + Ok(()) + } +} +/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed +/// bucket storage, depending on the column type and aggregation level. +pub(crate) fn build_segment_range_collector( + agg_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let accessor_idx = node.idx_in_req_data; + let req_data = agg_data.get_range_req_data(node.idx_in_req_data); + let field_type = req_data.field_type; + + // TODO: A better metric instead of is_top_level would be the number of buckets expected. + // E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets, + // we can are still in low cardinality territory. + let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64; + + let sub_agg = if !node.children.is_empty() { + Some(build_segment_agg_collectors(agg_data, &node.children)?) + } else { + None + }; + + if is_low_card { + Ok(Box::new(SegmentRangeCollector:: { + sub_agg: sub_agg.map(LowCardCachedSubAggs::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } else { + Ok(Box::new(SegmentRangeCollector:: { + sub_agg: sub_agg.map(CachedSubAggs::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } } -impl SegmentRangeCollector { - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let accessor_idx = node.idx_in_req_data; - let (field_type, ranges) = { - let req_view = req_data.get_range_req_data(node.idx_in_req_data); - (req_view.field_type, req_view.req.ranges.clone()) - }; - +impl SegmentRangeCollector { + pub(crate) fn create_new_buckets( + &mut self, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result> { + let field_type = self.column_type; + let req_data = agg_data.get_range_req_data(self.accessor_idx); // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let sub_agg_prototype = if !node.children.is_empty() { - Some(build_segment_agg_collectors(req_data, &node.children)?) - } else { - None - }; - - let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)? .iter() .map(|range| { + let bucket_id = self.bucket_id_provider.next_bucket_id(); let key = range .key .clone() @@ -317,20 +392,20 @@ impl SegmentRangeCollector { let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, field_type)) }; let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, field_type)) }; - let sub_aggregation = sub_agg_prototype.clone(); + // let sub_aggregation = sub_agg_prototype.clone(); Ok(SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation, + bucket_id, key, from, to, @@ -339,27 +414,20 @@ impl SegmentRangeCollector { }) .collect::>()?; - req_data.context.limits.add_memory_consumed( + self.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; - - Ok(SegmentRangeCollector { - buckets, - column_type: field_type, - accessor_idx, - }) - } - - #[inline] - fn get_bucket_pos(&self, val: u64) -> usize { - let pos = self - .buckets - .binary_search_by_key(&val, |probe| probe.range.start) - .unwrap_or_else(|pos| pos - 1); - debug_assert!(self.buckets[pos].range.contains(&val)); - pos + Ok(buckets) } } +#[inline] +fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize { + let pos = buckets + .binary_search_by_key(&val, |probe| probe.range.start) + .unwrap_or_else(|pos| pos - 1); + debug_assert!(buckets[pos].range.contains(&val)); + pos +} /// Converts the user provided f64 range value to fast field value space. /// @@ -456,7 +524,7 @@ pub(crate) fn range_to_string( let val = i64::from_u64(val); format_date(val) } else { - Ok(f64_from_fastfield_u64(val, field_type).to_string()) + Ok(f64_from_fastfield_u64(val, *field_type).to_string()) } }; @@ -486,7 +554,7 @@ mod tests { pub fn get_collector_from_ranges( ranges: Vec, field_type: ColumnType, - ) -> SegmentRangeCollector { + ) -> SegmentRangeCollector { let req = RangeAggregation { field: "dummy".to_string(), ranges, @@ -506,30 +574,33 @@ mod tests { let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, field_type)) }; let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, field_type)) }; SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation: None, key, from, to, + bucket_id: 0, }, } }) .collect(); SegmentRangeCollector { - buckets, + parent_buckets: vec![buckets], column_type: field_type, accessor_idx: 0, + sub_agg: None, + bucket_id_provider: Default::default(), + limits: AggregationLimitsGuard::default(), } } @@ -776,7 +847,7 @@ mod tests { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -799,7 +870,7 @@ mod tests { ]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -814,7 +885,7 @@ mod tests { let buckets = vec![(-10f64..-1f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); } @@ -823,7 +894,7 @@ mod tests { let buckets = vec![(0f64..10f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); } @@ -832,7 +903,7 @@ mod tests { fn range_binary_search_test_u64() { let check_ranges = |ranges: Vec| { let collector = get_collector_from_ranges(ranges, ColumnType::U64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9), 0); @@ -878,7 +949,7 @@ mod tests { let ranges = vec![(10.0..100.0).into()]; let collector = get_collector_from_ranges(ranges, ColumnType::F64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9f64.to_u64()), 0); @@ -890,63 +961,3 @@ mod tests { // the max value } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use itertools::Itertools; - use rand::seq::SliceRandom; - use rand::thread_rng; - - use super::*; - use crate::aggregation::bucket::range::tests::get_collector_from_ranges; - - const TOTAL_DOCS: u64 = 1_000_000u64; - const NUM_DOCS: u64 = 50_000u64; - - fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector { - let bucket_size = num_docs / num_buckets; - let mut buckets: Vec = vec![]; - for i in 0..num_buckets { - let bucket_start = (i * bucket_size) as f64; - buckets.push((bucket_start..bucket_start + bucket_size as f64).into()) - } - - get_collector_from_ranges(buckets, ColumnType::U64) - } - - fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec { - let mut rng = thread_rng(); - - let all_docs = (0..total_docs - 1).collect_vec(); - let mut vals = all_docs - .as_slice() - .choose_multiple(&mut rng, num_docs_returned as usize) - .cloned() - .collect_vec(); - vals.sort(); - vals - } - - fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) { - let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS); - let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS); - b.iter(|| { - let mut bucket_pos = 0; - for val in &vals { - bucket_pos = collector.get_bucket_pos(*val); - } - bucket_pos - }) - } - - #[bench] - fn bench_range_100_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 100) - } - - #[bench] - fn bench_range_10_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 10) - } -} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index d87cd0078..ed2793bd1 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -4,10 +4,10 @@ use std::net::Ipv6Addr; use columnar::column_values::CompactSpaceU64Accessor; use columnar::{ - Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128, - MonotonicallyMappableToU64, NumericalValue, StrColumn, + Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, + NumericalValue, StrColumn, }; -use common::BitSet; +use common::{BitSet, TinySet}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -17,18 +17,21 @@ use crate::aggregation::agg_data::{ }; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; -use crate::aggregation::buf_collector::BufAggregationCollector; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{format_date, Key}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::{format_date, BucketId, Key}; use crate::error::DataCorruption; use crate::TantivyError; /// Contains all information required by the SegmentTermCollector to perform the /// terms aggregation on a segment. +#[derive(Debug, Clone)] pub struct TermsAggReqData { /// The column accessor to access the fast field values. pub accessor: Column, @@ -38,10 +41,6 @@ pub struct TermsAggReqData { pub str_dict_column: Option, /// The missing value as u64 value. pub missing_value_for_accessor: Option, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, - /// Note: sub_aggregation_blueprint is filled later when building collectors - pub sub_aggregation_blueprint: Option>, /// Used to build the correct nested result when we have an empty result. pub sug_aggregations: Aggregations, /// The name of the aggregation. @@ -257,9 +256,9 @@ pub struct TermsAggregation { /// Internally, `missing` requires some specialized handling in some scenarios. /// /// Simple Case: - /// In the simplest case, we can just put the missing value in the termmap use that. In case of - /// text we put a special u64::MAX and replace it at the end with the actual missing value, - /// when loading the text. + /// In the simplest case, we can just put the missing value in the termmap and use that. In + /// case of text we put a special u64::MAX and replace it at the end with the actual + /// missing value, when loading the text. /// Special Case 1: /// If we have multiple columns on one field, we need to have a union on the indices on both /// columns, to find docids without a value. That requires a special missing aggregation. @@ -334,85 +333,9 @@ impl TermsAggregationInternal { } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector { - #[inline(always)] - fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self { - let sub_agg = sub_agg_blueprint_opt.clone_box(); - BufAggregationCollector::new(sub_agg) - } -} - -#[derive(Debug, Clone)] -struct BoxedAggregation(Box); - -impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation { - #[inline(always)] - fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self { - BoxedAggregation(sub_agg_blueprint.clone_box()) - } -} - -impl SegmentAggregationCollector for BoxedAggregation { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - self.0 - .add_intermediate_aggregation_result(agg_data, results) - } - - #[inline(always)] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect(doc, agg_data) - } - - #[inline(always)] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect_block(docs, agg_data) - } -} - -#[derive(Debug, Clone, Copy)] -struct NoSubAgg; - -impl SegmentAggregationCollector for NoSubAgg { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - _agg_data: &AggregationsSegmentCtx, - _results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect( - &mut self, - _doc: crate::DocId, - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect_block( - &mut self, - _docs: &[crate::DocId], - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } -} +/// The treshold for maximum number of terms to use a Vec-backed bucket storage. +/// TODO: Benchmark to validate the threshold +pub const MAX_NUM_TERMS_FOR_VEC: u64 = 100; /// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed /// bucket storage, depending on the column type and aggregation level. @@ -420,11 +343,8 @@ pub(crate) fn build_segment_term_collector( req_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result> { - let accessor_idx = node.idx_in_req_data; - let column_type = { - let terms_req_data = req_data.get_term_req_data(accessor_idx); - terms_req_data.column_type - }; + let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data).clone(); + let column_type = terms_req_data.column_type; if column_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( @@ -434,7 +354,6 @@ pub(crate) fn build_segment_term_collector( // Validate sub aggregation exists when ordering by sub-aggregation. { - let terms_req_data = req_data.get_term_req_data(accessor_idx); if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); @@ -450,127 +369,115 @@ pub(crate) fn build_segment_term_collector( // Build sub-aggregation blueprint if there are children. let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - { - let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx); - terms_req_data_mut.sub_aggregation_blueprint = blueprint; - } - - // Decide whether to use a Vec-backed or HashMap-backed bucket storage. - let terms_req_data = req_data.get_term_req_data(accessor_idx); // TODO: A better metric instead of is_top_level would be the number of buckets expected. // E.g. If term agg is not top level, but the parent is a bucket agg with less than 10 buckets, // we can still use Vec. - let can_use_vec = terms_req_data.is_top_level; - - // TODO: Benchmark to validate the threshold - const MAX_NUM_TERMS_FOR_VEC: usize = 100; + let is_top_level = terms_req_data.is_top_level; // Let's see if we can use a vec to aggregate our data // instead of a hashmap. let col_max_value = terms_req_data.accessor.max_value(); - let max_term: usize = - col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize; + let max_term_id: u64 = + col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)); - // - use a Vec instead of a hashmap for our aggregation. - // - buffer aggregation of our child aggregations (in any) - #[allow(clippy::collapsible_else_if)] - if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { - if has_sub_aggregations { - let sub_agg_blueprint = &req_data - .get_term_req_data_mut(accessor_idx) - .sub_aggregation_blueprint - .as_ref() - .ok_or_else(|| { - // Handle the error case here - // For example, return an error message or a default value - TantivyError::InternalError("Sub-aggregation blueprint not found".to_string()) - })?; - let term_buckets = VecTermBuckets::new(max_term + 1, || { - let collector_clone = sub_agg_blueprint.clone_box(); - BufAggregationCollector::new(collector_clone) - }); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + let sub_agg_collector = if has_sub_aggregations { + Some(build_segment_agg_collectors(req_data, &node.children)?) } else { - if has_sub_aggregations { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + None + }; + + let mut bucket_id_provider = BucketIdProvider::default(); + // Decide which bucket storage is best suited for this aggregation. + if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations { + let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider); + let collector: SegmentTermCollector<_, HighCardSubAggCache> = SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg: None, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC { + let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider); + let sub_agg = sub_agg_collector.map(LowCardCachedSubAggs::new); + let collector: SegmentTermCollector<_, LowCardSubAggCache> = SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else if max_term_id < 8_000_000 && is_top_level { + let term_buckets: PagedTermMap = + PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::new); + let collector: SegmentTermCollector = + SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::new); + let collector: SegmentTermCollector = + SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) } } -#[derive(Debug, Clone)] -struct Bucket { +#[derive(Debug, Clone, Copy, Default)] +struct Bucket { pub count: u32, - pub sub_agg: SubAgg, + pub bucket_id: BucketId, } -impl Bucket { +impl Bucket { #[inline(always)] - fn new(sub_agg: SubAgg) -> Self { - Self { count: 0, sub_agg } + fn new(bucket_id: BucketId) -> Self { + Self { + count: 0, + bucket_id, + } } } /// Abstraction over the storage used for term buckets (counts only). trait TermAggregationMap: Clone + Debug + 'static { - type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static; + /// Create a new instance with a strict upper bound on term ids. + fn new(max_term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> Self; /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize; - /// Returns the bucket associated to a given term_id. - fn term_entry( - &mut self, - term_id: u64, - blue_print: &dyn SegmentAggregationCollector, - ) -> &mut Bucket; - - /// If the tree of aggregations contains buffered aggregations, flush them. - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>; + /// Increments the count and returns the bucket_id associated to a given term_id. + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId; /// Returns the term aggregation as a vector of (term_id, bucket) pairs, /// in any order. - fn into_vec(self) -> Vec<(u64, Bucket)>; + fn into_vec(self) -> Vec<(u64, Bucket)>; } #[derive(Clone, Debug)] -struct HashMapTermBuckets { - bucket_map: FxHashMap>, +struct HashMapTermBuckets { + bucket_map: FxHashMap, } -impl Default for HashMapTermBuckets { +impl Default for HashMapTermBuckets { #[inline(always)] fn default() -> Self { Self { @@ -579,83 +486,188 @@ impl Default for HashMapTermBuckets { } } -impl< - SubAgg: Debug - + Clone - + SegmentAggregationCollector - + for<'a> From<&'a dyn SegmentAggregationCollector> - + 'static, - > TermAggregationMap for HashMapTermBuckets -{ - type SubAggregation = SubAgg; +const PAGE_SHIFT: usize = 10; +const PAGE_SIZE: usize = 1 << PAGE_SHIFT; // 1024 +const PAGE_MASK: usize = PAGE_SIZE - 1; +const BITMASK_LEN: usize = PAGE_SIZE / 64; +#[derive(Clone, Debug)] +struct Page { + /// Bitmask indicating which offsets are present. + /// It is chunked into TinySet words. + presence: [TinySet; BITMASK_LEN], + data: [Bucket; PAGE_SIZE], +} + +impl Page { + fn new() -> Self { + Self { + presence: [TinySet::empty(); BITMASK_LEN], + data: [Bucket::default(); PAGE_SIZE], + } + } + + #[inline] + fn is_set(&self, offset: usize) -> bool { + let bucket_idx = offset / 64; + let bit_idx = offset % 64; + self.presence[bucket_idx].contains(bit_idx as u32) + } + + #[inline] + fn set_present(&mut self, offset: usize) { + let bucket_idx = offset / 64; + let bit_idx = offset % 64; + self.presence[bucket_idx].insert_mut(bit_idx as u32); + } + + // Flattened iteration logic + fn collect_items(&self, base_term_id: u64, result: &mut Vec<(u64, Bucket)>) { + for (bucket_pos, &tiny_set) in self.presence.iter().enumerate() { + let base_offset = bucket_pos * 64; + + for bit in tiny_set.into_iter() { + let offset = base_offset + bit as usize; + result.push((base_term_id + offset as u64, self.data[offset])); + } + } + } +} + +/// A paged term map implementation for moderate sized term id sets. +/// Uses a fixed size vector of pages, each page containing a fixed size array of buckets. +/// +/// Each page covers a range of term ids. Pages are allocated on demand. +/// This implementation is more memory efficient than a full Vec for high cardinality term id sets, +/// +/// It has a fixed cost of `num_pages * 8 bytes` for the page directory. +/// For 1 million terms, this is 8 * 1024 = 8KB. +/// +/// Note that for nested aggregations we create one TermAggregationMap per parent bucket. +/// For example, with 100 parent buckets and 1 million terms, this is 800KB overhead for the page +/// directories only. Therefore, this implementation is only enabled for top-level aggregations +/// TODO: pass expected number of buckets from parent instead of strict is_top_level flag. +#[derive(Clone, Debug, Default)] +struct PagedTermMap { + // Fixed size vector based on max_term_id + pages: Vec>>, + mem_usage: usize, +} + +impl PagedTermMap {} + +impl TermAggregationMap for PagedTermMap { + #[inline] + fn get_memory_consumption(&self) -> usize { + self.mem_usage + std::mem::size_of::() + } + + #[inline] + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let term_id = term_id as usize; + let page_idx = term_id >> PAGE_SHIFT; + let offset = term_id & PAGE_MASK; + + // This panics if term_id > max_term_id + let page = match &mut self.pages[page_idx] { + Some(p) => p, + None => { + let new_page = Box::new(Page::new()); + self.mem_usage += std::mem::size_of::(); + self.pages[page_idx] = Some(new_page); + self.pages[page_idx].as_mut().unwrap() + } + }; + + if page.is_set(offset) { + let bucket = &mut page.data[offset]; + bucket.count += 1; + bucket.bucket_id + } else { + let new_id = bucket_id_provider.next_bucket_id(); + page.data[offset] = Bucket { + count: 1, + bucket_id: new_id, + }; + page.set_present(offset); + new_id + } + } + + fn into_vec(self) -> Vec<(u64, Bucket)> { + // estimate 16 entries per non-empty page + let estimated_count = self.pages.iter().filter(|p| p.is_some()).count() * 16; + let mut result = Vec::with_capacity(estimated_count); + + for (i, page_opt) in self.pages.into_iter().enumerate() { + if let Some(page) = page_opt { + let base_term_id = (i << PAGE_SHIFT) as u64; + page.collect_items(base_term_id, &mut result); + } + } + result + } + + /// Initialize with a strict upper bound. + /// Panics if you try to insert a term_id > max_term_id. + fn new(max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + let max_page_idx = (max_term_id as usize) >> PAGE_SHIFT; + let num_pages = max_page_idx + 1; + + // Pre-allocate the directory (pointers only, not the heavy pages) + // Memory cost: num_pages * 8 bytes + let pages = vec![None; num_pages]; + + let mem_usage = pages.capacity() * std::mem::size_of::>>(); + + Self { pages, mem_usage } + } +} + +impl TermAggregationMap for HashMapTermBuckets { #[inline] fn get_memory_consumption(&self) -> usize { self.bucket_map.memory_consumption() } #[inline(always)] - fn term_entry( - &mut self, - term_id: u64, - sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { - self.bucket_map + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let bucket = self + .bucket_map .entry(term_id) - .or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint))) + .or_insert_with(|| Bucket::new(bucket_id_provider.next_bucket_id())); + bucket.count += 1; + bucket.bucket_id } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.bucket_map.values_mut() { - bucket.sub_agg.flush(agg_data)?; - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.bucket_map.into_iter().collect() } + + #[inline] + fn new(_max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + Self::default() + } } /// An optimized term map implementation for a compact set of term ordinals. #[derive(Clone, Debug)] -struct VecTermBuckets { - buckets: Vec>, +struct VecTermBucketsNoAgg { + buckets: Vec, } -impl VecTermBuckets { - fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self { - VecTermBuckets { - buckets: std::iter::repeat_with(item_factory_fn) - .map(Bucket::new) - .take(num_terms) - .collect(), - } - } -} - -impl TermAggregationMap - for VecTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for VecTermBucketsNoAgg { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize { // We do not include `std::mem::size_of::()` // It is already measure by the parent aggregation. // - // The root aggregation mem size is not measure but we do not care. - self.buckets.capacity() * std::mem::size_of::>() + self.buckets.capacity() * std::mem::size_of::() } /// Add an occurrence of the given term id. #[inline(always)] - fn term_entry( - &mut self, - term_id: u64, - _sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + fn term_entry(&mut self, term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> BucketId { let term_id_usize = term_id as usize; debug_assert!( term_id_usize < self.buckets.len(), @@ -663,20 +675,69 @@ impl TermAggregat term_id, self.buckets.len() ); - unsafe { self.buckets.get_unchecked_mut(term_id_usize) } + let count = unsafe { self.buckets.get_unchecked_mut(term_id_usize) }; + *count += 1; + 0 // unused } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in &mut self.buckets { - if bucket.count > 0 { - bucket.sub_agg.flush(agg_data)?; - } + fn into_vec(self) -> Vec<(u64, Bucket)> { + self.buckets + .into_iter() + .enumerate() + .filter(|(_term_id, count)| *count > 0) + .map(|(term_id, count)| { + ( + term_id as u64, + Bucket { + count, + bucket_id: 0, // unused, there are no sub-aggregations + }, + ) + }) + .collect() + } + + fn new(num_terms: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + Self { + buckets: std::iter::repeat_with(|| 0) + .take(num_terms as usize) + .collect(), } - Ok(()) + } +} + +/// An optimized term map implementation for a compact set of term ordinals. +#[derive(Clone, Debug)] +struct VecTermBuckets { + buckets: Vec, +} + +impl TermAggregationMap for VecTermBuckets { + /// Estimate the memory consumption of this struct in bytes. + fn get_memory_consumption(&self) -> usize { + // We do not include `std::mem::size_of::()` + // It is already measure by the parent aggregation. + // + // The root aggregation mem size is not measure but we do not care. + self.buckets.capacity() * std::mem::size_of::() } - fn into_vec(self) -> Vec<(u64, Bucket)> { + /// Add an occurrence of the given term id. + #[inline(always)] + fn term_entry(&mut self, term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let term_id_usize = term_id as usize; + debug_assert!( + term_id_usize < self.buckets.len(), + "term_id {} out of bounds for VecTermBuckets (len={})", + term_id, + self.buckets.len() + ); + let bucket = unsafe { self.buckets.get_unchecked_mut(term_id_usize) }; + bucket.count += 1; + bucket.bucket_id + } + + fn into_vec(self) -> Vec<(u64, Bucket)> { self.buckets .into_iter() .enumerate() @@ -684,22 +745,26 @@ impl TermAggregat .map(|(term_id, bucket)| (term_id as u64, bucket)) .collect() } -} -impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg { - #[inline(always)] - fn from(_: &'a dyn SegmentAggregationCollector) -> Self { - Self + fn new(num_terms: u64, bucket_id_provider: &mut BucketIdProvider) -> Self { + VecTermBuckets { + buckets: std::iter::repeat_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) + .take(num_terms as usize) + .collect(), + } } } /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -struct SegmentTermCollector { +#[derive(Debug)] +struct SegmentTermCollector { /// The buckets containing the aggregation data. - term_buckets: TermMap, - accessor_idx: usize, + parent_buckets: Vec, + sub_agg: Option>, + bucket_id_provider: BucketIdProvider, + max_term_id: u64, + terms_req_data: TermsAggReqData, } pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { @@ -707,18 +772,26 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { (agg_name, agg_property) } -impl SegmentAggregationCollector for SegmentTermCollector -where - TermMap: TermAggregationMap, - TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>, +impl SegmentAggregationCollector + for SegmentTermCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + bucket: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + // TODO: avoid prepare_max_bucket here and handle empty buckets. + self.prepare_max_bucket(bucket, agg_data)?; + let bucket = std::mem::replace( + &mut self.parent_buckets[bucket as usize], + TermMap::new(0, &mut self.bucket_id_provider), + ); + let term_req = &self.terms_req_data; + let name = term_req.name.clone(); + + let bucket = + Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) } @@ -726,65 +799,49 @@ where #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req_data = agg_data.take_term_req_data(self.accessor_idx); - let mem_pre = self.get_memory_consumption(); - if let Some(missing) = req_data.missing_value_for_accessor { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } + let req_data = &mut self.terms_req_data; - if std::any::TypeId::of::() == std::any::TypeId::of::() { - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg); - bucket.count += 1; + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &req_data.accessor, + req_data.missing_value_for_accessor, + ); + + if let Some(sub_agg) = &mut self.sub_agg { + let term_buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + let it = agg_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&(_doc, term_id)| allowed_bs.contains(term_id as u32)); + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); + } else { + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); } } else { - let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref() - else { - return Err(TantivyError::InternalError( - "Could not find sub-aggregation blueprint".to_string(), - )); - }; - for (doc, term_id) in req_data - .column_block_accessor - .iter_docid_vals(docs, &req_data.accessor) - { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self - .term_buckets - .term_entry(term_id, sub_aggregation_blueprint); - bucket.count += 1; - bucket.sub_agg.collect(doc, agg_data)?; + let term_buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + let it = agg_data.column_block_accessor.iter_vals(); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&term_id| allowed_bs.contains(term_id as u32)); + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); + } else { + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); } } @@ -795,14 +852,31 @@ where .limits .add_memory_consumed(mem_delta as u64)?; } - agg_data.put_back_term_req_data(self.accessor_idx, req_data); + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - #[inline(always)] + #[inline] fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.flush(agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.flush(agg_data)?; + } + Ok(()) + } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let term_buckets: TermMap = + TermMap::new(self.max_term_id, &mut self.bucket_id_provider); + self.parent_buckets.push(term_buckets); + } Ok(()) } } @@ -831,20 +905,26 @@ fn extract_missing_value( Some((key, bucket)) } -impl SegmentTermCollector -where TermMap: TermAggregationMap +impl SegmentTermCollector +where + TermMap: TermAggregationMap, + C: SubAggCache, { fn get_memory_consumption(&self) -> usize { - self.term_buckets.get_memory_consumption() + self.parent_buckets + .iter() + .map(|b| b.get_memory_consumption()) + .sum() } #[inline] pub(crate) fn into_intermediate_bucket_result( - self, + term_req: &TermsAggReqData, + sub_agg: &mut Option>, + term_buckets: TermMap, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, Bucket)> = self.term_buckets.into_vec(); + let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec(); let order_by_sub_aggregation = matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); @@ -884,23 +964,28 @@ where TermMap: TermAggregationMap dict.reserve(entries.len()); let into_intermediate_bucket_entry = - |bucket: Bucket| -> crate::Result { - let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { + |bucket: Bucket, + sub_agg: &mut Option>| + -> crate::Result { + if let Some(sub_agg) = sub_agg { let mut sub_aggregation_res = IntermediateAggregationResults::default(); - // TODO remove box new - Box::new(bucket.sub_agg) - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - IntermediateTermBucketEntry { + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + bucket.bucket_id, + )?; + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: sub_aggregation_res, - } + }) } else { - IntermediateTermBucketEntry { + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: Default::default(), - } - }; - Ok(intermediate_entry) + }) + } }; if term_req.column_type == ColumnType::Str { @@ -913,21 +998,20 @@ where TermMap: TermAggregationMap if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req) { - let intermediate_entry = into_intermediate_bucket_entry(bucket)?; + let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?; dict.insert(intermediate_key, intermediate_entry); } // Sort by term ord entries.sort_unstable_by_key(|bucket| bucket.0); - let (term_ids, buckets): (Vec, Vec>) = - entries.into_iter().unzip(); + let (term_ids, buckets): (Vec, Vec) = entries.into_iter().unzip(); let mut buckets_it = buckets.into_iter(); term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| { let bucket = buckets_it.next().unwrap(); let intermediate_entry = - into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?; + into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?; dict.insert( IntermediateKey::Str( String::from_utf8(term.to_vec()).expect("could not convert to String"), @@ -969,14 +1053,14 @@ where TermMap: TermAggregationMap } } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } @@ -996,14 +1080,14 @@ where TermMap: TermAggregationMap })?; for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val = Ipv6Addr::from_u128(val); dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); } } else { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); } else if term_req.column_type == ColumnType::I64 { @@ -1037,6 +1121,32 @@ where TermMap: TermAggregationMap } } +impl SegmentTermCollector { + #[inline] + fn collect_terms_with_docs( + iter: impl Iterator, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + sub_agg: &mut CachedSubAggs, + ) { + for (doc, term_id) in iter { + let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider); + sub_agg.push(bucket_id, doc); + } + } + + #[inline] + fn collect_terms( + iter: impl Iterator, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + ) { + for term_id in iter { + term_buckets.term_entry(term_id, bucket_id_provider); + } + } +} + pub(crate) trait GetDocCount { fn doc_count(&self) -> u64; } @@ -1047,7 +1157,7 @@ impl GetDocCount for (String, IntermediateTermBucketEntry) { } } -impl GetDocCount for (u64, Bucket) { +impl GetDocCount for (u64, Bucket) { fn doc_count(&self) -> u64 { self.1.count as u64 } @@ -1079,8 +1189,10 @@ mod tests { use common::DateTime; use time::{Date, Month}; + use super::{PagedTermMap, TermAggregationMap, PAGE_SIZE}; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; + use crate::aggregation::segment_agg_result::BucketIdProvider; use crate::aggregation::tests::{ exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, @@ -1091,6 +1203,43 @@ mod tests { use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING}; use crate::{Index, IndexWriter}; + #[test] + fn paged_term_map_reuses_buckets_and_counts() { + let mut bucket_id_provider = BucketIdProvider::default(); + let mut map = PagedTermMap::new((PAGE_SIZE * 2) as u64, &mut bucket_id_provider); + + let bucket_first = map.term_entry(5, &mut bucket_id_provider); + let bucket_second_page = map.term_entry((PAGE_SIZE + 7) as u64, &mut bucket_id_provider); + + // Reinsertions should increment counts and reuse bucket ids + assert_eq!(map.term_entry(5, &mut bucket_id_provider), bucket_first); + assert_eq!( + map.term_entry((PAGE_SIZE + 7) as u64, &mut bucket_id_provider), + bucket_second_page + ); + + // High offset exercises the TinySet presence word boundaries. + let bucket_high_bit = map.term_entry(63, &mut bucket_id_provider); + + let mut entries = map.into_vec(); + entries.sort_by_key(|(term_id, _)| *term_id); + + let expected = vec![ + (5u64, bucket_first, 2u32), + (63u64, bucket_high_bit, 1u32), + ((PAGE_SIZE + 7) as u64, bucket_second_page, 2u32), + ]; + + assert_eq!(entries.len(), expected.len()); + for ((term_id, bucket), (expected_term, expected_bucket_id, expected_count)) in + entries.into_iter().zip(expected) + { + assert_eq!(term_id, expected_term); + assert_eq!(bucket.bucket_id, expected_bucket_id); + assert_eq!(bucket.count, expected_count); + } + } + #[test] fn terms_aggregation_test_single_segment() -> crate::Result<()> { terms_aggregation_test_merge_segment(true) diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 2baa7bbc8..e246ccfc3 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; use crate::aggregation::bucket::term_agg::TermsAggregation; +use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; /// Special aggregation to handle missing values for term aggregations. /// This missing aggregation will check multiple columns for existence. @@ -35,41 +37,55 @@ impl MissingTermAggReqData { } } -/// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] -pub struct TermMissingAgg { +struct MissingCount { missing_count: u32, + bucket_id: BucketId, +} + +/// The specialized missing term aggregation. +#[derive(Default, Debug)] +pub struct TermMissingAgg { accessor_idx: usize, - sub_agg: Option>, + sub_agg: Option, + /// Idx = parent bucket id, Value = missing count for that bucket + missing_count_per_bucket: Vec, + bucket_id_provider: BucketIdProvider, } impl TermMissingAgg { pub(crate) fn new( - req_data: &mut AggregationsSegmentCtx, + agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { let has_sub_aggregations = !node.children.is_empty(); let accessor_idx = node.idx_in_req_data; let sub_agg = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?; Some(sub_aggregation) } else { None }; + let sub_agg = sub_agg.map(CachedSubAggs::new); + let bucket_id_provider = BucketIdProvider::default(); + Ok(Self { accessor_idx, sub_agg, - ..Default::default() + missing_count_per_bucket: Vec::new(), + bucket_id_provider, }) } } impl SegmentAggregationCollector for TermMissingAgg { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let term_agg = &req_data.req; let missing = term_agg @@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg { let mut entries: FxHashMap = Default::default(); + let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize]; let mut missing_entry = IntermediateTermBucketEntry { - doc_count: self.missing_count, + doc_count: missing_count.missing_count, sub_aggregation: Default::default(), }; - if let Some(sub_agg) = self.sub_agg { + if let Some(sub_agg) = &mut self.sub_agg { let mut res = IntermediateAggregationResults::default(); - sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?; missing_entry.sub_aggregation = res; } entries.insert(missing.into(), missing_entry); @@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize]; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); - let has_value = req_data - .accessors - .iter() - .any(|(acc, _)| acc.index.has_value(doc)); - if !has_value { - self.missing_count += 1; - if let Some(sub_agg) = self.sub_agg.as_mut() { - sub_agg.collect(doc, agg_data)?; + + for doc in docs { + let doc = *doc; + let has_value = req_data + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); + if !has_value { + bucket.missing_count += 1; + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket_id, doc); + } } } + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - for doc in docs { - self.collect(*doc, agg_data)?; + while self.missing_count_per_bucket.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.missing_count_per_bucket.push(MissingCount { + missing_count: 0, + bucket_id, + }); + } + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs deleted file mode 100644 index 17bc1ed35..000000000 --- a/src/aggregation/buf_collector.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::agg_data::AggregationsSegmentCtx; -use crate::DocId; - -#[cfg(test)] -pub(crate) const DOC_BLOCK_SIZE: usize = 64; - -#[cfg(not(test))] -pub(crate) const DOC_BLOCK_SIZE: usize = 256; - -pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; - -/// BufAggregationCollector buffers documents before calling collect_block(). -#[derive(Clone)] -pub(crate) struct BufAggregationCollector { - pub(crate) collector: Box, - staged_docs: DocBlock, - num_staged_docs: usize, -} - -impl std::fmt::Debug for BufAggregationCollector { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SegmentAggregationResultsCollector") - .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) - .field("num_staged_docs", &self.num_staged_docs) - .finish() - } -} - -impl BufAggregationCollector { - pub fn new(collector: Box) -> Self { - Self { - collector, - num_staged_docs: 0, - staged_docs: [0; DOC_BLOCK_SIZE], - } - } -} - -impl SegmentAggregationCollector for BufAggregationCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.staged_docs[self.num_staged_docs] = doc; - self.num_staged_docs += 1; - if self.num_staged_docs == self.staged_docs.len() { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - } - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collector.collect_block(docs, agg_data)?; - Ok(()) - } - - #[inline] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - - self.collector.flush(agg_data)?; - - Ok(()) - } -} diff --git a/src/aggregation/cached_sub_aggs.rs b/src/aggregation/cached_sub_aggs.rs new file mode 100644 index 000000000..f97da31ab --- /dev/null +++ b/src/aggregation/cached_sub_aggs.rs @@ -0,0 +1,245 @@ +use std::fmt::Debug; + +use super::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC; +use crate::aggregation::BucketId; +use crate::DocId; + +/// A cache for sub-aggregations, storing doc ids per bucket id. +/// Depending on the cardinality of the parent aggregation, we use different +/// storage strategies. +/// +/// ## Low Cardinality +/// Cardinality here refers to the number of unique flattened buckets that can be created +/// by the parent aggregation. +/// Flattened buckets are the result of combining all buckets per collector +/// into a single list of buckets, where each bucket is identified by its BucketId. +/// +/// ## Usage +/// Since this is caching for sub-aggregations, it is only used by bucket +/// aggregations. +/// +/// TODO: consider using a more advanced data structure for high cardinality +/// aggregations. +/// What this datastructure does in general is to group docs by bucket id. +#[derive(Debug)] +pub(crate) struct CachedSubAggs { + cache: C, + sub_agg_collector: Box, + num_docs: usize, +} + +pub type LowCardCachedSubAggs = CachedSubAggs; +pub type HighCardCachedSubAggs = CachedSubAggs; + +const FLUSH_THRESHOLD: usize = 2048; + +/// A trait for caching sub-aggregation doc ids per bucket id. +/// Different implementations can be used depending on the cardinality +/// of the parent aggregation. +pub trait SubAggCache: Debug { + fn new() -> Self; + fn push(&mut self, bucket_id: BucketId, doc_id: DocId); + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + force: bool, + ) -> crate::Result<()>; +} + +impl CachedSubAggs { + pub fn new(sub_agg: Box) -> Self { + Self { + cache: Backend::new(), + sub_agg_collector: sub_agg, + num_docs: 0, + } + } + + pub fn get_sub_agg_collector(&mut self) -> &mut Box { + &mut self.sub_agg_collector + } + + #[inline] + pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + self.cache.push(bucket_id, doc_id); + self.num_docs += 1; + } + + /// Check if we need to flush based on the number of documents cached. + /// If so, flushes the cache to the provided aggregation collector. + pub fn check_flush_local( + &mut self, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if self.num_docs >= FLUSH_THRESHOLD { + self.cache + .flush_local(&mut self.sub_agg_collector, agg_data, false)?; + self.num_docs = 0; + } + Ok(()) + } + + /// Note: this _does_ flush the sub aggregations. + pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if self.num_docs != 0 { + self.cache + .flush_local(&mut self.sub_agg_collector, agg_data, true)?; + self.num_docs = 0; + } + self.sub_agg_collector.flush(agg_data)?; + Ok(()) + } +} + +/// Number of partitions for high cardinality sub-aggregation cache. +const NUM_PARTITIONS: usize = 16; + +#[derive(Debug)] +pub(crate) struct HighCardSubAggCache { + /// This weird partitioning is used to do some cheap grouping on the bucket ids. + /// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality, + /// but there are just 16 bucket ids, each bucket id will go to its own partition. + /// + /// We want to keep this cheap, because high cardinality aggregations can have a lot of + /// buckets, and there may be nothing to group. + partitions: Box<[PartitionEntry; NUM_PARTITIONS]>, +} + +impl HighCardSubAggCache { + #[inline] + fn clear(&mut self) { + for partition in self.partitions.iter_mut() { + partition.clear(); + } + } +} + +#[derive(Debug, Clone, Default)] +struct PartitionEntry { + bucket_ids: Vec, + docs: Vec, +} + +impl PartitionEntry { + #[inline] + fn clear(&mut self) { + self.bucket_ids.clear(); + self.docs.clear(); + } +} + +impl SubAggCache for HighCardSubAggCache { + fn new() -> Self { + Self { + partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())), + } + } + + fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + let idx = bucket_id % NUM_PARTITIONS as u32; + let slot = &mut self.partitions[idx as usize]; + slot.bucket_ids.push(bucket_id); + slot.docs.push(doc_id); + } + + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + _force: bool, + ) -> crate::Result<()> { + let mut max_bucket = 0u32; + for partition in self.partitions.iter() { + if let Some(&local_max) = partition.bucket_ids.iter().max() { + max_bucket = max_bucket.max(local_max); + } + } + + sub_agg.prepare_max_bucket(max_bucket, agg_data)?; + + for slot in self.partitions.iter() { + if !slot.bucket_ids.is_empty() { + // Reduce dynamic dispatch overhead by collecting a full partition in one call. + sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?; + } + } + + self.clear(); + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct LowCardSubAggCache { + /// Cache doc ids per bucket for sub-aggregations. + /// + /// The outer Vec is indexed by BucketId. + per_bucket_docs: Vec>, +} + +impl LowCardSubAggCache { + #[inline] + fn clear(&mut self) { + for v in &mut self.per_bucket_docs { + v.clear(); + } + } +} + +impl SubAggCache for LowCardSubAggCache { + fn new() -> Self { + Self { + per_bucket_docs: Vec::new(), + } + } + + fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + let idx = bucket_id as usize; + if self.per_bucket_docs.len() <= idx { + self.per_bucket_docs.resize_with(idx + 1, Vec::new); + } + self.per_bucket_docs[idx].push(doc_id); + } + + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + force: bool, + ) -> crate::Result<()> { + // Pre-aggregated: call collect per bucket. + let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1); + sub_agg.prepare_max_bucket(max_bucket, agg_data)?; + // The threshold above which we flush buckets individually. + // Note: We need to make sure that we don't lock ourselves into a situation where we hit + // the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush) + let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2); + const _: () = { + // MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations + // Note: There may be other flexible values, for other aggregations, but we can use the + // const value here as a upper bound. (better than nothing) + let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2); + assert!( + bucket_treshold_limit > 0, + "Bucket threshold must be greater than 0" + ); + }; + if force { + bucket_treshold = 0; + } + for (bucket_id, docs) in self + .per_bucket_docs + .iter() + .enumerate() + .filter(|(_, docs)| docs.len() > bucket_treshold) + { + sub_agg.collect(bucket_id as BucketId, docs, agg_data)?; + } + + self.clear(); + Ok(()) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 4c4c2c7f1..59e9c677d 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,9 +1,9 @@ use super::agg_req::Aggregations; use super::agg_result::AggregationResults; -use super::buf_collector::BufAggregationCollector; +use super::cached_sub_aggs::LowCardCachedSubAggs; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; use super::AggContextParams; +// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly. use crate::aggregation::agg_data::{ build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; @@ -136,7 +136,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsSegmentCtx, - agg_collector: BufAggregationCollector, + agg_collector: LowCardCachedSubAggs, error: Option, } @@ -151,8 +151,11 @@ impl AggregationSegmentCollector { ) -> crate::Result { let mut agg_data = build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; - let result = - BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); + let mut result = + LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?); + result + .get_sub_agg_collector() + .prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero Ok(AggregationSegmentCollector { aggs_with_accessor: agg_data, @@ -170,26 +173,31 @@ impl SegmentCollector for AggregationSegmentCollector { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.push(0, doc); + match self .agg_collector - .collect(doc, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } - - /// The query pushes the documents to the collector via this method. - /// - /// Only valid for Collectors that ignore docs fn collect_block(&mut self, docs: &[DocId]) { if self.error.is_some() { return; } - if let Err(err) = self - .agg_collector - .collect_block(docs, &mut self.aggs_with_accessor) - { - self.error = Some(err); + + match self.agg_collector.get_sub_agg_collector().collect( + 0, + docs, + &mut self.aggs_with_accessor, + ) { + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } @@ -200,10 +208,13 @@ impl SegmentCollector for AggregationSegmentCollector { self.agg_collector.flush(&mut self.aggs_with_accessor)?; let mut sub_aggregation_res = IntermediateAggregationResults::default(); - Box::new(self.agg_collector).add_intermediate_aggregation_result( - &self.aggs_with_accessor, - &mut sub_aggregation_res, - )?; + self.agg_collector + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + 0, + )?; Ok(sub_aggregation_res) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 104131461..b20e8a042 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry { /// The number of documents in the bucket. pub doc_count: u64, /// The sub_aggregation in this bucket. - pub sub_aggregation: IntermediateAggregationResults, + pub sub_aggregation_res: IntermediateAggregationResults, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. @@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, sub_aggregation: self - .sub_aggregation + .sub_aggregation_res .into_final_result_internal(req, limits)?, to: self.to, from: self.from, @@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry { impl MergeFruits for IntermediateRangeBucketEntry { fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> { self.doc_count += other.doc_count; - self.sub_aggregation.merge_fruits(other.sub_aggregation)?; + self.sub_aggregation_res + .merge_fruits(other.sub_aggregation_res)?; Ok(()) } } @@ -887,7 +888,7 @@ mod tests { IntermediateRangeBucketEntry { key: IntermediateKey::Str(key.to_string()), doc_count: *doc_count, - sub_aggregation: Default::default(), + sub_aggregation_res: Default::default(), from: None, to: None, }, @@ -920,7 +921,7 @@ mod tests { doc_count: *doc_count, from: None, to: None, - sub_aggregation: get_sub_test_tree(&[( + sub_aggregation_res: get_sub_test_tree(&[( sub_aggregation_key.to_string(), *sub_aggregation_count, )]), diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index e707f2b00..57f694984 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -52,10 +52,8 @@ pub struct IntermediateAverage { impl IntermediateAverage { /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateAverage) { diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 8f3bdd3e5..c184848d8 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{BuildHasher, Hasher}; use columnar::column_values::CompactSpaceU64Accessor; -use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn}; +use columnar::{Column, ColumnType, Dictionary, StrColumn}; use common::f64_to_u64; use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; use rustc_hash::FxHashSet; @@ -106,8 +106,6 @@ pub struct CardinalityAggReqData { pub str_dict_column: Option, /// The missing value normalized to the internal u64 representation of the field type. pub missing_value_for_accessor: Option, - /// The column block accessor to access the fast field values. - pub(crate) column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, /// The aggregation request. @@ -135,45 +133,34 @@ impl CardinalityAggregationReq { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentCardinalityCollector { - cardinality: CardinalityCollector, - entries: FxHashSet, + buckets: Vec, accessor_idx: usize, + /// The column accessor to access the fast field values. + accessor: Column, + /// The column_type of the field. + column_type: ColumnType, + /// The missing value normalized to the internal u64 representation of the field type. + missing_value_for_accessor: Option, } -impl SegmentCardinalityCollector { - pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { +#[derive(Clone, Debug, PartialEq, Default)] +pub(crate) struct SegmentCardinalityCollectorBucket { + cardinality: CardinalityCollector, + entries: FxHashSet, +} +impl SegmentCardinalityCollectorBucket { + pub fn new(column_type: ColumnType) -> Self { Self { cardinality: CardinalityCollector::new(column_type as u8), - entries: Default::default(), - accessor_idx, + entries: FxHashSet::default(), } } - - fn fetch_block_with_field( - &mut self, - docs: &[crate::DocId], - agg_data: &mut CardinalityAggReqData, - ) { - if let Some(missing) = agg_data.missing_value_for_accessor { - agg_data.column_block_accessor.fetch_block_with_missing( - docs, - &agg_data.accessor, - missing, - ); - } else { - agg_data - .column_block_accessor - .fetch_block(docs, &agg_data.accessor); - } - } - fn into_intermediate_metric_result( mut self, - agg_data: &AggregationsSegmentCtx, + req_data: &CardinalityAggReqData, ) -> crate::Result { - let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); if req_data.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); let dict = req_data @@ -194,6 +181,7 @@ impl SegmentCardinalityCollector { term_ids.push(term_ord as u32); } } + term_ids.sort_unstable(); dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| { self.cardinality.sketch.insert_any(&term); @@ -227,16 +215,49 @@ impl SegmentCardinalityCollector { } } +impl SegmentCardinalityCollector { + pub fn from_req( + column_type: ColumnType, + accessor_idx: usize, + accessor: Column, + missing_value_for_accessor: Option, + ) -> Self { + Self { + buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1], + column_type, + accessor_idx, + accessor, + missing_value_for_accessor, + } + } + + fn fetch_block_with_field( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) { + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_value_for_accessor, + ); + } +} + impl SegmentAggregationCollector for SegmentCardinalityCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); let name = req_data.name.to_string(); + // take the bucket in buckets and replace it with a new empty one + let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); - let intermediate_result = self.into_intermediate_metric_result(agg_data)?; + let intermediate_result = bucket.into_intermediate_metric_result(req_data)?; results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); - self.fetch_block_with_field(docs, req_data); + self.fetch_block_with_field(docs, agg_data); + let bucket = &mut self.buckets[parent_bucket_id as usize]; - let col_block_accessor = &req_data.column_block_accessor; - if req_data.column_type == ColumnType::Str { + let col_block_accessor = &agg_data.column_block_accessor; + if self.column_type == ColumnType::Str { for term_ord in col_block_accessor.iter_vals() { - self.entries.insert(term_ord); + bucket.entries.insert(term_ord); } - } else if req_data.column_type == ColumnType::IpAddr { - let compact_space_accessor = req_data + } else if self.column_type == ColumnType::IpAddr { + let compact_space_accessor = self .accessor .values .clone() @@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { })?; for val in col_block_accessor.iter_vals() { let val: u128 = compact_space_accessor.compact_to_u128(val as u32); - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } else { for val in col_block_accessor.iter_vals() { - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + if max_bucket as usize >= self.buckets.len() { + self.buckets.resize_with(max_bucket as usize + 1, || { + SegmentCardinalityCollectorBucket::new(self.column_type) + }); + } + Ok(()) + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs index ac550a38f..b28ced047 100644 --- a/src/aggregation/metric/count.rs +++ b/src/aggregation/metric/count.rs @@ -52,10 +52,8 @@ pub struct IntermediateCount { impl IntermediateCount { /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateCount) { diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs index d7302e5f5..e71426790 100644 --- a/src/aggregation/metric/extended_stats.rs +++ b/src/aggregation/metric/extended_stats.rs @@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of extended statistics /// on numeric values that are extracted @@ -318,51 +317,28 @@ impl IntermediateExtendedStats { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentExtendedStatsCollector { + name: String, missing: Option, field_type: ColumnType, - pub(crate) extended_stats: IntermediateExtendedStats, - pub(crate) accessor_idx: usize, - val_cache: Vec, + accessor: columnar::Column, + buckets: Vec, + sigma: Option, } impl SegmentExtendedStatsCollector { - pub fn from_req( - field_type: ColumnType, - sigma: Option, - accessor_idx: usize, - missing: Option, - ) -> Self { - let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + pub fn from_req(req: &MetricAggReqData, sigma: Option) -> Self { + let missing = req + .missing + .and_then(|val| f64_to_fastfield_u64(val, &req.field_type)); Self { - field_type, - extended_stats: IntermediateExtendedStats::with_sigma(sigma), - accessor_idx, + name: req.name.clone(), + field_type: req.field_type, + accessor: req.accessor.clone(), missing, - val_cache: Default::default(), - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = self.missing.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); + buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16], + sigma, } } } @@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector { impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + let name = self.name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); results.push( name, IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( - self.extended_stats, + extended_stats, )), )?; @@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = self.missing { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - has_val = true; - } - if !has_val { - self.extended_stats - .collect(f64_from_fastfield_u64(missing, &self.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - } + let mut extended_stats = self.buckets[parent_bucket_id as usize].clone(); + + agg_data + .column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, self.missing); + for val in agg_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, self.field_type); + extended_stats.collect(val1); } + // store back + self.buckets[parent_bucket_id as usize] = extended_stats; + Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + if self.buckets.len() <= max_bucket as usize { + self.buckets.resize_with(max_bucket as usize + 1, || { + IntermediateExtendedStats::with_sigma(self.sigma) + }); + } Ok(()) } } diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs index 89c6e4458..59af7e2de 100644 --- a/src/aggregation/metric/max.rs +++ b/src/aggregation/metric/max.rs @@ -52,10 +52,8 @@ pub struct IntermediateMax { impl IntermediateMax { /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMax) { diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs index 61fd2ecd2..ecf2fcafc 100644 --- a/src/aggregation/metric/min.rs +++ b/src/aggregation/metric/min.rs @@ -52,10 +52,8 @@ pub struct IntermediateMin { impl IntermediateMin { /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMin) { diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 3537af8a6..d3a448a38 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -31,7 +31,7 @@ use std::collections::HashMap; pub use average::*; pub use cardinality::*; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; pub use count::*; pub use extended_stats::*; pub use max::*; @@ -55,8 +55,6 @@ pub struct MetricAggReqData { pub field_type: ColumnType, /// The missing value normalized to the internal u64 representation of the field type. pub missing_u64: Option, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The column accessor to access the fast field values. pub accessor: Column, /// Used when converting to intermediate result diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index c846e2187..ff9de45f1 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// # Percentiles /// @@ -131,10 +130,16 @@ impl PercentilesAggregationReq { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentPercentilesCollector { - pub(crate) percentiles: PercentilesCollector, + pub(crate) buckets: Vec, pub(crate) accessor_idx: usize, + /// The type of the field. + pub field_type: ColumnType, + /// The missing value normalized to the internal u64 representation of the field type. + pub missing_u64: Option, + /// The column accessor to access the fast field values. + pub accessor: Column, } #[derive(Clone, Serialize, Deserialize)] @@ -229,33 +234,18 @@ impl PercentilesCollector { } impl SegmentPercentilesCollector { - pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result { - Ok(Self { - percentiles: PercentilesCollector::new(), + pub fn from_req_and_validate( + field_type: ColumnType, + missing_u64: Option, + accessor: Column, + accessor_idx: usize, + ) -> Self { + Self { + buckets: Vec::with_capacity(64), + field_type, + missing_u64, + accessor, accessor_idx, - }) - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); } } } @@ -263,12 +253,18 @@ impl SegmentPercentilesCollector { impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); - let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + // Swap collector with an empty one to avoid cloning + let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + + let intermediate_metric_result = + IntermediateMetricResult::Percentiles(percentiles_collector); results.push( name, @@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); + let percentiles = &mut self.buckets[parent_bucket_id as usize]; + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_u64, + ); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - has_val = true; - } - if !has_val { - self.percentiles - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } + for val in agg_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, self.field_type); + percentiles.collect(val1); } Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(PercentilesCollector::new()); + } Ok(()) } } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 56715fdea..c43a6a259 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; +use columnar::{Column, ColumnType}; use serde::{Deserialize, Serialize}; use super::*; @@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// are extracted from the aggregated documents. @@ -83,7 +83,7 @@ impl Stats { /// Intermediate result of the stats aggregation that can be combined with other intermediate /// results. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { /// The number of extracted values. pub(crate) count: u64, @@ -187,75 +187,75 @@ pub enum StatsType { Percentiles, } +fn create_collector( + req: &MetricAggReqData, +) -> Box { + Box::new(SegmentStatsCollector:: { + name: req.name.clone(), + collecting_for: req.collecting_for, + is_number_or_date_type: req.is_number_or_date_type, + missing_u64: req.missing_u64, + accessor: req.accessor.clone(), + buckets: vec![IntermediateStats::default()], + }) +} + +/// Build a concrete `SegmentStatsCollector` depending on the column type. +pub(crate) fn build_segment_stats_collector( + req: &MetricAggReqData, +) -> crate::Result> { + match req.field_type { + ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)), + ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)), + ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)), + ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)), + ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)), + ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)), + ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)), + ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)), + } +} + +#[repr(C)] #[derive(Clone, Debug)] -pub(crate) struct SegmentStatsCollector { - pub(crate) stats: IntermediateStats, - pub(crate) accessor_idx: usize, +pub(crate) struct SegmentStatsCollector { + pub(crate) missing_u64: Option, + pub(crate) accessor: Column, + pub(crate) is_number_or_date_type: bool, + pub(crate) buckets: Vec, + pub(crate) name: String, + pub(crate) collecting_for: StatsType, } -impl SegmentStatsCollector { - pub fn from_req(accessor_idx: usize) -> Self { - Self { - stats: IntermediateStats::default(), - accessor_idx, - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - if req_data.is_number_or_date_type { - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } else { - for _val in req_data.column_block_accessor.iter_vals() { - // we ignore the value and simply record that we got something - self.stats.collect(0.0); - } - } - } -} - -impl SegmentAggregationCollector for SegmentStatsCollector { +impl SegmentAggregationCollector + for SegmentStatsCollector +{ #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let req = agg_data.get_metric_req_data(self.accessor_idx); - let name = req.name.clone(); + let name = self.name.clone(); - let intermediate_metric_result = match req.collecting_for { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let stats = self.buckets[parent_bucket_id as usize]; + let intermediate_metric_result = match self.collecting_for { StatsType::Average => { - IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) + IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats)) } StatsType::Count => { - IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) + IntermediateMetricResult::Count(IntermediateCount::from_stats(stats)) } - StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), - StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), - StatsType::Stats => IntermediateMetricResult::Stats(self.stats), - StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), + StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)), + StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)), + StatsType::Stats => IntermediateMetricResult::Stats(stats), + StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)), _ => { return Err(TantivyError::InvalidArgument(format!( "Unsupported stats type for stats aggregation: {:?}", - req.collecting_for + self.collecting_for ))) } }; @@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - has_val = true; - } - if !has_val { - self.stats - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } - - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + // TODO: remove once we fetch all values for all bucket ids in one go + if docs.len() == 1 && self.missing_u64.is_none() { + collect_stats::( + &mut self.buckets[parent_bucket_id as usize], + self.accessor.values_for_doc(docs[0]), + self.is_number_or_date_type, + )?; + + return Ok(()); + } + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_u64, + ); + collect_stats::( + &mut self.buckets[parent_bucket_id as usize], + agg_data.column_block_accessor.iter_vals(), + self.is_number_or_date_type, + )?; + Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + let required_buckets = (max_bucket as usize) + 1; + if self.buckets.len() < required_buckets { + self.buckets + .resize_with(required_buckets, IntermediateStats::default); + } + Ok(()) + } +} + +#[inline] +fn collect_stats( + stats: &mut IntermediateStats, + vals: impl Iterator, + is_number_or_date_type: bool, +) -> crate::Result<()> { + if is_number_or_date_type { + for val in vals { + let val1 = convert_to_f64::(val); + stats.collect(val1); + } + } else { + for _val in vals { + // we ignore the value and simply record that we got something + stats.collect(0.0); + } + } + + Ok(()) } #[cfg(test)] diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs index 86f661679..2487c4e9d 100644 --- a/src/aggregation/metric/sum.rs +++ b/src/aggregation/metric/sum.rs @@ -52,10 +52,8 @@ pub struct IntermediateSum { impl IntermediateSum { /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateSum) { diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 6a8bdf826..54e5a5ced 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -15,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::AggregationError; +use crate::aggregation::{AggregationError, BucketId}; use crate::collector::sort_key::ReverseComparator; use crate::collector::TopNComputer; use crate::schema::OwnedValue; use crate::{DocAddress, DocId, SegmentOrdinal}; -// duplicate import removed; already imported above /// Contains all information required by the TopHitsSegmentCollector to perform the /// top_hits aggregation on a segment. @@ -472,7 +471,10 @@ impl TopHitsTopNComputer { /// Create a new TopHitsCollector pub fn new(req: &TopHitsAggregationReq) -> Self { Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + top_n: TopNComputer::new_with_comparator( + req.size + req.from.unwrap_or(0), + ReverseComparator, + ), req: req.clone(), } } @@ -518,7 +520,8 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, ReverseComparator>, + buckets: Vec, DocAddress, ReverseComparator>>, + num_hits: usize, } impl TopHitsSegmentCollector { @@ -527,19 +530,29 @@ impl TopHitsSegmentCollector { accessor_idx: usize, segment_ordinal: SegmentOrdinal, ) -> Self { + let num_hits = req.size + req.from.unwrap_or(0); Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + num_hits, segment_ordinal, accessor_idx, + buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1], } } - fn into_top_hits_collector( - self, + fn get_top_hits_computer( + &mut self, + parent_bucket_id: BucketId, value_accessors: &HashMap>, req: &TopHitsAggregationReq, ) -> TopHitsTopNComputer { + if parent_bucket_id as usize >= self.buckets.len() { + return TopHitsTopNComputer::new(req); + } + let top_n = std::mem::replace( + &mut self.buckets[parent_bucket_id as usize], + TopNComputer::new(0), + ); let mut top_hits_computer = TopHitsTopNComputer::new(req); - let top_results = self.top_n.into_vec(); + let top_results = top_n.into_vec(); for res in top_results { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); @@ -554,54 +567,24 @@ impl TopHitsSegmentCollector { top_hits_computer } - - /// TODO add a specialized variant for a single sort field - fn collect_with( - &mut self, - doc_id: crate::DocId, - req: &TopHitsAggregationReq, - accessors: &[(Column, ColumnType)], - ) -> crate::Result<()> { - let sorts: Vec = req - .sort - .iter() - .enumerate() - .map(|(idx, KeyOrder { order, .. })| { - let order = *order; - let value = accessors - .get(idx) - .expect("could not find field in accessors") - .0 - .values_for_doc(doc_id) - .next(); - DocValueAndOrder { value, order } - }) - .collect(); - - self.top_n.push( - sorts, - DocAddress { - segment_ord: self.segment_ordinal, - doc_id, - }, - ); - Ok(()) - } } impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let value_accessors = &req_data.value_accessors; - let intermediate_result = IntermediateMetricResult::TopHits( - self.into_top_hits_collector(value_accessors, &req_data.req), - ); + let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer( + parent_bucket_id, + value_accessors, + &req_data.req, + )); results.push( req_data.name.to_string(), IntermediateAggregationResult::Metric(intermediate_result), @@ -611,26 +594,56 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector { /// TODO: Consider a caching layer to reduce the call overhead fn collect( &mut self, - doc_id: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - self.collect_with(doc_id, &req_data.req, &req_data.accessors)?; - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let top_n = &mut self.buckets[parent_bucket_id as usize]; let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - // TODO: Consider getting fields with the column block accessor. - for doc in docs { - self.collect_with(*doc, &req_data.req, &req_data.accessors)?; + let req = &req_data.req; + let accessors = &req_data.accessors; + for &doc_id in docs { + // TODO: this is terrible, a new vec is allocated for every doc + // We can fetch blocks instead + // We don't need to store the order for every value + let sorts: Vec = req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + DocValueAndOrder { value, order } + }) + .collect(); + + top_n.push( + sorts, + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.buckets.resize( + (max_bucket as usize) + 1, + TopNComputer::new_with_comparator(self.num_hits, ReverseComparator), + ); + Ok(()) + } } #[cfg(test)] @@ -746,7 +759,7 @@ mod tests { ], "from": 0, } - } + } })) .unwrap(); @@ -875,7 +888,7 @@ mod tests { "mixed.*", ], } - } + } }))?; let collector = AggregationCollector::from_aggs(d, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ddf60ea4c..b4a080d6a 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -133,7 +133,7 @@ mod agg_limits; pub mod agg_req; pub mod agg_result; pub mod bucket; -mod buf_collector; +pub(crate) mod cached_sub_aggs; mod collector; mod date; mod error; @@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::TokenizerManager; +/// A bucket id is a dense identifier for a bucket within an aggregation. +/// It is used to index into a Vec that hold per-bucket data. +/// +/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId. +/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket. +/// +/// This allows to have a single AggregationCollector instance per aggregation, +/// that can handle multiple buckets efficiently. +/// +/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])]. +/// For that we'll need a buffer. One Vec per bucket aggregation is needed. +pub type BucketId = u32; + /// Context parameters for aggregation execution /// /// This struct holds shared resources needed during aggregation execution: @@ -335,19 +348,37 @@ impl Display for Key { } } +pub(crate) fn convert_to_f64(val: u64) -> f64 { + if COLUMN_TYPE_ID == ColumnType::U64 as u8 { + val as f64 + } else if COLUMN_TYPE_ID == ColumnType::I64 as u8 + || COLUMN_TYPE_ID == ColumnType::DateTime as u8 + { + i64::from_u64(val) as f64 + } else if COLUMN_TYPE_ID == ColumnType::F64 as u8 { + f64::from_u64(val) + } else if COLUMN_TYPE_ID == ColumnType::Bool as u8 { + val as f64 + } else { + panic!( + "ColumnType ID {} cannot be converted to f64 metric", + COLUMN_TYPE_ID + ) + } +} + /// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics. /// /// # Panics /// Only `u64`, `f64`, `date`, and `i64` are supported. -pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 { +pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 { match field_type { - ColumnType::U64 => val as f64, - ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64, - ColumnType::F64 => f64::from_u64(val), - ColumnType::Bool => val as f64, - _ => { - panic!("unexpected type {field_type:?}. This should not happen") - } + ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val), + ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val), + ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val), + ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val), + ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val), + _ => panic!("unexpected type {field_type:?}. This should not happen"), } } diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5cc2650b6..7bd13f1cd 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -8,25 +8,67 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; + +/// Monotonically increasing provider of BucketIds. +#[derive(Debug, Clone, Default)] +pub struct BucketIdProvider(u32); +impl BucketIdProvider { + /// Get the next BucketId. + pub fn next_bucket_id(&mut self) -> BucketId { + let bucket_id = self.0; + self.0 += 1; + bucket_id + } +} /// A SegmentAggregationCollector is used to collect aggregation results. -pub trait SegmentAggregationCollector: CollectorClone + Debug { +pub trait SegmentAggregationCollector: Debug { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()>; + /// Note: The caller needs to call `prepare_max_bucket` before calling `collect`. fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; - fn collect_block( + /// Collect docs for multiple buckets in one call. + /// Minimizes dynamic dispatch overhead when collecting many buckets. + /// + /// Note: The caller needs to call `prepare_max_bucket` before calling `collect`. + fn collect_multiple( &mut self, + bucket_ids: &[BucketId], docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + debug_assert_eq!(bucket_ids.len(), docs.len()); + let mut start = 0; + while start < bucket_ids.len() { + let bucket_id = bucket_ids[start]; + let mut end = start + 1; + while end < bucket_ids.len() && bucket_ids[end] == bucket_id { + end += 1; + } + self.collect(bucket_id, &docs[start..end], agg_data)?; + start = end; + } + Ok(()) + } + + /// Prepare the collector for collecting up to BucketId `max_bucket`. + /// This is useful so we can split allocation ahead of time of collecting. + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()>; /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. @@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug { } } -/// A helper trait to enable cloning of Box -pub trait CollectorClone { - fn clone_box(&self) -> Box; -} - -impl CollectorClone for T -where T: 'static + SegmentAggregationCollector + Clone -{ - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_box() - } -} - -#[derive(Clone, Default)] +#[derive(Default)] /// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one /// and can provide specialized versions instead, that remove some of its overhead. @@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - for agg in self.aggs { - agg.add_intermediate_aggregation_result(agg_data, results)?; + for agg in &mut self.aggs { + agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?; } Ok(()) @@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data)?; - - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for collector in &mut self.aggs { - collector.collect_block(docs, agg_data)?; + collector.collect(parent_bucket_id, docs, agg_data)?; } - Ok(()) } @@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + for collector in &mut self.aggs { + collector.prepare_max_bucket(max_bucket, agg_data)?; + } + Ok(()) + } } diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index a94ec03e8..6eb2c3ee7 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -486,9 +486,9 @@ mod tests { use std::collections::BTreeSet; use columnar::Dictionary; - use rand::distributions::Uniform; + use rand::distr::Uniform; use rand::prelude::SliceRandom; - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use super::{FacetCollector, FacetCounts}; use crate::collector::facet_collector::compress_mapping; @@ -731,7 +731,7 @@ mod tests { let schema = schema_builder.build(); let index = Index::create_in_ram(schema); - let uniform = Uniform::new_inclusive(1, 100_000); + let uniform = Uniform::new_inclusive(1, 100_000).unwrap(); let mut docs: Vec = vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)] .into_iter() @@ -741,14 +741,11 @@ mod tests { std::iter::repeat_n(doc, count) }) .map(|mut doc| { - doc.add_facet( - facet_field, - &format!("/facet/{}", thread_rng().sample(uniform)), - ); + doc.add_facet(facet_field, &format!("/facet/{}", rng().sample(uniform))); doc }) .collect(); - docs[..].shuffle(&mut thread_rng()); + docs[..].shuffle(&mut rng()); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for doc in docs { @@ -822,8 +819,8 @@ mod tests { #[cfg(all(test, feature = "unstable"))] mod bench { + use rand::rng; use rand::seq::SliceRandom; - use rand::thread_rng; use test::Bencher; use crate::collector::FacetCollector; @@ -846,7 +843,7 @@ mod bench { } } // 40425 docs - docs[..].shuffle(&mut thread_rng()); + docs[..].shuffle(&mut rng()); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for doc in docs { diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index a66115633..391873298 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -1,25 +1,48 @@ mod order; +mod sort_by_erased_type; mod sort_by_score; mod sort_by_static_fast_value; mod sort_by_string; mod sort_key_computer; pub use order::*; +pub use sort_by_erased_type::SortByErasedType; pub use sort_by_score::SortBySimilarityScore; pub use sort_by_static_fast_value::SortByStaticFastValue; pub use sort_by_string::SortByString; pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer}; #[cfg(test)] -mod tests { +pub(crate) mod tests { + + // By spec, regardless of whether ascending or descending order was requested, in presence of a + // tie, we sort by ascending doc id/doc address. + pub(crate) fn sort_hits( + hits: &mut [ComparableDoc], + order: Order, + ) { + if order.is_asc() { + hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc))); + } else { + hits.sort_by(|l, r| { + l.sort_key + .cmp(&r.sort_key) + .reverse() // This is descending + .then(l.doc.cmp(&r.doc)) + }); + } + } + use std::collections::HashMap; use std::ops::Range; - use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; + use crate::collector::sort_key::{ + SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString, + }; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, QueryParser}; - use crate::schema::{Schema, FAST, TEXT}; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; use crate::{DocAddress, Document, Index, Order, Score, Searcher}; fn make_index() -> crate::Result { @@ -294,11 +317,9 @@ mod tests { (SortBySimilarityScore, score_order), (SortByString::for_field("city"), city_order), )); - Ok(searcher - .search(&AllQuery, &top_collector)? - .into_iter() - .map(|(f, doc)| (f, ids[&doc])) - .collect()) + let results: Vec<((Score, Option), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) } assert_eq!( @@ -323,6 +344,51 @@ mod tests { Ok(()) } + #[test] + fn test_order_by_score_then_owned_value() -> crate::Result<()> { + let index = make_index()?; + + type SortKey = (Score, OwnedValue); + + fn query( + index: &Index, + score_order: Order, + city_order: Order, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>(( + (SortBySimilarityScore, score_order), + (SortByErasedType::for_field("city"), city_order), + )); + let results: Vec<((Score, OwnedValue), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) + } + + assert_eq!( + &query(&index, Order::Asc, Order::Asc)?, + &[ + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Null), 3), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, Order::Desc)?, + &[ + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Null), 3), + ] + ); + Ok(()) + } + use proptest::prelude::*; proptest! { @@ -372,15 +438,10 @@ mod tests { // Using the TopDocs collector should always be equivalent to sorting, skipping the // offset, and then taking the limit. - let sorted_docs: Vec<_> = if order.is_desc() { - let mut comparable_docs: Vec> = + let sorted_docs: Vec<_> = { + let mut comparable_docs: Vec> = all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() - } else { - let mut comparable_docs: Vec> = - all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); + sort_hits(&mut comparable_docs, order); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 923d5cb8e..3cac357ad 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,36 +1,116 @@ use std::cmp::Ordering; +use columnar::MonotonicallyMappableToU64; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; -use crate::schema::Schema; +use crate::schema::{OwnedValue, Schema}; use crate::{DocId, Order, Score}; +fn compare_owned_value(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + match (lhs, rhs) { + (OwnedValue::Null, OwnedValue::Null) => Ordering::Equal, + (OwnedValue::Null, _) => { + if NULLS_FIRST { + Ordering::Less + } else { + Ordering::Greater + } + } + (_, OwnedValue::Null) => { + if NULLS_FIRST { + Ordering::Greater + } else { + Ordering::Less + } + } + (OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b), + (OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b), + (OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b), + (OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()), + (OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b), + (OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b), + (OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b), + (OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b), + (OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::I64(b)) => { + if *b < 0 { + Ordering::Greater + } else { + a.cmp(&(*b as u64)) + } + } + (OwnedValue::I64(a), OwnedValue::U64(b)) => { + if *a < 0 { + Ordering::Less + } else { + (*a as u64).cmp(b) + } + } + (OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (a, b) => { + let ord = a.discriminant_value().cmp(&b.discriminant_value()); + // If the discriminant is equal, it's because a new type was added, but hasn't been + // included in this `match` statement. + assert!( + ord != Ordering::Equal, + "Unimplemented comparison for type of {a:?}, {b:?}" + ); + ord + } + } +} + /// Comparator trait defining the order in which documents should be ordered. pub trait Comparator: Send + Sync + std::fmt::Debug + Default { /// Return the order between two values. fn compare(&self, lhs: &T, rhs: &T) -> Ordering; } -/// With the natural comparator, the top k collector will return -/// the top documents in decreasing order. +/// Compare values naturally (e.g. 1 < 2). +/// +/// When used with `TopDocs`, which reverses the order, this results in a +/// "Descending" sort (Greatest values first). +/// +/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value, +/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`). #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct NaturalComparator; impl Comparator for NaturalComparator { #[inline(always)] fn compare(&self, lhs: &T, rhs: &T) -> Ordering { - lhs.partial_cmp(rhs).unwrap() + lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal) } } -/// Sorts document in reverse order. +/// A (partial) implementation of comparison for OwnedValue. /// -/// If the sort key is None, it will considered as the lowest value, and will therefore appear -/// first. +/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with +/// mismatched types. The one exception is Null, for which we do define all comparisons. +impl Comparator for NaturalComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + +/// Compare values in reverse (e.g. 2 < 1). +/// +/// When used with `TopDocs`, which reverses the order, this results in an +/// "Ascending" sort (Smallest values first). +/// +/// `None` is considered smaller than `Some` in the underlying comparator, but because the +/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting +/// Ascending sort (e.g. `[None, Some(10), Some(20)]`). /// /// The ReverseComparator does not necessarily imply that the sort order is reversed compared -/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids. +/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be +/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order. #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct ReverseComparator; @@ -43,11 +123,15 @@ where NaturalComparator: Comparator } } -/// Sorts document in reverse order, but considers None as having the lowest value. +/// Compare values in reverse, but treating `None` as lower than `Some`. +/// +/// When used with `TopDocs`, which reverses the order, this results in an +/// "Ascending" sort (Smallest values first), but with `None` values appearing last +/// (e.g. `[Some(10), Some(20), None]`). /// /// This is usually what is wanted when sorting by a field in an ascending order. -/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the -/// cheapest items first, and the items without a price at last. +/// For instance, in an e-commerce website, if sorting by price ascending, +/// the cheapest items would appear first, and items without a price would appear last. #[derive(Debug, Copy, Clone, Default)] pub struct ReverseNoneIsLowerComparator; @@ -107,6 +191,84 @@ impl Comparator for ReverseNoneIsLowerComparator { } } +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(rhs, lhs) + } +} + +/// Compare values naturally, but treating `None` as higher than `Some`. +/// +/// When used with `TopDocs`, which reverses the order, this results in a +/// "Descending" sort (Greatest values first), but with `None` values appearing first +/// (e.g. `[None, Some(20), Some(10)]`). +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +pub struct NaturalNoneIsHigherComparator; + +impl Comparator> for NaturalNoneIsHigherComparator +where NaturalComparator: Comparator +{ + #[inline(always)] + fn compare(&self, lhs_opt: &Option, rhs_opt: &Option) -> Ordering { + match (lhs_opt, rhs_opt) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs), + } + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &String, rhs: &String) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// An enum representing the different sort orders. #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] pub enum ComparatorEnum { @@ -115,8 +277,10 @@ pub enum ComparatorEnum { Natural, /// Reverse order (See [ReverseComparator]) Reverse, - /// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator]) + /// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator]) ReverseNoneLower, + /// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator]) + NaturalNoneHigher, } impl From for ComparatorEnum { @@ -133,6 +297,7 @@ where ReverseNoneIsLowerComparator: Comparator, NaturalComparator: Comparator, ReverseComparator: Comparator, + NaturalNoneIsHigherComparator: Comparator, { #[inline(always)] fn compare(&self, lhs: &T, rhs: &T) -> Ordering { @@ -140,6 +305,7 @@ where ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs), ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs), ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs), + ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs), } } } @@ -322,11 +488,12 @@ impl SegmentSortKeyComput for SegmentSortKeyComputerWithComparator where TSegmentSortKeyComputer: SegmentSortKeyComputer, - TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send, + TSegmentSortKey: Clone + 'static + Sync + Send, TComparator: Comparator + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; + type SegmentComparator = TComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.segment_sort_key_computer.segment_sort_key(doc, score) @@ -346,3 +513,55 @@ where .convert_segment_sort_key(sort_key) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::OwnedValue; + + #[test] + fn test_natural_none_is_higher() { + let comp = NaturalNoneIsHigherComparator; + let null = None; + let v1 = Some(1_u64); + let v2 = Some(2_u64); + + // NaturalNoneIsGreaterComparator logic: + // 1. Delegates to NaturalComparator for non-nulls. + // NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater. + assert_eq!(comp.compare(&v2, &v1), Ordering::Greater); + + // 2. Treats None (Null) as Greater than any value. + // compare(None, Some(2)) should be Greater. + assert_eq!(comp.compare(&null, &v2), Ordering::Greater); + + // compare(Some(1), None) should be Less. + assert_eq!(comp.compare(&v1, &null), Ordering::Less); + + // compare(None, None) should be Equal. + assert_eq!(comp.compare(&null, &null), Ordering::Equal); + } + + #[test] + fn test_mixed_ownedvalue_compare() { + let u = OwnedValue::U64(10); + let i = OwnedValue::I64(10); + let f = OwnedValue::F64(10.0); + + let nc = NaturalComparator; + assert_eq!(nc.compare(&u, &i), Ordering::Equal); + assert_eq!(nc.compare(&u, &f), Ordering::Equal); + assert_eq!(nc.compare(&i, &f), Ordering::Equal); + + let u2 = OwnedValue::U64(11); + assert_eq!(nc.compare(&u2, &f), Ordering::Greater); + + let s = OwnedValue::Str("a".to_string()); + // Str < U64 + assert_eq!(nc.compare(&s, &u), Ordering::Less); + // Str < I64 + assert_eq!(nc.compare(&s, &i), Ordering::Less); + // Str < F64 + assert_eq!(nc.compare(&s, &f), Ordering::Less); + } +} diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs new file mode 100644 index 000000000..d15dd130c --- /dev/null +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -0,0 +1,361 @@ +use columnar::{ColumnType, MonotonicallyMappableToU64}; + +use crate::collector::sort_key::{ + NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, +}; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::FastFieldNotAvailableError; +use crate::schema::OwnedValue; +use crate::{DateTime, DocId, Score}; + +/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score. +/// +/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders +/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to +/// use a SortKeyComputer implementation with a known-type at compile time. +#[derive(Debug, Clone)] +pub enum SortByErasedType { + /// Sort by a fast field + Field(String), + /// Sort by score + Score, +} + +impl SortByErasedType { + /// Creates a new sort key computer which will sort by the given fast field column, with type + /// erasure. + pub fn for_field(column_name: impl ToString) -> Self { + Self::Field(column_name.to_string()) + } + + /// Creates a new sort key computer which will sort by score, with type erasure. + pub fn for_score() -> Self { + Self::Score + } +} + +trait ErasedSegmentSortKeyComputer: Send + Sync { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; +} + +struct ErasedSegmentSortKeyComputerWrapper { + inner: C, + converter: F, +} + +impl ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper +where + C: SegmentSortKeyComputer> + Send + Sync, + F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static, +{ + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let val = self.inner.convert_segment_sort_key(sort_key); + (self.converter)(val) + } +} + +struct ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, +} + +impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into(); + Some(score_value.to_u64()) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let score_value: u64 = sort_key.expect("This implementation always produces a score."); + OwnedValue::F64(f64::from_u64(score_value)) + } +} + +impl SortKeyComputer for SortByErasedType { + type SortKey = OwnedValue; + type Child = ErasedColumnSegmentSortKeyComputer; + type Comparator = NaturalComparator; + + fn requires_scoring(&self) -> bool { + matches!(self, Self::Score) + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let inner: Box = match self { + Self::Field(column_name) => { + let fast_fields = segment_reader.fast_fields(); + // TODO: We currently double-open the column to avoid relying on the implementation + // details of `SortByString` or `SortByStaticFastValue`. Once + // https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should + // consider directly constructing the appropriate `SegmentSortKeyComputer` type for + // the column that we open here. + let (_column, column_type) = + fast_fields.u64_lenient(column_name)?.ok_or_else(|| { + FastFieldNotAvailableError { + field_name: column_name.to_owned(), + } + })?; + + match column_type { + ColumnType::Str => { + let computer = SortByString::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::U64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::I64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::F64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::Bool => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::DateTime => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null) + }, + }) + } + column_type => { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is of type {column_type:?}, which is not supported for \ + sorting by owned value yet.", + column_name + ))) + } + } + } + Self::Score => Box::new(ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, + }), + }; + Ok(ErasedColumnSegmentSortKeyComputer { inner }) + } +} + +pub struct ErasedColumnSegmentSortKeyComputer { + inner: Box, +} + +impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { + type SortKey = OwnedValue; + type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { + self.inner.convert_segment_sort_key(segment_sort_key) + } +} + +#[cfg(test)] +mod tests { + use crate::collector::sort_key::{ComparatorEnum, SortByErasedType}; + use crate::collector::TopDocs; + use crate::query::AllQuery; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_owned_u64() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null] + ); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("id"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null] + ); + } + + #[test] + fn test_sort_by_owned_string() { + let mut schema_builder = Schema::builder(); + let city_field = schema_builder.add_text_field("city", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(city_field => "tokyo")).unwrap(); + writer.add_document(doc!(city_field => "austin")).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("city"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![ + OwnedValue::Str("austin".to_string()), + OwnedValue::Str("tokyo".to_string()), + OwnedValue::Null + ] + ); + } + + #[test] + fn test_sort_by_owned_reverse() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)] + ); + } + + #[test] + fn test_sort_by_owned_score() { + let mut schema_builder = Schema::builder(); + let body_field = schema_builder.add_text_field("body", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(body_field => "a a")).unwrap(); + writer.add_document(doc!(body_field => "a")).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]); + let query = query_parser.parse_query("a").unwrap(); + + // Sort by score descending (Natural) + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_score(), ComparatorEnum::Natural)); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {key:?}"), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] > values[1]); + + // Sort by score ascending (ReverseNoneLower) + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_score(), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {key:?}"), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] < values[1]); + } +} diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index df8b0dd75..a23660e56 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore { impl SegmentSortKeyComputer for SortBySimilarityScore { type SortKey = Score; - type SegmentSortKey = Score; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index b38b8b034..44a4e1d8d 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -34,9 +34,7 @@ impl SortByStaticFastValue { impl SortKeyComputer for SortByStaticFastValue { type Child = SortByFastValueSegmentSortKeyComputer; - type SortKey = Option; - type Comparator = NaturalComparator; fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> { @@ -84,8 +82,8 @@ pub struct SortByFastValueSegmentSortKeyComputer { impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 41ef22e9b..2dd0b4592 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -30,9 +30,7 @@ impl SortByString { impl SortKeyComputer for SortByString { type SortKey = Option; - type Child = ByStringColumnSegmentSortKeyComputer; - type Comparator = NaturalComparator; fn segment_sort_key_computer( @@ -50,8 +48,8 @@ pub struct ByStringColumnSegmentSortKeyComputer { impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option { @@ -60,6 +58,8 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { } fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { + // TODO: Individual lookups to the dictionary like this are very likely to repeatedly + // decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776 let term_ord = term_ord_opt?; let str_column = self.str_column_opt.as_ref()?; let mut bytes = Vec::new(); diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d56fa7cd0..6aab919a9 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -12,13 +12,21 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader}; /// It is the segment local version of the [`SortKeyComputer`]. pub trait SegmentSortKeyComputer: 'static { /// The final score being emitted. - type SortKey: 'static + PartialOrd + Send + Sync + Clone; + type SortKey: 'static + Send + Sync + Clone; /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. /// /// It is typically small like a `u64`, and is meant to be converted /// to the final score at the end of the collection of the segment. - type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + type SegmentSortKey: 'static + Clone + Send + Sync + Clone; + + /// Comparator type. + type SegmentComparator: Comparator + 'static; + + /// Returns the segment sort key comparator. + fn segment_comparator(&self) -> Self::SegmentComparator { + Self::SegmentComparator::default() + } /// Computes the sort key for the given document and score. fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; @@ -47,7 +55,7 @@ pub trait SegmentSortKeyComputer: 'static { left: &Self::SegmentSortKey, right: &Self::SegmentSortKey, ) -> Ordering { - NaturalComparator.compare(left, right) + self.segment_comparator().compare(left, right) } /// Implementing this method makes it possible to avoid computing @@ -81,7 +89,7 @@ pub trait SegmentSortKeyComputer: 'static { /// the sort key at a segment scale. pub trait SortKeyComputer: Sync { /// The sort key type. - type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug; + type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug; /// Type of the associated [`SegmentSortKeyComputer`]. type Child: SegmentSortKeyComputer; /// Comparator type. @@ -136,10 +144,7 @@ where HeadSortKeyComputer: SortKeyComputer, TailSortKeyComputer: SortKeyComputer, { - type SortKey = ( - ::SortKey, - ::SortKey, - ); + type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey); type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); type Comparator = ( @@ -188,6 +193,11 @@ where TailSegmentSortKeyComputer::SegmentSortKey, ); + type SegmentComparator = ( + HeadSegmentSortKeyComputer::SegmentComparator, + TailSegmentSortKeyComputer::SegmentComparator, + ); + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. /// @@ -269,11 +279,12 @@ impl SegmentSortKeyComputer for MappedSegmentSortKeyComputer where T: SegmentSortKeyComputer, - PreviousScore: 'static + Clone + Send + Sync + PartialOrd, - NewScore: 'static + Clone + Send + Sync + PartialOrd, + PreviousScore: 'static + Clone + Send + Sync, + NewScore: 'static + Clone + Send + Sync, { type SortKey = NewScore; type SegmentSortKey = T::SegmentSortKey; + type SegmentComparator = T::SegmentComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.sort_key_computer.segment_sort_key(doc, score) @@ -463,6 +474,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { (self)(doc) diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 3ca27fc75..9ca47581b 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -160,7 +160,7 @@ mod tests { expected: &[(crate::Score, usize)], ) { let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect(); - vals.shuffle(&mut rand::thread_rng()); + vals.shuffle(&mut rand::rng()); let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order)); assert_eq!(&vals_merged, expected); } diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 6981c86c9..1990b3837 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,64 +1,22 @@ -use std::cmp::Ordering; - use serde::{Deserialize, Serialize}; /// Contains a feature (field, score, etc.) of a document along with the document address. /// -/// It guarantees stable sorting: in case of a tie on the feature, the document -/// address is used. -/// -/// The REVERSE_ORDER generic parameter controls whether the by-feature order -/// should be reversed, which is useful for achieving for example largest-first -/// semantics without having to wrap the feature in a `Reverse`. -#[derive(Clone, Default, Serialize, Deserialize)] -pub struct ComparableDoc { +/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`. +#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)] +pub struct ComparableDoc { /// The feature of the document. In practice, this is - /// is any type that implements `PartialOrd`. + /// is a type which can be compared with a `Comparator`. pub sort_key: T, - /// The document address. In practice, this is any - /// type that implements `PartialOrd`, and is guaranteed - /// to be unique for each document. + /// The document address. In practice, this is either a `DocId` or `DocAddress`. pub doc: D, } -impl std::fmt::Debug - for ComparableDoc -{ + +impl std::fmt::Debug for ComparableDoc { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) + f.debug_struct("ComparableDoc") .field("feature", &self.sort_key) .field("doc", &self.doc) .finish() } } - -impl PartialOrd for ComparableDoc { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for ComparableDoc { - #[inline] - fn cmp(&self, other: &Self) -> Ordering { - let by_feature = self - .sort_key - .partial_cmp(&other.sort_key) - .map(|ord| if R { ord.reverse() } else { ord }) - .unwrap_or(Ordering::Equal); - - let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal); - - // In case of a tie on the feature, we sort by ascending - // `DocAddress` in order to ensure a stable sorting of the - // documents. - by_feature.then_with(lazy_by_doc_address) - } -} - -impl PartialEq for ComparableDoc { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl Eq for ComparableDoc {} diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 78c344dbe..0ce1c611a 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader}; /// The theoretical complexity for collecting the top `K` out of `N` documents /// is `O(N + K)`. /// -/// This collector does not guarantee a stable sorting in case of a tie on the -/// document score, for stable sorting `PartialOrd` needs to resolve on other fields -/// like docid in case of score equality. -/// Only then, it is suitable for pagination. +/// This collector guarantees a stable sorting in case of a tie on the +/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker. +/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`. /// /// ```rust /// use tantivy::collector::TopDocs; @@ -325,7 +324,7 @@ impl TopDocs { sort_key_computer: impl SortKeyComputer + Send + 'static, ) -> impl Collector> where - TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug, + TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug, { TopBySortKeyCollector::new(sort_key_computer, self.doc_range()) } @@ -446,7 +445,7 @@ where F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn, TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, TweakScoreSegmentSortKeyComputer: - SegmentSortKeyComputer, + SegmentSortKeyComputer, TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, { type SortKey = TSortKey; @@ -481,6 +480,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { (self.sort_key_fn)(doc, score) @@ -500,8 +500,13 @@ where /// /// For TopN == 0, it will be relative expensive. /// -/// When using the natural comparator, the top N computer returns the top N elements in -/// descending order, as expected for a top N. +/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress): +/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in +/// ascending order, regardless of the `Comparator` used for the `Score` type. +/// +/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the +/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides +/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons. #[derive(Serialize, Deserialize)] #[serde(from = "TopNComputerDeser")] pub struct TopNComputer { @@ -580,6 +585,18 @@ where } } +#[inline(always)] +fn compare_for_top_k>( + c: &C, + lhs: &ComparableDoc, + rhs: &ComparableDoc, +) -> std::cmp::Ordering { + c.compare(&lhs.sort_key, &rhs.sort_key) + .reverse() // Reverse here because we want top K. + .then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we + // sort by doc id +} + impl TopNComputer where D: Ord, @@ -600,10 +617,13 @@ where /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. + /// + /// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order. #[inline] pub fn push(&mut self, sort_key: TSortKey, doc: D) { if let Some(last_median) = &self.threshold { - if self.comparator.compare(&sort_key, last_median) == Ordering::Less { + // See the struct docs for an explanation of why this comparison is strict. + if self.comparator.compare(&sort_key, last_median) != Ordering::Greater { return; } } @@ -629,9 +649,7 @@ where fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| { - self.comparator - .compare(&rhs.sort_key, &lhs.sort_key) - .then_with(|| lhs.doc.cmp(&rhs.doc)) + compare_for_top_k(&self.comparator, lhs, rhs) }); let median_score = median_el.sort_key.clone(); @@ -646,11 +664,8 @@ where if self.buffer.len() > self.top_n { self.truncate_top_n(); } - self.buffer.sort_unstable_by(|left, right| { - self.comparator - .compare(&right.sort_key, &left.sort_key) - .then_with(|| left.doc.cmp(&right.doc)) - }); + self.buffer + .sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs)); self.buffer } @@ -755,6 +770,33 @@ mod tests { ); } + #[test] + fn test_topn_computer_duplicates() { + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(2, NaturalComparator); + + computer.push(1u32, 1u32); + computer.push(1u32, 2u32); + computer.push(1u32, 3u32); + computer.push(1u32, 4u32); + computer.push(1u32, 5u32); + + // In the presence of duplicates, DocIds are always ascending order. + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: 1u32, + doc: 1u32, + }, + ComparableDoc { + sort_key: 1u32, + doc: 2u32, + } + ] + ); + } + #[test] fn test_topn_computer_no_panic() { for top_n in 0..10 { @@ -772,14 +814,17 @@ mod tests { #[test] fn test_topn_computer_asc_prop( limit in 0..10_usize, - docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), + mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), ) { + // NB: TopNComputer must receive inputs in ascending DocId order. + docs.sort_by_key(|(_, doc_id)| *doc_id); let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator); for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs: Vec> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::>(); - comparable_docs.sort(); + let mut comparable_docs: Vec> = + docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect(); + crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc); comparable_docs.truncate(limit); prop_assert_eq!( computer.into_sorted_vec(), @@ -1406,15 +1451,10 @@ mod tests { // Using the TopDocs collector should always be equivalent to sorting, skipping the // offset, and then taking the limit. - let sorted_docs: Vec<_> = if order.is_desc() { - let mut comparable_docs: Vec> = + let sorted_docs: Vec<_> = { + let mut comparable_docs: Vec> = all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() - } else { - let mut comparable_docs: Vec> = - all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); + crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); diff --git a/src/core/executor.rs b/src/core/executor.rs index 8cc7e0026..f11644599 100644 --- a/src/core/executor.rs +++ b/src/core/executor.rs @@ -48,7 +48,15 @@ impl Executor { F: Sized + Sync + Fn(A) -> crate::Result, { match self { - Executor::SingleThread => args.map(f).collect::>(), + Executor::SingleThread => { + // Avoid `collect`, since the stacktrace is blown up by it, which makes profiling + // harder. + let mut result = Vec::with_capacity(args.size_hint().0); + for arg in args { + result.push(f(arg)?); + } + Ok(result) + } Executor::ThreadPool(pool) => { let args: Vec = args.collect(); let num_fruits = args.len(); diff --git a/src/directory/file_watcher.rs b/src/directory/mmap_directory/file_watcher.rs similarity index 100% rename from src/directory/file_watcher.rs rename to src/directory/mmap_directory/file_watcher.rs diff --git a/src/directory/mmap_directory.rs b/src/directory/mmap_directory/mod.rs similarity index 99% rename from src/directory/mmap_directory.rs rename to src/directory/mmap_directory/mod.rs index f4785ef72..60ef82b30 100644 --- a/src/directory/mmap_directory.rs +++ b/src/directory/mmap_directory/mod.rs @@ -1,3 +1,5 @@ +mod file_watcher; + use std::collections::HashMap; use std::fmt; use std::fs::{self, File, OpenOptions}; @@ -7,6 +9,7 @@ use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock, Weak}; use common::StableDeref; +use file_watcher::FileWatcher; use fs4::fs_std::FileExt; #[cfg(all(feature = "mmap", unix))] pub use memmap2::Advice; @@ -18,7 +21,6 @@ use crate::core::META_FILEPATH; use crate::directory::error::{ DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError, }; -use crate::directory::file_watcher::FileWatcher; use crate::directory::{ AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle, WritePtr, diff --git a/src/directory/mod.rs b/src/directory/mod.rs index 7fab7e051..d4494d307 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -5,7 +5,6 @@ mod mmap_directory; mod directory; mod directory_lock; -mod file_watcher; pub mod footer; mod managed_directory; mod ram_directory; diff --git a/src/docset.rs b/src/docset.rs index 7de138da6..01ea1125a 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -40,6 +40,8 @@ pub trait DocSet: Send { /// of `DocSet` should support it. /// /// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`. + /// + /// `target` has to be larger or equal to `.doc()` when calling `seek`. fn seek(&mut self, target: DocId) -> DocId { let mut doc = self.doc(); debug_assert!(doc <= target); @@ -49,6 +51,33 @@ pub trait DocSet: Send { doc } + /// Seeks to the target if possible and returns true if the target is in the DocSet. + /// + /// DocSets that already have an efficient `seek` method don't need to implement + /// `seek_into_the_danger_zone`. All wrapper DocSets should forward + /// `seek_into_the_danger_zone` to the underlying DocSet. + /// + /// ## API Behaviour + /// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target. + /// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc + /// between the last doc that matched and target or a doc that is a valid next hit after + /// target. The DocSet is considered to be in an invalid state until + /// `seek_into_the_danger_zone` returns true again. + /// + /// `target` needs to be equal or larger than `doc` when in a valid state. + /// + /// Consecutive calls are not allowed to have decreasing `target` values. + /// + /// # Warning + /// This is an advanced API used by intersection. The API contract is tricky, avoid using it. + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let current_doc = self.doc(); + if current_doc < target { + self.seek(target); + } + self.doc() == target + } + /// Fills a given mutable buffer with the next doc ids from the /// `DocSet` /// @@ -94,6 +123,15 @@ pub trait DocSet: Send { /// which would be the number of documents in the DocSet. /// /// By default this returns `size_hint()`. + /// + /// DocSets may have vastly different cost depending on their type, + /// e.g. an intersection with 10 hits is much cheaper than + /// a phrase search with 10 hits, since it needs to load positions. + /// + /// ### Future Work + /// We may want to differentiate `DocSet` costs more more granular, e.g. + /// creation_cost, advance_cost, seek_cost on to get a good estimation + /// what query types to choose. fn cost(&self) -> u64 { self.size_hint() as u64 } @@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet { (**self).seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + (**self).seek_into_the_danger_zone(target) + } + fn doc(&self) -> u32 { (**self).doc() } @@ -169,6 +211,11 @@ impl DocSet for Box { unboxed.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let unboxed: &mut TDocSet = self.borrow_mut(); + unboxed.seek_into_the_danger_zone(target) + } + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.fill_buffer(buffer) diff --git a/src/fastfield/alive_bitset.rs b/src/fastfield/alive_bitset.rs index 11d7463c7..bbdc82a45 100644 --- a/src/fastfield/alive_bitset.rs +++ b/src/fastfield/alive_bitset.rs @@ -162,7 +162,7 @@ mod tests { mod bench { use rand::prelude::IteratorRandom; - use rand::thread_rng; + use rand::rng; use test::Bencher; use super::AliveBitSet; @@ -176,7 +176,7 @@ mod bench { } fn remove_rand(raw: &mut Vec) { - let i = (0..raw.len()).choose(&mut thread_rng()).unwrap(); + let i = (0..raw.len()).choose(&mut rng()).unwrap(); raw.remove(i); } diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index 726b9b76a..aca53c212 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -879,7 +879,7 @@ mod tests { const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000; let times: Vec = std::iter::repeat_with(|| { // +- One hour. - let t = T0 + rng.gen_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS); + let t = T0 + rng.random_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS); DateTime::from_timestamp_micros(t) }) .take(1_000) diff --git a/src/functional_test.rs b/src/functional_test.rs index 1548d8096..9606bb7a7 100644 --- a/src/functional_test.rs +++ b/src/functional_test.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use crate::indexer::index_writer::MEMORY_BUDGET_NUM_BYTES_MIN; use crate::schema::*; @@ -29,7 +29,7 @@ fn test_functional_store() -> crate::Result<()> { let index = Index::create_in_ram(schema); let reader = index.reader()?; - let mut rng = thread_rng(); + let mut rng = rng(); let mut index_writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; @@ -38,9 +38,9 @@ fn test_functional_store() -> crate::Result<()> { let mut doc_id = 0u64; for _iteration in 0..get_num_iterations() { - let num_docs: usize = rng.gen_range(0..4); + let num_docs: usize = rng.random_range(0..4); if !doc_set.is_empty() { - let doc_to_remove_id = rng.gen_range(0..doc_set.len()); + let doc_to_remove_id = rng.random_range(0..doc_set.len()); let removed_doc_id = doc_set.swap_remove(doc_to_remove_id); index_writer.delete_term(Term::from_field_u64(id_field, removed_doc_id)); } @@ -70,10 +70,10 @@ const LOREM: &str = "Doc Lorem ipsum dolor sit amet, consectetur adipiscing elit cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat \ non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; fn get_text() -> String { - use rand::seq::SliceRandom; - let mut rng = thread_rng(); + use rand::seq::IndexedRandom; + let mut rng = rng(); let tokens: Vec<_> = LOREM.split(' ').collect(); - let random_val = rng.gen_range(0..20); + let random_val = rng.random_range(0..20); (0..random_val) .map(|_| tokens.choose(&mut rng).unwrap()) @@ -101,7 +101,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> { let index = Index::create_from_tempdir(schema)?; let reader = index.reader()?; - let mut rng = thread_rng(); + let mut rng = rng(); let mut index_writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; @@ -110,7 +110,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> { let mut uncommitted_docs: HashSet = HashSet::new(); for _ in 0..get_num_iterations() { - let random_val = rng.gen_range(0..20); + let random_val = rng.random_range(0..20); if random_val == 0 { index_writer.commit()?; committed_docs.extend(&uncommitted_docs); diff --git a/src/index/index_meta.rs b/src/index/index_meta.rs index 86eaa35d6..d06d706c4 100644 --- a/src/index/index_meta.rs +++ b/src/index/index_meta.rs @@ -13,9 +13,9 @@ use crate::store::Compressor; use crate::{Inventory, Opstamp, TrackedObject}; #[derive(Clone, Debug, Serialize, Deserialize)] -struct DeleteMeta { +pub struct DeleteMeta { num_deleted_docs: u32, - opstamp: Opstamp, + pub opstamp: Opstamp, } #[derive(Clone, Default)] @@ -213,7 +213,7 @@ impl SegmentMeta { struct InnerSegmentMeta { segment_id: SegmentId, max_doc: u32, - deletes: Option, + pub deletes: Option, /// If you want to avoid the SegmentComponent::TempStore file to be covered by /// garbage collection and deleted, set this to true. This is used during merge. #[serde(skip)] @@ -404,7 +404,10 @@ mod tests { schema_builder.build() }; let index_metas = IndexMeta { - index_settings: IndexSettings::default(), + index_settings: IndexSettings { + docstore_compression: Compressor::None, + ..Default::default() + }, segments: Vec::new(), schema, opstamp: 0u64, @@ -413,7 +416,7 @@ mod tests { let json = serde_json::ser::to_string(&index_metas).expect("serialization failed"); assert_eq!( json, - r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"# + r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"# ); let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap(); @@ -494,6 +497,8 @@ mod tests { #[test] #[cfg(feature = "lz4-compression")] fn test_index_settings_default() { + use crate::store::Compressor; + let mut index_settings = IndexSettings::default(); assert_eq!( index_settings, diff --git a/src/index/segment.rs b/src/index/segment.rs index 4c9382cb0..fcd32a1ff 100644 --- a/src/index/segment.rs +++ b/src/index/segment.rs @@ -46,7 +46,7 @@ impl Segment { /// /// This method is only used when updating `max_doc` from 0 /// as we finalize a fresh new segment. - pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment { + pub fn with_max_doc(self, max_doc: u32) -> Segment { Segment { index: self.index, meta: self.meta.with_max_doc(max_doc), diff --git a/src/indexer/delete_queue.rs b/src/indexer/delete_queue.rs index 3aa9f0d85..1a269caed 100644 --- a/src/indexer/delete_queue.rs +++ b/src/indexer/delete_queue.rs @@ -4,38 +4,37 @@ use std::sync::{Arc, RwLock, Weak}; use super::operation::DeleteOperation; use crate::Opstamp; -// The DeleteQueue is similar in conceptually to a multiple -// consumer single producer broadcast channel. -// -// All consumer will receive all messages. -// -// Consumer of the delete queue are holding a `DeleteCursor`, -// which points to a specific place of the `DeleteQueue`. -// -// New consumer can be created in two ways -// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete operation -// (and some or none of the past operations... The client is in charge of checking the opstamps.). -// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can -// now advance independently from the original cursor. +/// The DeleteQueue is similar in conceptually to a multiple +/// consumer single producer broadcast channel. +/// +/// All consumer will receive all messages. +/// +/// Consumer of the delete queue are holding a `DeleteCursor`, +/// which points to a specific place of the `DeleteQueue`. +/// +/// New consumer can be created in two ways +/// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete +/// operation (and some or none of the past operations... The client is in charge of checking the +/// opstamps.). +/// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can +/// now advance independently from the original cursor. #[derive(Default)] struct InnerDeleteQueue { writer: Vec, last_block: Weak, } -#[derive(Clone)] +/// The delete queue is a linked list storing delete operations. +/// +/// Several consumers can hold a reference to it. Delete operations +/// get dropped/gc'ed when no more consumers are holding a reference +/// to them. +#[derive(Clone, Default)] pub struct DeleteQueue { inner: Arc>, } impl DeleteQueue { - // Creates a new delete queue. - pub fn new() -> DeleteQueue { - DeleteQueue { - inner: Arc::default(), - } - } - fn get_last_block(&self) -> Arc { { // try get the last block with simply acquiring the read lock. @@ -58,10 +57,10 @@ impl DeleteQueue { block } - // Creates a new cursor that makes it possible to - // consume future delete operations. - // - // Past delete operations are not accessible. + /// Creates a new cursor that makes it possible to + /// consume future delete operations. + /// + /// Past delete operations are not accessible. pub fn cursor(&self) -> DeleteCursor { let last_block = self.get_last_block(); let operations_len = last_block.operations.len(); @@ -71,7 +70,7 @@ impl DeleteQueue { } } - // Appends a new delete operations. + /// Appends a new delete operations. pub fn push(&self, delete_operation: DeleteOperation) { self.inner .write() @@ -169,6 +168,7 @@ struct Block { next: NextBlock, } +/// As we process delete operations, keeps track of our position. #[derive(Clone)] pub struct DeleteCursor { block: Arc, @@ -261,7 +261,7 @@ mod tests { #[test] fn test_deletequeue() { - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let make_op = |i: usize| DeleteOperation { opstamp: i as u64, diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 1ba92d6de..1e07dd210 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -128,7 +128,7 @@ fn compute_deleted_bitset( /// is `==` target_opstamp. /// For instance, there was no delete operation between the state of the `segment_entry` and /// the `target_opstamp`, `segment_entry` is not updated. -pub(crate) fn advance_deletes( +pub fn advance_deletes( mut segment: Segment, segment_entry: &mut SegmentEntry, target_opstamp: Opstamp, @@ -303,7 +303,7 @@ impl IndexWriter { let (document_sender, document_receiver) = crossbeam_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS); - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let current_opstamp = index.load_metas()?.opstamp; diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index 2d86aa461..53cc57034 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -4,6 +4,7 @@ //! `IndexWriter` is the main entry point for that, which created from //! [`Index::writer`](crate::Index::writer). +/// Delete queue implementation for broadcasting delete operations to consumers. pub(crate) mod delete_queue; pub(crate) mod path_to_unordered_id; @@ -32,12 +33,11 @@ mod stamper; use crossbeam_channel as channel; use smallvec::SmallVec; -pub use self::index_writer::{IndexWriter, IndexWriterOptions}; +pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions}; pub use self::log_merge_policy::LogMergePolicy; pub use self::merge_operation::MergeOperation; pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy}; -use self::operation::AddOperation; -pub use self::operation::UserOperation; +pub use self::operation::{AddOperation, DeleteOperation, UserOperation}; pub use self::prepared_commit::PreparedCommit; pub use self::segment_entry::SegmentEntry; pub(crate) use self::segment_serializer::SegmentSerializer; diff --git a/src/indexer/operation.rs b/src/indexer/operation.rs index 69bffec17..9316f6fa7 100644 --- a/src/indexer/operation.rs +++ b/src/indexer/operation.rs @@ -5,14 +5,20 @@ use crate::Opstamp; /// Timestamped Delete operation. pub struct DeleteOperation { + /// Operation stamp. + /// It is used to check whether the delete operation + /// applies to an added document operation. pub opstamp: Opstamp, + /// Weight is used to define the set of documents to be deleted. pub target: Box, } /// Timestamped Add operation. #[derive(Eq, PartialEq, Debug)] pub struct AddOperation { + /// Operation stamp. pub opstamp: Opstamp, + /// Document to be added. pub document: D, } diff --git a/src/indexer/segment_register.rs b/src/indexer/segment_register.rs index 0e7046310..fa7bfafa4 100644 --- a/src/indexer/segment_register.rs +++ b/src/indexer/segment_register.rs @@ -117,7 +117,7 @@ mod tests { #[test] fn test_segment_register() { let inventory = SegmentMetaInventory::default(); - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let mut segment_register = SegmentRegister::default(); let segment_id_a = SegmentId::generate_random(); diff --git a/src/indexer/segment_writer.rs b/src/indexer/segment_writer.rs index 72152cffa..94e3f0de2 100644 --- a/src/indexer/segment_writer.rs +++ b/src/indexer/segment_writer.rs @@ -421,10 +421,9 @@ fn remap_and_write( #[cfg(test)] mod tests { use std::collections::BTreeMap; - use std::path::{Path, PathBuf}; + use std::path::Path; use columnar::ColumnType; - use tempfile::TempDir; use crate::collector::{Count, TopDocs}; use crate::directory::RamDirectory; @@ -1067,10 +1066,7 @@ mod tests { let mut schema_builder = Schema::builder(); schema_builder.add_text_field("title", text_options); let schema = schema_builder.build(); - let tempdir = TempDir::new().unwrap(); - let tempdir_path = PathBuf::from(tempdir.path()); - Index::create_in_dir(&tempdir_path, schema).unwrap(); - let index = Index::open_in_dir(tempdir_path).unwrap(); + let index = Index::create_in_ram(schema); let schema = index.schema(); let mut index_writer = index.writer(50_000_000).unwrap(); let title = schema.get_field("title").unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 22eab343a..2747fe8ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ //! //! ```rust //! # use std::path::Path; +//! # use std::fs; //! # use tempfile::TempDir; //! # use tantivy::collector::TopDocs; //! # use tantivy::query::QueryParser; @@ -27,8 +28,11 @@ //! # // Let's create a temporary directory for the //! # // sake of this example //! # if let Ok(dir) = TempDir::new() { -//! # run_example(dir.path()).unwrap(); -//! # dir.close().unwrap(); +//! # let index_path = dir.path().join("index"); +//! # // In case the directory already exists, we remove it +//! # let _ = fs::remove_dir_all(&index_path); +//! # fs::create_dir_all(&index_path).unwrap(); +//! # run_example(&index_path).unwrap(); //! # } //! # } //! # @@ -203,6 +207,7 @@ mod docset; mod reader; #[cfg(test)] +#[cfg(feature = "mmap")] mod compat_tests; pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer}; @@ -372,7 +377,7 @@ pub mod tests { use common::{BinarySerializable, FixedSize}; use query_grammar::{UserInputAst, UserInputLeaf, UserInputLiteral}; - use rand::distributions::{Bernoulli, Uniform}; + use rand::distr::{Bernoulli, Uniform}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use time::OffsetDateTime; @@ -423,7 +428,7 @@ pub mod tests { pub fn generate_nonunique_unsorted(max_value: u32, n_elems: usize) -> Vec { let seed: [u8; 32] = [1; 32]; StdRng::from_seed(seed) - .sample_iter(&Uniform::new(0u32, max_value)) + .sample_iter(&Uniform::new(0u32, max_value).unwrap()) .take(n_elems) .collect::>() } @@ -1170,12 +1175,11 @@ pub mod tests { #[test] fn test_validate_checksum() -> crate::Result<()> { - let index_path = tempfile::tempdir().expect("dir"); let mut builder = Schema::builder(); let body = builder.add_text_field("body", TEXT | STORED); let schema = builder.build(); - let index = Index::create_in_dir(&index_path, schema)?; - let mut writer: IndexWriter = index.writer(50_000_000)?; + let index = Index::create_in_ram(schema); + let mut writer: IndexWriter = index.writer_for_tests()?; writer.set_merge_policy(Box::new(NoMergePolicy)); for _ in 0..5000 { writer.add_document(doc!(body => "foo"))?; diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index 487da620c..0ddf7e3df 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -1,12 +1,15 @@ use bitpacking::{BitPacker, BitPacker4x}; -use common::FixedSize; pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN; -const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES; +// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to +// store a u32 in the worst case, rounding up to 5 bytes total +const MAX_VINT_SIZE: usize = 5; +const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE; mod vint; /// Returns the size in bytes of a compressed block, given `num_bits`. +#[inline] pub fn compressed_block_size(num_bits: u8) -> usize { (num_bits as usize) * COMPRESSION_BLOCK_SIZE / 8 } @@ -267,7 +270,6 @@ impl VIntDecoder for BlockDecoder { #[cfg(test)] pub(crate) mod tests { - use super::*; use crate::TERMINATED; @@ -372,6 +374,13 @@ pub(crate) mod tests { } } } + + #[test] + fn test_compress_vint_unsorted_does_not_overflow() { + let mut encoder = BlockEncoder::new(); + let input: Vec = vec![u32::MAX; COMPRESSION_BLOCK_SIZE]; + encoder.compress_vint_unsorted(&input); + } } #[cfg(all(test, feature = "unstable"))] @@ -388,7 +397,10 @@ mod bench { let mut seed: [u8; 32] = [0; 32]; seed[31] = seed_val; let mut rng = StdRng::from_seed(seed); - (0u32..).filter(|_| rng.gen_bool(ratio)).take(n).collect() + (0u32..) + .filter(|_| rng.random_bool(ratio)) + .take(n) + .collect() } pub fn generate_array(n: usize, ratio: f64) -> Vec { diff --git a/src/postings/mod.rs b/src/postings/mod.rs index efc0e069d..d60ad597d 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -527,6 +527,7 @@ pub(crate) mod tests { } impl Scorer for UnoptimizedDocSet { + #[inline] fn score(&mut self) -> Score { self.0.score() } @@ -603,13 +604,13 @@ mod bench { let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for _ in 0..posting_list_size { let mut doc = TantivyDocument::default(); - if rng.gen_bool(1f64 / 15f64) { + if rng.random_bool(1f64 / 15f64) { doc.add_text(text_field, "a"); } - if rng.gen_bool(1f64 / 10f64) { + if rng.random_bool(1f64 / 10f64) { doc.add_text(text_field, "b"); } - if rng.gen_bool(1f64 / 5f64) { + if rng.random_bool(1f64 / 5f64) { doc.add_text(text_field, "c"); } doc.add_text(text_field, "d"); diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index d9ba33eb2..e9046bd3c 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -70,13 +70,13 @@ impl SegmentPostings { let mut buffer = Vec::new(); { let mut postings_serializer = - PostingsSerializer::new(&mut buffer, 0.0, IndexRecordOption::Basic, None); + PostingsSerializer::new(0.0, IndexRecordOption::Basic, None); postings_serializer.new_term(docs.len() as u32, false); for &doc in docs { postings_serializer.write_doc(doc, 1u32); } postings_serializer - .close_term(docs.len() as u32) + .close_term(docs.len() as u32, &mut buffer) .expect("In memory Serialization should never fail."); } let block_segment_postings = BlockSegmentPostings::open( @@ -115,7 +115,6 @@ impl SegmentPostings { }) .unwrap_or(0.0); let mut postings_serializer = PostingsSerializer::new( - &mut buffer, average_field_norm, IndexRecordOption::WithFreqs, fieldnorm_reader, @@ -125,7 +124,7 @@ impl SegmentPostings { postings_serializer.write_doc(doc, tf); } postings_serializer - .close_term(doc_and_tfs.len() as u32) + .close_term(doc_and_tfs.len() as u32, &mut buffer) .unwrap(); let block_segment_postings = BlockSegmentPostings::open( doc_and_tfs.len() as u32, diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index c0ee8483c..08c3c7542 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -104,10 +104,12 @@ impl InvertedIndexSerializer { /// the serialization of a specific field. pub struct FieldSerializer<'a> { term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter>, - postings_serializer: PostingsSerializer<&'a mut CountingWriter>, + postings_serializer: PostingsSerializer, positions_serializer_opt: Option>>, current_term_info: TermInfo, term_open: bool, + postings_write: &'a mut CountingWriter, + postings_start_offset: u64, } impl<'a> FieldSerializer<'a> { @@ -128,27 +130,30 @@ impl<'a> FieldSerializer<'a> { .as_ref() .map(|ff_reader| total_num_tokens as Score / ff_reader.num_docs() as Score) .unwrap_or(0.0); - let postings_serializer = PostingsSerializer::new( - postings_write, - average_fieldnorm, - index_record_option, - fieldnorm_reader, - ); + let postings_serializer = + PostingsSerializer::new(average_fieldnorm, index_record_option, fieldnorm_reader); let positions_serializer_opt = if index_record_option.has_positions() { Some(PositionSerializer::new(positions_write)) } else { None }; + let postings_start_offset = postings_write.written_bytes(); Ok(FieldSerializer { term_dictionary_builder, postings_serializer, positions_serializer_opt, current_term_info: TermInfo::default(), term_open: false, + postings_write, + postings_start_offset, }) } + fn postings_offset(&self) -> usize { + (self.postings_write.written_bytes() - self.postings_start_offset) as usize + } + fn current_term_info(&self) -> TermInfo { let positions_start = if let Some(positions_serializer) = self.positions_serializer_opt.as_ref() { @@ -156,7 +161,7 @@ impl<'a> FieldSerializer<'a> { } else { 0u64 } as usize; - let addr = self.postings_serializer.written_bytes() as usize; + let addr = self.postings_offset(); TermInfo { doc_freq: 0, postings_range: addr..addr, @@ -213,21 +218,22 @@ impl<'a> FieldSerializer<'a> { crate::fail_point!("FieldSerializer::close_term", |msg: Option| { Err(io::Error::new(io::ErrorKind::Other, format!("{msg:?}"))) }); - if self.term_open { - self.postings_serializer - .close_term(self.current_term_info.doc_freq)?; - self.current_term_info.postings_range.end = - self.postings_serializer.written_bytes() as usize; - if let Some(positions_serializer) = self.positions_serializer_opt.as_mut() { - positions_serializer.close_term()?; - self.current_term_info.positions_range.end = - positions_serializer.written_bytes() as usize; - } - self.term_dictionary_builder - .insert_value(&self.current_term_info)?; - self.term_open = false; + if !self.term_open { + return Ok(()); + }; + + self.postings_serializer + .close_term(self.current_term_info.doc_freq, self.postings_write)?; + self.current_term_info.postings_range.end = self.postings_offset(); + if let Some(positions_serializer) = self.positions_serializer_opt.as_mut() { + positions_serializer.close_term()?; + self.current_term_info.positions_range.end = + positions_serializer.written_bytes() as usize; } + self.term_dictionary_builder + .insert_value(&self.current_term_info)?; + self.term_open = false; Ok(()) } @@ -237,7 +243,7 @@ impl<'a> FieldSerializer<'a> { if let Some(positions_serializer) = self.positions_serializer_opt { positions_serializer.close()?; } - self.postings_serializer.close()?; + self.postings_write.flush()?; self.term_dictionary_builder.finish()?; Ok(()) } @@ -291,8 +297,7 @@ impl Block { } } -pub struct PostingsSerializer { - output_write: CountingWriter, +pub struct PostingsSerializer { last_doc_id_encoded: u32, block_encoder: BlockEncoder, @@ -310,16 +315,13 @@ pub struct PostingsSerializer { term_has_freq: bool, } -impl PostingsSerializer { +impl PostingsSerializer { pub fn new( - write: W, avg_fieldnorm: Score, mode: IndexRecordOption, fieldnorm_reader: Option, - ) -> PostingsSerializer { + ) -> PostingsSerializer { PostingsSerializer { - output_write: CountingWriter::wrap(write), - block_encoder: BlockEncoder::new(), block: Box::new(Block::new()), @@ -422,11 +424,11 @@ impl PostingsSerializer { } } - fn close(mut self) -> io::Result<()> { - self.postings_write.flush() - } - - pub fn close_term(&mut self, doc_freq: u32) -> io::Result<()> { + pub fn close_term( + &mut self, + doc_freq: u32, + output_write: &mut impl std::io::Write, + ) -> io::Result<()> { if !self.block.is_empty() { // we have doc ids waiting to be written // this happens when the number of doc ids is @@ -451,26 +453,16 @@ impl PostingsSerializer { } if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { let skip_data = self.skip_write.data(); - VInt(skip_data.len() as u64).serialize(&mut self.output_write)?; - self.output_write.write_all(skip_data)?; + VInt(skip_data.len() as u64).serialize(output_write)?; + output_write.write_all(skip_data)?; } - self.output_write.write_all(&self.postings_write[..])?; + output_write.write_all(&self.postings_write[..])?; self.skip_write.clear(); self.postings_write.clear(); self.bm25_weight = None; Ok(()) } - /// Returns the number of bytes written in the postings write object - /// at this point. - /// When called before writing the postings of a term, this value is used as - /// start offset. - /// When called after writing the postings of a term, this value is used as a - /// end offset. - fn written_bytes(&self) -> u64 { - self.output_write.written_bytes() - } - fn clear(&mut self) { self.block.clear(); self.last_doc_id_encoded = 0; diff --git a/src/postings/skip.rs b/src/postings/skip.rs index c36690444..dd762ca46 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED}; // doc num bits uses the following encoding: // given 0b a b cdefgh -// |1|2| 3 | +// |1|2|3| 4 | // - 1: unused // - 2: is delta-1 encoded. 0 if not, 1, if yes -// - 3: a 6 bit number in 0..=32, the actual bitwidth +// - 3: unused +// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32 +// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can +// be represented on 31b without delta encoding. fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 { + assert!(bitwidth < 32); bitwidth | ((delta_1 as u8) << 6) } fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) { let delta_1 = ((raw_bitwidth >> 6) & 1) != 0; - let bitwidth = raw_bitwidth & 0x3f; + let bitwidth = raw_bitwidth & 0x1f; (bitwidth, delta_1) } @@ -430,7 +434,7 @@ mod tests { #[test] fn test_encode_decode_bitwidth() { - for bitwidth in 0..=32 { + for bitwidth in 0..32 { for delta_1 in [false, true] { assert_eq!( (bitwidth, delta_1), diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 11172f9ed..5431a3a1b 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -23,7 +23,11 @@ pub struct AllWeight; impl Weight for AllWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { let all_scorer = AllScorer::new(reader.max_doc()); - Ok(Box::new(BoostScorer::new(all_scorer, boost))) + if boost != 1.0 { + Ok(Box::new(BoostScorer::new(all_scorer, boost))) + } else { + Ok(Box::new(all_scorer)) + } } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { @@ -58,6 +62,15 @@ impl DocSet for AllScorer { self.doc } + fn seek(&mut self, target: DocId) -> DocId { + debug_assert!(target >= self.doc); + self.doc = target; + if self.doc >= self.max_doc { + self.doc = TERMINATED; + } + self.doc + } + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { if self.doc() == TERMINATED { return 0; @@ -92,6 +105,7 @@ impl DocSet for AllScorer { } impl Scorer for AllScorer { + #[inline] fn score(&mut self) -> Score { 1.0 } diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index c6710b09c..6b2f2d6e3 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -483,7 +483,7 @@ mod tests { let checkpoints_for_each_pruning = compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k); let checkpoints_manual = - compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000); + compute_checkpoints_manual(term_scorers.clone(), top_k, max_doc as u32); assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len()); for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning .iter() diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 9e8cedf2e..c46e9b0b1 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -97,6 +97,65 @@ fn into_box_scorer( } } +/// Returns the effective MUST scorer, accounting for removed AllScorers. +/// +/// When AllScorer instances are removed from must_scorers as an optimization, +/// we must restore the "match all" semantics if the list becomes empty. +fn effective_must_scorer( + must_scorers: Vec>, + removed_all_scorer_count: usize, + max_doc: DocId, + num_docs: u32, +) -> Option> { + if must_scorers.is_empty() { + if removed_all_scorer_count > 0 { + // Had AllScorer(s) only - all docs match + Some(Box::new(AllScorer::new(max_doc))) + } else { + // No MUST constraint at all + None + } + } else { + Some(intersect_scorers(must_scorers, num_docs)) + } +} + +/// Returns a SHOULD scorer with AllScorer union if any were removed. +/// +/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result +/// should include all documents. We restore this by unioning with AllScorer. +/// +/// When `scoring_enabled` is false, we can just return AllScorer alone since +/// we don't need score contributions from the should_scorer. +fn effective_should_scorer_for_union( + should_scorer: SpecializedScorer, + removed_all_scorer_count: usize, + max_doc: DocId, + num_docs: u32, + score_combiner_fn: impl Fn() -> TScoreCombiner, + scoring_enabled: bool, +) -> SpecializedScorer { + if removed_all_scorer_count > 0 { + if scoring_enabled { + // Need to union to get score contributions from both + let all_scorers: Vec> = vec![ + into_box_scorer(should_scorer, &score_combiner_fn, num_docs), + Box::new(AllScorer::new(max_doc)), + ]; + SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( + all_scorers, + score_combiner_fn, + num_docs, + ))) + } else { + // Scoring disabled - AllScorer alone is sufficient + SpecializedScorer::Other(Box::new(AllScorer::new(max_doc))) + } + } else { + should_scorer + } +} + enum ShouldScorersCombinationMethod { // Should scorers are irrelevant. Ignored, @@ -193,18 +252,18 @@ impl BooleanWeight { return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - let minimum_number_should_match = self + let effective_minimum_number_should_match = self .minimum_number_should_match .saturating_sub(should_special_scorer_counts.num_all_scorers); let should_scorers: ShouldScorersCombinationMethod = { let num_of_should_scorers = should_scorers.len(); - if minimum_number_should_match > num_of_should_scorers { + if effective_minimum_number_should_match > num_of_should_scorers { // We don't have enough scorers to satisfy the minimum number of should matches. // The request will match no documents. return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - match minimum_number_should_match { + match effective_minimum_number_should_match { 0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored, 0 => ShouldScorersCombinationMethod::Optional(scorer_union( should_scorers, @@ -226,7 +285,7 @@ impl BooleanWeight { scorer_disjunction( should_scorers, score_combiner_fn(), - self.minimum_number_should_match, + effective_minimum_number_should_match, ), )), } @@ -246,53 +305,78 @@ impl BooleanWeight { let include_scorer = match (should_scorers, must_scorers) { (ShouldScorersCombinationMethod::Ignored, must_scorers) => { - let boxed_scorer: Box = if must_scorers.is_empty() { - // We do not have any should scorers, nor all scorers. - // There are still two cases here. - // - // If this follows the removal of some AllScorers in the should/must clauses, - // then we match all documents. - // - // Otherwise, it is really just an EmptyScorer. - if must_special_scorer_counts.num_all_scorers - + should_special_scorer_counts.num_all_scorers - > 0 - { - Box::new(AllScorer::new(reader.max_doc())) - } else { - Box::new(EmptyScorer) - } - } else { - intersect_scorers(must_scorers, num_docs) - }; + // No SHOULD clauses (or they were absorbed into MUST). + // Result depends entirely on MUST + any removed AllScorers. + let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers + + should_special_scorer_counts.num_all_scorers; + let boxed_scorer: Box = effective_must_scorer( + must_scorers, + combined_all_scorer_count, + reader.max_doc(), + num_docs, + ) + .unwrap_or_else(|| Box::new(EmptyScorer)); SpecializedScorer::Other(boxed_scorer) } (ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => { - if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 { - // Optional options are promoted to required if no must scorers exists. - should_scorer - } else { - let must_scorer = intersect_scorers(must_scorers, num_docs); - if self.scoring_enabled { - SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< - _, - _, - TScoreCombiner, - >::new( - must_scorer, - into_box_scorer(should_scorer, &score_combiner_fn, num_docs), - ))) - } else { - SpecializedScorer::Other(must_scorer) + // Optional SHOULD: contributes to scoring but not required for matching. + match effective_must_scorer( + must_scorers, + must_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + ) { + None => { + // No MUST constraint: promote SHOULD to required. + // Must preserve any removed AllScorers from SHOULD via union. + effective_should_scorer_for_union( + should_scorer, + should_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + &score_combiner_fn, + self.scoring_enabled, + ) + } + Some(must_scorer) => { + // Has MUST constraint: SHOULD only affects scoring. + if self.scoring_enabled { + SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< + _, + _, + TScoreCombiner, + >::new( + must_scorer, + into_box_scorer(should_scorer, &score_combiner_fn, num_docs), + ))) + } else { + SpecializedScorer::Other(must_scorer) + } } } } - (ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => { - if must_scorers.is_empty() { - should_scorer - } else { - must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs)); - SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) + (ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => { + // Required SHOULD: at least `minimum_number_should_match` must match. + // Semantics: (MUST constraint) AND (SHOULD constraint) + match effective_must_scorer( + must_scorers, + must_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + ) { + None => { + // No MUST constraint: SHOULD alone determines matching. + should_scorer + } + Some(must_scorer) => { + // Has MUST constraint: intersect MUST with SHOULD. + let should_boxed = + into_box_scorer(should_scorer, &score_combiner_fn, num_docs); + SpecializedScorer::Other(intersect_scorers( + vec![must_scorer, should_boxed], + num_docs, + )) + } } } }; diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 0ddc5a26c..681881c11 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -9,12 +9,14 @@ pub use self::boolean_weight::BooleanWeight; #[cfg(test)] mod tests { + use std::ops::Bound; + use super::*; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; - use crate::collector::TopDocs; + use crate::collector::{Count, TopDocs}; use crate::query::term_query::TermScorer; use crate::query::{ - AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, + AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery, RequiredOptionalScorer, Scorer, SumCombiner, TermQuery, }; use crate::schema::*; @@ -374,4 +376,466 @@ mod tests { } Ok(()) } + + #[test] + pub fn test_min_should_match_with_all_query() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + let effective_all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "apple"), + IndexRecordOption::Basic, + )); + + // in some previous version, we would remove the 2 all_match, but then say we need *4* + // matches out of the 3 term queries, which matches nothing. + let mut bool_query = BooleanQuery::new(vec![ + (Occur::Should, effective_all_match_query.box_clone()), + (Occur::Should, effective_all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + (Occur::Should, term_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + bool_query.set_minimum_number_should_match(4); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 1); + + Ok(()) + } + + // ========================================================================= + // AllScorer Preservation Regression Tests + // ========================================================================= + // + // These tests verify the fix for a bug where AllScorer instances (produced by + // queries matching all documents, such as range queries covering all values) + // were incorrectly removed from Boolean query processing, causing documents + // to be unexpectedly excluded from results. + // + // The bug manifested in several scenarios: + // 1. SHOULD + SHOULD where one clause is AllScorer + // 2. MUST (AllScorer) + SHOULD + // 3. Range queries in Boolean clauses when all documents match the range + + /// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses. + /// + /// When a SHOULD clause produces an AllScorer (e.g., from a range query matching + /// all documents), the Boolean query should still match all documents. + /// + /// Bug before fix: AllScorer was removed during optimization, leaving only the + /// other SHOULD clauses, which incorrectly excluded documents. + #[test] + pub fn test_should_with_all_scorer_regression() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range query will return AllScorer + index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?; + index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?; + index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?; + index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Range query matching all docs (returns AllScorer) + let all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "hello"), + IndexRecordOption::Basic, + )); + + // Verify range matches all 6 docs + assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6); + + // RangeQuery(all) OR TermQuery should match all 6 docs + let bool_query = BooleanQuery::new(vec![ + (Occur::Should, all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 6, "SHOULD with AllScorer should match all docs"); + + // Order should not matter + let bool_query_reversed = BooleanQuery::new(vec![ + (Occur::Should, term_query.box_clone()), + (Occur::Should, all_match_query.box_clone()), + ]); + let count_reversed = searcher.search(&bool_query_reversed, &Count)?; + assert_eq!( + count_reversed, 6, + "Order of SHOULD clauses should not matter" + ); + + Ok(()) + } + + /// Regression test: MUST clause with AllScorer combined with SHOULD clause. + /// + /// When MUST contains an AllScorer, all documents satisfy the MUST constraint. + /// The SHOULD clause should only affect scoring, not filtering. + /// + /// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector. + /// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents. + #[test] + pub fn test_must_all_with_should_regression() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range query will return AllScorer + index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?; + index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Range query matching all docs (returns AllScorer) + let all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "apple"), + IndexRecordOption::Basic, + )); + + // Verify range matches all 4 docs + assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4); + + // MUST(range matching all) AND SHOULD(term) should match all 4 docs + let bool_query = BooleanQuery::new(vec![ + (Occur::Must, all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs"); + + Ok(()) + } + + /// Regression test: Range queries in Boolean clauses when all documents match. + /// + /// Range queries can return AllScorer as an optimization when all indexed values + /// fall within the range. This test ensures such queries work correctly in + /// Boolean combinations. + /// + /// This is the most common real-world manifestation of the bug, occurring in + /// queries like: (age > 50 OR name = 'Alice') AND status = 'active' + /// when all documents have age > 50. + #[test] + pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let name_field = schema_builder.add_text_field("name", TEXT); + let age_field = + schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All documents have age > 50, so range query will return AllScorer + index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?; + index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?; + index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?; + index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + let range_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(age_field, 50)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(name_field, "alice"), + IndexRecordOption::Basic, + )); + + // Verify preconditions + assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4); + assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1); + + // SHOULD(range) OR SHOULD(term): range matches all, so result is 4 + let should_query = BooleanQuery::new(vec![ + (Occur::Should, range_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&should_query, &Count)?, + 4, + "SHOULD range OR term should match all" + ); + + // MUST(range) AND SHOULD(term): range matches all, term is optional + let must_should_query = BooleanQuery::new(vec![ + (Occur::Must, range_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&must_should_query, &Count)?, + 4, + "MUST range + SHOULD term should match all" + ); + + Ok(()) + } + + /// Test multiple AllScorer instances in different clause types. + /// + /// Verifies correct behavior when AllScorers appear in multiple positions. + #[test] + pub fn test_multiple_all_scorers() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range queries will return AllScorer + index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Two different range queries that both match all docs (return AllScorer) + let all_query1: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let all_query2: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 5)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "doc1"), + IndexRecordOption::Basic, + )); + + // Multiple AllScorers in SHOULD + let multi_all_should = BooleanQuery::new(vec![ + (Occur::Should, all_query1.box_clone()), + (Occur::Should, all_query2.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&multi_all_should, &Count)?, + 3, + "Multiple AllScorers in SHOULD" + ); + + // AllScorer in both MUST and SHOULD + let all_must_and_should = BooleanQuery::new(vec![ + (Occur::Must, all_query1.box_clone()), + (Occur::Should, all_query2.box_clone()), + ]); + assert_eq!( + searcher.search(&all_must_and_should, &Count)?, + 3, + "AllScorer in both MUST and SHOULD" + ); + + Ok(()) + } +} + +/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches +/// the result against an index which contains all permutations of documents with N fields. +#[cfg(test)] +mod proptest_boolean_query { + use std::collections::{BTreeMap, HashSet}; + use std::ops::{Bound, Range}; + + use proptest::collection::vec; + use proptest::prelude::*; + + use crate::collector::DocSetCollector; + use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery}; + use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT}; + use crate::{DocId, Index, Term}; + + #[derive(Debug, Clone)] + enum BooleanQueryAST { + /// Matches all documents via AllQuery (wraps AllScorer in BoostScorer) + All, + /// Matches all documents via RangeQuery (returns bare AllScorer) + /// This is the actual trigger for the AllScorer preservation bug + RangeAll, + /// Matches documents where the field has value "true" + Leaf { + field_idx: usize, + }, + Union(Vec), + Intersection(Vec), + } + + impl BooleanQueryAST { + fn matches(&self, doc_id: DocId) -> bool { + match self { + BooleanQueryAST::All => true, + BooleanQueryAST::RangeAll => true, + BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx), + BooleanQueryAST::Union(children) => { + children.iter().any(|child| child.matches(doc_id)) + } + BooleanQueryAST::Intersection(children) => { + children.iter().all(|child| child.matches(doc_id)) + } + } + } + + fn matches_field(doc_id: DocId, field_idx: usize) -> bool { + ((doc_id as usize) >> field_idx) & 1 == 1 + } + + fn to_query(&self, fields: &[Field], range_field: Field) -> Box { + match self { + BooleanQueryAST::All => Box::new(AllQuery), + BooleanQueryAST::RangeAll => { + // Range query that matches all docs (all have value >= 0) + // This returns bare AllScorer, triggering the bug we fixed + Box::new(RangeQuery::new( + Bound::Included(Term::from_field_i64(range_field, 0)), + Bound::Unbounded, + )) + } + BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new( + Term::from_field_text(fields[*field_idx], "true"), + crate::schema::IndexRecordOption::Basic, + )), + BooleanQueryAST::Union(children) => { + let sub_queries = children + .iter() + .map(|child| (Occur::Should, child.to_query(fields, range_field))) + .collect(); + Box::new(BooleanQuery::new(sub_queries)) + } + BooleanQueryAST::Intersection(children) => { + let sub_queries = children + .iter() + .map(|child| (Occur::Must, child.to_query(fields, range_field))) + .collect(); + Box::new(BooleanQuery::new(sub_queries)) + } + } + } + } + + fn doc_ids(num_docs: usize, num_fields: usize) -> Range { + let permutations = 1 << num_fields; + let copies = (num_docs as f32 / permutations as f32).ceil() as u32; + 0..(permutations * copies) + } + + fn create_index_with_boolean_permutations( + num_docs: usize, + num_fields: usize, + ) -> (Index, Vec, Field) { + let mut schema_builder = Schema::builder(); + let fields: Vec = (0..num_fields) + .map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT)) + .collect(); + // Add a numeric field for RangeQuery tests - all docs have value = doc_id + let range_field = schema_builder.add_i64_field( + "range_field", + NumericOptions::default().set_fast().set_indexed(), + ); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + + for doc_id in doc_ids(num_docs, num_fields) { + let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default(); + for (field_idx, &field) in fields.iter().enumerate() { + if (doc_id >> field_idx) & 1 == 1 { + doc.insert(field, "true".into()); + } + } + // All docs have non-negative values, so RangeQuery(>=0) matches all + doc.insert(range_field, (doc_id as i64).into()); + writer.add_document(doc).unwrap(); + } + writer.commit().unwrap(); + (index, fields, range_field) + } + + fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy { + // Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs + let leaf = prop_oneof![ + (0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }), + Just(BooleanQueryAST::All), + Just(BooleanQueryAST::RangeAll), + ]; + leaf.prop_recursive( + 8, // 8 levels of recursion + 256, // 256 nodes max + 10, // 10 items per collection + |inner| { + prop_oneof![ + vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union), + vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection), + ] + }, + ) + } + + #[test] + fn proptest_boolean_query() { + // In the presence of optimizations around buffering, it can take large numbers of + // documents to uncover some issues. + let num_fields = 8; + let num_docs = 1 << num_fields; + let (index, fields, range_field) = + create_index_with_boolean_permutations(num_docs, num_fields); + let searcher = index.reader().unwrap().searcher(); + proptest!(|(ast in arb_boolean_query_ast(num_fields))| { + let query = ast.to_query(&fields, range_field); + + let mut matching_docs = HashSet::new(); + for doc_id in doc_ids(num_docs, num_fields) { + if ast.matches(doc_id as DocId) { + matching_docs.insert(doc_id as DocId); + } + } + + let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap(); + let result_docs: HashSet = + doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect(); + prop_assert_eq!(result_docs, matching_docs); + }); + } } diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index 06678287f..cc4c10f7a 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -104,6 +104,9 @@ impl DocSet for BoostScorer { fn seek(&mut self, target: DocId) -> DocId { self.underlying.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + self.underlying.seek_into_the_danger_zone(target) + } fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { self.underlying.fill_buffer(buffer) @@ -131,6 +134,7 @@ impl DocSet for BoostScorer { } impl Scorer for BoostScorer { + #[inline] fn score(&mut self) -> Score { self.underlying.score() * self.boost } diff --git a/src/query/const_score_query.rs b/src/query/const_score_query.rs index 570c7feca..d07e6a96f 100644 --- a/src/query/const_score_query.rs +++ b/src/query/const_score_query.rs @@ -137,6 +137,7 @@ impl DocSet for ConstScorer { } impl Scorer for ConstScorer { + #[inline] fn score(&mut self) -> Score { self.score } diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index 910e207df..ca7eab20d 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -62,6 +62,16 @@ impl DocSet for ScorerWrapper { self.current_doc = doc_id; doc_id } + fn seek(&mut self, target: DocId) -> DocId { + let doc_id = self.scorer.seek(target); + self.current_doc = doc_id; + doc_id + } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let found = self.scorer.seek_into_the_danger_zone(target); + self.current_doc = self.scorer.doc(); + found + } fn doc(&self) -> DocId { self.current_doc @@ -163,6 +173,7 @@ impl DocSet impl Scorer for Disjunction { + #[inline] fn score(&mut self) -> Score { self.current_score } @@ -297,6 +308,7 @@ mod tests { } impl Scorer for DummyScorer { + #[inline] fn score(&mut self) -> Score { self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0) } diff --git a/src/query/empty_query.rs b/src/query/empty_query.rs index 86ff84c08..2fa1772bd 100644 --- a/src/query/empty_query.rs +++ b/src/query/empty_query.rs @@ -55,6 +55,7 @@ impl DocSet for EmptyScorer { } impl Scorer for EmptyScorer { + #[inline] fn score(&mut self) -> Score { 0.0 } diff --git a/src/query/exclude.rs b/src/query/exclude.rs index 0b13e66e0..15e609c1e 100644 --- a/src/query/exclude.rs +++ b/src/query/exclude.rs @@ -84,6 +84,7 @@ where TScorer: Scorer, TDocSetExclude: DocSet + 'static, { + #[inline] fn score(&mut self) -> Score { self.underlying_docset.score() } diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 10e257c43..d536dcf05 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,5 +1,5 @@ +use super::size_hint::estimate_intersection; use crate::docset::{DocSet, TERMINATED}; -use crate::query::size_hint::estimate_intersection; use crate::query::term_query::TermScorer; use crate::query::{EmptyScorer, Scorer}; use crate::{DocId, Score}; @@ -12,6 +12,9 @@ use crate::{DocId, Score}; /// For better performance, the function uses a /// specialized implementation if the two /// shortest scorers are `TermScorer`s. +/// +/// num_docs_segment is the number of documents in the segment. It is used for estimating the +/// `size_hint` of the intersection. pub fn intersect_scorers( mut scorers: Vec>, num_docs_segment: u32, @@ -102,35 +105,48 @@ impl Intersection { } impl DocSet for Intersection { + #[inline] fn advance(&mut self) -> DocId { let (left, right) = (&mut self.left, &mut self.right); let mut candidate = left.advance(); + if candidate == TERMINATED { + return TERMINATED; + } - 'outer: loop { + loop { // In the first part we look for a document in the intersection // of the two rarest `DocSet` in the intersection. loop { - let right_doc = right.seek(candidate); - candidate = left.seek(right_doc); - if candidate == right_doc { + if right.seek_into_the_danger_zone(candidate) { break; } + let right_doc = right.doc(); + // TODO: Think about which value would make sense here + // It depends on the DocSet implementation, when a seek would outweigh an advance. + if right_doc > candidate.wrapping_add(100) { + candidate = left.seek(right_doc); + } else { + candidate = left.advance(); + } + if candidate == TERMINATED { + return TERMINATED; + } } debug_assert_eq!(left.doc(), right.doc()); - // test the remaining scorers; - for docset in self.others.iter_mut() { - let seek_doc = docset.seek(candidate); - if seek_doc > candidate { - candidate = left.seek(seek_doc); - continue 'outer; - } + // test the remaining scorers + if self + .others + .iter_mut() + .all(|docset| docset.seek_into_the_danger_zone(candidate)) + { + debug_assert_eq!(candidate, self.left.doc()); + debug_assert_eq!(candidate, self.right.doc()); + debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); + return candidate; } - debug_assert_eq!(candidate, self.left.doc()); - debug_assert_eq!(candidate, self.right.doc()); - debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); - return candidate; + candidate = left.advance(); } } @@ -146,6 +162,20 @@ impl DocSet for Intersection bool { + self.left.seek_into_the_danger_zone(target) + && self.right.seek_into_the_danger_zone(target) + && self + .others + .iter_mut() + .all(|docset| docset.seek_into_the_danger_zone(target)) + } + + #[inline] fn doc(&self) -> DocId { self.left.doc() } @@ -172,6 +202,7 @@ where TScorer: Scorer, TOtherScorer: Scorer, { + #[inline] fn score(&mut self) -> Score { self.left.score() + self.right.score() @@ -181,6 +212,8 @@ where #[cfg(test)] mod tests { + use proptest::prelude::*; + use super::Intersection; use crate::docset::{DocSet, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; @@ -270,4 +303,38 @@ mod tests { let intersection = Intersection::new(vec![a, b, c], 10); assert_eq!(intersection.doc(), TERMINATED); } + + // Strategy to generate sorted and deduplicated vectors of u32 document IDs + fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy> { + prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| { + vec.sort(); + vec.dedup(); + vec + }) + } + + proptest! { + #[test] + fn prop_test_intersection_consistency( + a in sorted_deduped_vec(100, 10), + b in sorted_deduped_vec(100, 10), + num_docs in 100u32..500u32 + ) { + let left = VecDocSet::from(a.clone()); + let right = VecDocSet::from(b.clone()); + let mut intersection = Intersection::new(vec![left, right], num_docs); + + let expected: Vec = a.iter() + .cloned() + .filter(|doc| b.contains(doc)) + .collect(); + + for expected_doc in expected { + assert_eq!(intersection.doc(), expected_doc); + intersection.advance(); + } + assert_eq!(intersection.doc(), TERMINATED); + } + + } } diff --git a/src/query/mod.rs b/src/query/mod.rs index d609a0402..0bc865921 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -70,9 +70,83 @@ pub use self::weight::Weight; #[cfg(test)] mod tests { + use crate::collector::TopDocs; + use crate::query::phrase_query::tests::create_index; use crate::query::QueryParser; use crate::schema::{Schema, TEXT}; - use crate::{Index, Term}; + use crate::{DocAddress, Index, Term}; + + #[test] + pub fn test_mixed_intersection_and_union() -> crate::Result<()> { + let index = create_index(&["a b", "a c", "a b c", "b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 2]); + assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![2, 1, 0] + ); + + Ok(()) + } + + #[test] + pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> { + // Test 4096 skip in BufferedUnionScorer + let mut data: Vec<&str> = Vec::new(); + data.push("a b"); + let zz_data = vec!["z z"; 5000]; + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a c"]); + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a b c", "b"]); + let index = create_index(&data)?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 10002]); + assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![10002, 5001, 0] + ); + + Ok(()) + } #[test] fn test_query_terms() { diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index 14933f3ae..8b03089fa 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -81,6 +81,7 @@ impl DocSet for PhraseKind { } impl Scorer for PhraseKind { + #[inline] fn score(&mut self) -> Score { match self { PhraseKind::SinglePrefix { positions, .. } => { @@ -193,6 +194,14 @@ impl DocSet for PhrasePrefixScorer { self.advance() } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + if self.phrase_scorer.seek_into_the_danger_zone(target) { + self.matches_prefix() + } else { + false + } + } + fn doc(&self) -> DocId { self.phrase_scorer.doc() } @@ -207,6 +216,7 @@ impl DocSet for PhrasePrefixScorer { } impl Scorer for PhrasePrefixScorer { + #[inline] fn score(&mut self) -> Score { // TODO modify score?? self.phrase_scorer.score() diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 12a94dce3..108783b40 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -382,8 +382,9 @@ impl PhraseScorer { PostingsWithOffset::new(postings, (max_offset - offset) as u32) }) .collect::>(); + let intersection_docset = Intersection::new(postings_with_offsets, num_docs); let mut scorer = PhraseScorer { - intersection_docset: Intersection::new(postings_with_offsets, num_docs), + intersection_docset, num_terms: num_docsets, left_positions: Vec::with_capacity(100), right_positions: Vec::with_capacity(100), @@ -529,25 +530,40 @@ impl DocSet for PhraseScorer { self.advance() } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + debug_assert!(target >= self.doc()); + if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() { + return true; + } + false + } + fn doc(&self) -> DocId { self.intersection_docset.doc() } fn size_hint(&self) -> u32 { - self.intersection_docset.size_hint() + // We adjust the intersection estimate, since actual phrase hits are much lower than where + // the all appear. + // The estimate should depend on average field length, e.g. if the field is really short + // a phrase hit is more likely + self.intersection_docset.size_hint() / (10 * self.num_terms as u32) } /// Returns a best-effort hint of the /// cost to drive the docset. fn cost(&self) -> u64 { - // Evaluating phrase matches is generally more expensive than simple term matches, - // as it requires loading and comparing positions. Use a conservative multiplier - // based on the number of terms. + // While determing a potential hit is cheap for phrases, evaluating an actual hit is + // expensive since it requires to load positions for a doc and check if they are next to + // each other. + // So the cost estimation would be the number of times we need to check if a doc is a hit * + // 10 * self.num_terms. self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64 } } impl Scorer for PhraseScorer { + #[inline] fn score(&mut self) -> Score { let doc = self.doc(); let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); diff --git a/src/query/phrase_query/regex_phrase_weight.rs b/src/query/phrase_query/regex_phrase_weight.rs index 4e850d2e2..9cefc555a 100644 --- a/src/query/phrase_query/regex_phrase_weight.rs +++ b/src/query/phrase_query/regex_phrase_weight.rs @@ -311,7 +311,7 @@ mod tests { #![proptest_config(ProptestConfig::with_cases(50))] #[test] fn test_phrase_regex_with_random_strings(mut random_strings in proptest::collection::vec("[c-z ]{0,10}", 1..100), num_occurrences in 1..150_usize) { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); // Insert "aaa ccc" the specified number of times into the list for _ in 0..num_occurrences { diff --git a/src/query/range_query/fast_field_range_doc_set.rs b/src/query/range_query/fast_field_range_doc_set.rs index dd4b8fe68..24d2b1fe3 100644 --- a/src/query/range_query/fast_field_range_doc_set.rs +++ b/src/query/range_query/fast_field_range_doc_set.rs @@ -62,6 +62,17 @@ pub(crate) struct RangeDocSet { const DEFAULT_FETCH_HORIZON: u32 = 128; impl RangeDocSet { pub(crate) fn new(value_range: RangeInclusive, column: Column) -> Self { + if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() { + return Self { + value_range, + column, + loaded_docs: VecCursor::new(), + next_fetch_start: TERMINATED, + fetch_horizon: DEFAULT_FETCH_HORIZON, + last_seek_pos_opt: None, + }; + } + let mut range_docset = Self { value_range, column, @@ -81,6 +92,9 @@ impl RangeDocSet { /// Returns true if more data could be fetched fn fetch_block(&mut self) { + if self.next_fetch_start >= self.column.num_docs() { + return; + } const MAX_HORIZON: u32 = 100_000; while self.loaded_docs.is_empty() { let finished_to_end = self.fetch_horizon(self.fetch_horizon); @@ -105,10 +119,10 @@ impl RangeDocSet { fn fetch_horizon(&mut self, horizon: u32) -> bool { let mut finished_to_end = false; - let limit = self.column.num_docs(); - let mut end = self.next_fetch_start + horizon; - if end >= limit { - end = limit; + let num_docs = self.column.num_docs(); + let mut fetch_end = self.next_fetch_start + horizon; + if fetch_end >= num_docs { + fetch_end = num_docs; finished_to_end = true; } @@ -116,7 +130,7 @@ impl RangeDocSet { let doc_buffer: &mut Vec = self.loaded_docs.get_cleared_data(); self.column.get_docids_for_value_range( self.value_range.clone(), - self.next_fetch_start..end, + self.next_fetch_start..fetch_end, doc_buffer, ); if let Some(last_doc) = last_doc { @@ -124,7 +138,7 @@ impl RangeDocSet { self.loaded_docs.next(); } } - self.next_fetch_start = end; + self.next_fetch_start = fetch_end; finished_to_end } @@ -136,9 +150,6 @@ impl DocSet for RangeDocSe if let Some(docid) = self.loaded_docs.next() { return docid; } - if self.next_fetch_start >= self.column.num_docs() { - return TERMINATED; - } self.fetch_block(); self.loaded_docs.current().unwrap_or(TERMINATED) } @@ -174,15 +185,25 @@ impl DocSet for RangeDocSe } fn size_hint(&self) -> u32 { - self.column.num_docs() + // TODO: Implement a better size hint + self.column.num_docs() / 10 } /// Returns a best-effort hint of the /// cost to drive the docset. fn cost(&self) -> u64 { - // Advancing the docset is relatively expensive since it scans the column. - // Keep cost relative to a term query driver; use num_docs as baseline. - self.column.num_docs() as u64 + // Advancing the docset is pretty expensive since it scans the whole column, there is no + // index currently (will change with an kd-tree) + // Since we use SIMD to scan the fast field range query we lower the cost a little bit, + // assuming that we hit 10% of the docs like in size_hint. + // + // If we would return a cost higher than num_docs, we would never choose ff range query as + // the driver in a DocSet, when intersecting a term query with a fast field. But + // it's the faster choice when the term query has a lot of docids and the range + // query has not. + // + // Ideally this would take the fast field codec into account + (self.column.num_docs() as f64 * 0.8) as u64 } } @@ -236,4 +257,52 @@ mod tests { let count = searcher.search(&query, &Count).unwrap(); assert_eq!(count, 500); } + + #[test] + fn range_query_no_overlap_optimization() { + let mut schema_builder = schema::SchemaBuilder::new(); + let id_field = schema_builder.add_text_field("id", schema::STRING); + let value_field = schema_builder.add_u64_field("value", schema::FAST | schema::INDEXED); + + let dir = RamDirectory::default(); + let index = IndexBuilder::new() + .schema(schema_builder.build()) + .open_or_create(dir) + .unwrap(); + + { + let mut writer = index.writer(15_000_000).unwrap(); + + // Add documents with values in the range [10, 20] + for i in 0..100 { + let mut doc = TantivyDocument::new(); + doc.add_text(id_field, format!("doc{i}")); + doc.add_u64(value_field, 10 + (i % 11) as u64); // values in range 10-20 + + writer.add_document(doc).unwrap(); + } + writer.commit().unwrap(); + } + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + // Test a range query [100, 200] that has no overlap with data range [10, 20] + let query = RangeQuery::new( + Bound::Included(Term::from_field_u64(value_field, 100)), + Bound::Included(Term::from_field_u64(value_field, 200)), + ); + + let count = searcher.search(&query, &Count).unwrap(); + assert_eq!(count, 0); // should return 0 results since there's no overlap + + // Test another non-overlapping range: [0, 5] while data range is [10, 20] + let query2 = RangeQuery::new( + Bound::Included(Term::from_field_u64(value_field, 0)), + Bound::Included(Term::from_field_u64(value_field, 5)), + ); + + let count2 = searcher.search(&query2, &Count).unwrap(); + assert_eq!(count2, 0); // should return 0 results since there's no overlap + } } diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 1893a06a5..a597c8dca 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -429,7 +429,7 @@ mod tests { docs.push(doc); } - docs.shuffle(&mut rand::thread_rng()); + docs.shuffle(&mut rand::rng()); let mut docs_it = docs.into_iter(); for doc in (&mut docs_it).take(50) { index_writer.add_document(doc)?; diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index 54cf0cad5..68da73c92 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -491,7 +491,7 @@ mod tests { use common::DateTime; use proptest::prelude::*; use rand::rngs::StdRng; - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; use rand::SeedableRng; use time::format_description::well_known::Rfc3339; use time::OffsetDateTime; @@ -1598,449 +1598,3 @@ pub(crate) mod ip_range_tests { Ok(()) } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - use test::Bencher; - - use super::tests::*; - use super::*; - use crate::collector::Count; - use crate::query::QueryParser; - use crate::Index; - - fn get_index_0_to_100() -> Index { - let mut rng = StdRng::from_seed([1u8; 32]); - let num_vals = 100_000; - let docs: Vec<_> = (0..num_vals) - .map(|_i| { - let id_name = if rng.gen_bool(0.01) { - "veryfew".to_string() // 1% - } else if rng.gen_bool(0.1) { - "few".to_string() // 9% - } else { - "many".to_string() // 90% - }; - Doc { - id_name, - id: rng.gen_range(0..100), - } - }) - .collect(); - - create_index_from_docs(&docs, false) - } - - fn get_90_percent() -> RangeInclusive { - 0..=90 - } - - fn get_10_percent() -> RangeInclusive { - 0..=10 - } - - fn get_1_percent() -> RangeInclusive { - 10..=10 - } - - fn execute_query( - field: &str, - id_range: RangeInclusive, - suffix: &str, - index: &Index, - ) -> usize { - let gen_query_inclusive = |from: &u64, to: &u64| { - format!( - "{}:[{} TO {}] {}", - field, - &from.to_string(), - &to.to_string(), - suffix - ) - }; - - let query = gen_query_inclusive(id_range.start(), id_range.end()); - let query_from_text = |text: &str| { - QueryParser::for_index(index, vec![]) - .parse_query(text) - .unwrap() - }; - let query = query_from_text(&query); - let reader = index.reader().unwrap(); - let searcher = reader.searcher(); - searcher.search(&query, &(Count)).unwrap() - } - - #[bench] - fn bench_id_range_hit_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index)); - } -} - -#[cfg(all(test, feature = "unstable"))] -mod bench_ip { - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - use test::Bencher; - - use super::ip_range_tests::*; - use super::*; - use crate::collector::Count; - use crate::query::QueryParser; - use crate::Index; - - fn get_index_0_to_100() -> Index { - let mut rng = StdRng::from_seed([1u8; 32]); - let num_vals = 100_000; - let docs: Vec<_> = (0..num_vals) - .map(|_i| { - let id = if rng.gen_bool(0.01) { - "veryfew".to_string() // 1% - } else if rng.gen_bool(0.1) { - "few".to_string() // 9% - } else { - "many".to_string() // 90% - }; - Doc { - id, - // Multiply by 1000, so that we create many buckets in the compact space - // The benches depend on this range to select n-percent of elements with the - // methods below. - ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000), - } - }) - .collect(); - - create_index_from_ip_docs(&docs) - } - - fn get_90_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(0); - let end = Ipv6Addr::from_u128(90 * 1000); - start..=end - } - - fn get_10_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(0); - let end = Ipv6Addr::from_u128(10 * 1000); - start..=end - } - - fn get_1_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(10 * 1000); - let end = Ipv6Addr::from_u128(10 * 1000); - start..=end - } - - fn execute_query( - field: &str, - ip_range: RangeInclusive, - suffix: &str, - index: &Index, - ) -> usize { - let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| { - format!( - "{}:[{} TO {}] {}", - field, - &from.to_string(), - &to.to_string(), - suffix - ) - }; - - let query = gen_query_inclusive(ip_range.start(), ip_range.end()); - let query_from_text = |text: &str| { - QueryParser::for_index(index, vec![]) - .parse_query(text) - .unwrap() - }; - let query = query_from_text(&query); - let reader = index.reader().unwrap(); - let searcher = reader.searcher(); - searcher.search(&query, &(Count)).unwrap() - } - - #[bench] - fn bench_ip_range_hit_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:veryfew", &index)); - } -} diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index be9e14692..bed99f5b7 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -56,6 +56,11 @@ where self.req_scorer.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + self.score_cache = None; + self.req_scorer.seek_into_the_danger_zone(target) + } + fn doc(&self) -> DocId { self.req_scorer.doc() } @@ -76,6 +81,7 @@ where TOptScorer: Scorer, TScoreCombiner: ScoreCombiner, { + #[inline] fn score(&mut self) -> Score { if let Some(score) = self.score_cache { return score; diff --git a/src/query/score_combiner.rs b/src/query/score_combiner.rs index a49f8b104..2fe760c3d 100644 --- a/src/query/score_combiner.rs +++ b/src/query/score_combiner.rs @@ -29,6 +29,7 @@ impl ScoreCombiner for DoNothingCombiner { fn clear(&mut self) {} + #[inline] fn score(&self) -> Score { 1.0 } @@ -49,6 +50,7 @@ impl ScoreCombiner for SumCombiner { self.score = 0.0; } + #[inline] fn score(&self) -> Score { self.score } @@ -86,6 +88,7 @@ impl ScoreCombiner for DisjunctionMaxCombiner { self.sum = 0.0; } + #[inline] fn score(&self) -> Score { self.max + (self.sum - self.max) * self.tie_breaker } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 69448042f..e91fc2fbc 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -18,6 +18,7 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { impl_downcast!(Scorer); impl Scorer for Box { + #[inline] fn score(&mut self) -> Score { self.deref_mut().score() } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 5c020febd..6c7c5b17a 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -98,14 +98,17 @@ impl TermScorer { } impl DocSet for TermScorer { + #[inline] fn advance(&mut self) -> DocId { self.postings.advance() } + #[inline] fn seek(&mut self, target: DocId) -> DocId { self.postings.seek(target) } + #[inline] fn doc(&self) -> DocId { self.postings.doc() } @@ -116,6 +119,7 @@ impl DocSet for TermScorer { } impl Scorer for TermScorer { + #[inline] fn score(&mut self) -> Score { let fieldnorm_id = self.fieldnorm_id(); let term_freq = self.term_freq(); @@ -300,10 +304,10 @@ mod tests { let mut writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; use rand::Rng; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); writer.set_merge_policy(Box::new(NoMergePolicy)); for _ in 0..3_000 { - let term_freq = rng.gen_range(1..10000); + let term_freq = rng.random_range(1..10000); let words: Vec<&str> = std::iter::repeat_n("bbbb", term_freq).collect(); let text = words.join(" "); writer.add_document(doc!(text_field=>text))?; diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index 3c726b8a7..ee554e357 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32; // This function is similar except that it does is not unstable, and // it does not keep the original vector ordering. // -// Also, it does not "yield" any elements. +// Elements are dropped and not yielded. fn unordered_drain_filter(v: &mut Vec, mut predicate: P) where P: FnMut(&mut T) -> bool { let mut i = 0; @@ -128,6 +128,7 @@ impl BufferedUnionScorer bool { while self.bucket_idx < HORIZON_NUM_TINYBITSETS { if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() { @@ -143,6 +144,12 @@ impl BufferedUnionScorer bool { + // wrapping_sub, because target may be < window_start_doc + let gap = target.wrapping_sub(self.window_start_doc); + gap < HORIZON + } } impl DocSet for BufferedUnionScorer @@ -150,6 +157,7 @@ where TScorer: Scorer, TScoreCombiner: ScoreCombiner, { + #[inline] fn advance(&mut self) -> DocId { if self.advance_buffered() { return self.doc; @@ -217,8 +225,29 @@ where } } - // TODO Also implement `count` with deletes efficiently. + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + if self.is_in_horizon(target) { + // Our value is within the buffered horizon and the docset may already have been + // processed and removed, so we need to use seek, which uses the regular advance. + self.seek(target) == target + } else { + // The docsets are not in the buffered range, so we can use seek_into_the_danger_zone + // of the underlying docsets + let is_hit = self + .docsets + .iter_mut() + .any(|docset| docset.seek_into_the_danger_zone(target)); + // The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone` + // returns true. + if is_hit { + self.seek(target); + } + is_hit + } + } + + #[inline] fn doc(&self) -> DocId { self.doc } @@ -231,6 +260,7 @@ where self.docsets.iter().map(|docset| docset.cost()).sum() } + // TODO Also implement `count` with deletes efficiently. fn count_including_deleted(&mut self) -> u32 { if self.doc == TERMINATED { return 0; @@ -259,6 +289,7 @@ where TScoreCombiner: ScoreCombiner, TScorer: Scorer, { + #[inline] fn score(&mut self) -> Score { self.score } diff --git a/src/query/union/simple_union.rs b/src/query/union/simple_union.rs index 61cbb94b6..b153a7f22 100644 --- a/src/query/union/simple_union.rs +++ b/src/query/union/simple_union.rs @@ -92,6 +92,7 @@ impl DocSet for SimpleUnion { } fn size_hint(&self) -> u32 { + // TODO: use estimate_union self.docsets .iter() .map(|docset| docset.size_hint()) diff --git a/src/schema/document/owned_value.rs b/src/schema/document/owned_value.rs index 9fbf1f8c2..49a6b1ac7 100644 --- a/src/schema/document/owned_value.rs +++ b/src/schema/document/owned_value.rs @@ -58,6 +58,31 @@ impl AsRef for OwnedValue { } } +impl OwnedValue { + /// Returns a u8 discriminant value for the `OwnedValue` variant. + /// + /// This can be used to sort `OwnedValue` instances by their type. + pub fn discriminant_value(&self) -> u8 { + match self { + OwnedValue::Null => 0, + OwnedValue::Str(_) => 1, + OwnedValue::PreTokStr(_) => 2, + // It is key to make sure U64, I64, F64 are grouped together in there, otherwise we + // might be breaking transivity. + OwnedValue::U64(_) => 3, + OwnedValue::I64(_) => 4, + OwnedValue::F64(_) => 5, + OwnedValue::Bool(_) => 6, + OwnedValue::Date(_) => 7, + OwnedValue::Facet(_) => 8, + OwnedValue::Bytes(_) => 9, + OwnedValue::Array(_) => 10, + OwnedValue::Object(_) => 11, + OwnedValue::IpAddr(_) => 12, + } + } +} + impl<'a> Value<'a> for &'a OwnedValue { type ArrayIter = std::slice::Iter<'a, OwnedValue>; type ObjectIter = ObjectMapIter<'a>; diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 1cd4b7243..c8af359d9 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -98,6 +98,10 @@ //! make it possible to access the value given the doc id rapidly. This is useful if the value //! of the field is required during scoring or collection for instance. //! +//! Some queries may leverage Fast fields when run on a field that is not indexed. This can be +//! handy if that kind of request is infrequent, however note that searching on a Fast field is +//! generally much slower than searching in an index. +//! //! ``` //! use tantivy::schema::*; //! let mut schema_builder = Schema::builder(); diff --git a/src/snippet/mod.rs b/src/snippet/mod.rs index 020e6b588..ee61b534a 100644 --- a/src/snippet/mod.rs +++ b/src/snippet/mod.rs @@ -483,7 +483,7 @@ mod tests { use super::{collapse_overlapped_ranges, search_fragments, select_best_fragment_combination}; use crate::query::QueryParser; - use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions, TEXT}; + use crate::schema::{Schema, TEXT}; use crate::snippet::SnippetGenerator; use crate::tokenizer::{NgramTokenizer, SimpleTokenizer}; use crate::Index; @@ -727,8 +727,10 @@ Survey in 2016, 2017, and 2018."#; Ok(()) } + #[cfg(feature = "stemmer")] #[test] fn test_snippet_generator() -> crate::Result<()> { + use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions}; let mut schema_builder = Schema::builder(); let text_options = TextOptions::default().set_indexing_options( TextFieldIndexing::default() diff --git a/src/store/mod.rs b/src/store/mod.rs index 582643515..cccf4d8f9 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -102,6 +102,7 @@ pub(crate) mod tests { } const NUM_DOCS: usize = 1_000; + #[test] fn test_doc_store_iter_with_delete_bug_1077() -> crate::Result<()> { // this will cover deletion of the first element in a checkpoint @@ -113,7 +114,7 @@ pub(crate) mod tests { let directory = RamDirectory::create(); let store_wrt = directory.open_write(path)?; let schema = - write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::Lz4, BLOCK_SIZE, true); + write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::default(), BLOCK_SIZE, true); let field_title = schema.get_field("title").unwrap(); let store_file = directory.open_read(path)?; let store = StoreReader::open(store_file, 10)?; diff --git a/src/store/reader.rs b/src/store/reader.rs index fb1533988..a4105abec 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -465,7 +465,7 @@ mod tests { let directory = RamDirectory::create(); let path = Path::new("store"); let writer = directory.open_write(path)?; - let schema = write_lorem_ipsum_store(writer, 500, Compressor::default(), BLOCK_SIZE, true); + let schema = write_lorem_ipsum_store(writer, 500, Compressor::None, BLOCK_SIZE, true); let title = schema.get_field("title").unwrap(); let store_file = directory.open_read(path)?; let store = StoreReader::open(store_file, DOCSTORE_CACHE_CAPACITY)?; @@ -499,7 +499,7 @@ mod tests { assert_eq!(store.cache_stats().cache_hits, 1); assert_eq!(store.cache_stats().cache_misses, 2); - assert_eq!(store.cache.peek_lru(), Some(11207)); + assert_eq!(store.cache.peek_lru(), Some(232206)); Ok(()) } diff --git a/src/termdict/fst_termdict/merger.rs b/src/termdict/fst_termdict/merger.rs index e8a064deb..43147a5ae 100644 --- a/src/termdict/fst_termdict/merger.rs +++ b/src/termdict/fst_termdict/merger.rs @@ -95,7 +95,7 @@ impl<'a> TermMerger<'a> { #[cfg(all(test, feature = "unstable"))] mod bench { use rand::distributions::Alphanumeric; - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use test::{self, Bencher}; use super::TermMerger; @@ -117,9 +117,9 @@ mod bench { let buffer: Vec = { let mut terms = vec![]; for _i in 0..num_terms { - let rand_string: String = thread_rng() + let rand_string: String = rng() .sample_iter(&Alphanumeric) - .take(thread_rng().gen_range(30..42)) + .take(rng().random_range(30..42)) .map(char::from) .collect(); terms.push(rand_string); diff --git a/src/tokenizer/mod.rs b/src/tokenizer/mod.rs index 5a5435562..31c518fd4 100644 --- a/src/tokenizer/mod.rs +++ b/src/tokenizer/mod.rs @@ -132,13 +132,14 @@ mod regex_tokenizer; mod remove_long; mod simple_tokenizer; mod split_compound_words; -mod stemmer; mod stop_word_filter; mod tokenized_string; mod tokenizer; mod tokenizer_manager; mod whitespace_tokenizer; +#[cfg(feature = "stemmer")] +mod stemmer; pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer}; pub use self::alphanum_only::AlphaNumOnlyFilter; @@ -151,6 +152,7 @@ pub use self::regex_tokenizer::RegexTokenizer; pub use self::remove_long::RemoveLongFilter; pub use self::simple_tokenizer::{SimpleTokenStream, SimpleTokenizer}; pub use self::split_compound_words::SplitCompoundWords; +#[cfg(feature = "stemmer")] pub use self::stemmer::{Language, Stemmer}; pub use self::stop_word_filter::StopWordFilter; pub use self::tokenized_string::{PreTokenizedStream, PreTokenizedString}; @@ -167,10 +169,7 @@ pub const MAX_TOKEN_LEN: usize = u16::MAX as usize - 5; #[cfg(test)] pub(crate) mod tests { - use super::{ - Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, TokenizerManager, - }; - use crate::tokenizer::TextAnalyzer; + use super::{Token, TokenizerManager}; /// This is a function that can be used in tests and doc tests /// to assert a token's correctness. @@ -205,59 +204,15 @@ pub(crate) mod tests { } #[test] - fn test_en_tokenizer() { + fn test_tokenizer_does_not_exist() { let tokenizer_manager = TokenizerManager::default(); assert!(tokenizer_manager.get("en_doesnotexist").is_none()); - let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); - let mut tokens: Vec = vec![]; - { - let mut add_token = |token: &Token| { - tokens.push(token.clone()); - }; - en_tokenizer - .token_stream("Hello, happy tax payer!") - .process(&mut add_token); - } - - assert_eq!(tokens.len(), 4); - assert_token(&tokens[0], 0, "hello", 0, 5); - assert_token(&tokens[1], 1, "happi", 7, 12); - assert_token(&tokens[2], 2, "tax", 13, 16); - assert_token(&tokens[3], 3, "payer", 17, 22); - } - - #[test] - fn test_non_en_tokenizer() { - let tokenizer_manager = TokenizerManager::default(); - tokenizer_manager.register( - "el_stem", - TextAnalyzer::builder(SimpleTokenizer::default()) - .filter(RemoveLongFilter::limit(40)) - .filter(LowerCaser) - .filter(Stemmer::new(Language::Greek)) - .build(), - ); - let mut en_tokenizer = tokenizer_manager.get("el_stem").unwrap(); - let mut tokens: Vec = vec![]; - { - let mut add_token = |token: &Token| { - tokens.push(token.clone()); - }; - en_tokenizer - .token_stream("Καλημέρα, χαρούμενε φορολογούμενε!") - .process(&mut add_token); - } - - assert_eq!(tokens.len(), 3); - assert_token(&tokens[0], 0, "καλημερ", 0, 16); - assert_token(&tokens[1], 1, "χαρουμεν", 18, 36); - assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63); } #[test] fn test_tokenizer_empty() { let tokenizer_manager = TokenizerManager::default(); - let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); + let mut en_tokenizer = tokenizer_manager.get("default").unwrap(); { let mut tokens: Vec = vec![]; { diff --git a/src/tokenizer/stemmer.rs b/src/tokenizer/stemmer.rs index fc87440ce..764efc2ee 100644 --- a/src/tokenizer/stemmer.rs +++ b/src/tokenizer/stemmer.rs @@ -142,3 +142,60 @@ impl TokenStream for StemmerTokenStream { self.tail.token_mut() } } + +#[cfg(test)] +mod tests { + use tokenizer_api::Token; + + use super::*; + use crate::tokenizer::tests::assert_token; + use crate::tokenizer::{LowerCaser, SimpleTokenizer, TextAnalyzer, TokenizerManager}; + + #[test] + fn test_en_stem() { + let tokenizer_manager = TokenizerManager::default(); + let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); + let mut tokens: Vec = vec![]; + { + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + en_tokenizer + .token_stream("Dogs are the bests!") + .process(&mut add_token); + } + + assert_eq!(tokens.len(), 4); + assert_token(&tokens[0], 0, "dog", 0, 4); + assert_token(&tokens[1], 1, "are", 5, 8); + assert_token(&tokens[2], 2, "the", 9, 12); + assert_token(&tokens[3], 3, "best", 13, 18); + } + + #[test] + fn test_non_en_stem() { + let tokenizer_manager = TokenizerManager::default(); + tokenizer_manager.register( + "el_stem", + TextAnalyzer::builder(SimpleTokenizer::default()) + .filter(LowerCaser) + .filter(Stemmer::new(Language::Greek)) + .build(), + ); + let mut el_tokenizer = tokenizer_manager.get("el_stem").unwrap(); + let mut tokens: Vec = vec![]; + { + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + el_tokenizer + .token_stream("Καλημέρα, χαρούμενε φορολογούμενε!") + .process(&mut add_token); + } + + assert_eq!(tokens.len(), 3); + assert_token(&tokens[0], 0, "καλημερ", 0, 16); + assert_token(&tokens[1], 1, "χαρουμεν", 18, 36); + assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63); + } +} diff --git a/src/tokenizer/tokenizer_manager.rs b/src/tokenizer/tokenizer_manager.rs index a0bdbcc0c..8bdbba7bd 100644 --- a/src/tokenizer/tokenizer_manager.rs +++ b/src/tokenizer/tokenizer_manager.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use crate::tokenizer::stemmer::Language; use crate::tokenizer::tokenizer::TextAnalyzer; use crate::tokenizer::{ - LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, Stemmer, WhitespaceTokenizer, + LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, WhitespaceTokenizer, }; /// The tokenizer manager serves as a store for @@ -64,14 +63,18 @@ impl Default for TokenizerManager { .filter(LowerCaser) .build(), ); - manager.register( - "en_stem", - TextAnalyzer::builder(SimpleTokenizer::default()) - .filter(RemoveLongFilter::limit(40)) - .filter(LowerCaser) - .filter(Stemmer::new(Language::English)) - .build(), - ); + #[cfg(feature = "stemmer")] + { + use crate::tokenizer::stemmer::{Language, Stemmer}; + manager.register( + "en_stem", + TextAnalyzer::builder(SimpleTokenizer::default()) + .filter(RemoveLongFilter::limit(40)) + .filter(LowerCaser) // The stemmer does not lowercase + .filter(Stemmer::new(Language::English)) + .build(), + ); + } manager.register("whitespace", WhitespaceTokenizer::default()); manager } diff --git a/sstable/Cargo.toml b/sstable/Cargo.toml index 7b353cece..813692e26 100644 --- a/sstable/Cargo.toml +++ b/sstable/Cargo.toml @@ -25,7 +25,7 @@ zstd-compression = ["zstd"] proptest = "1" criterion = { version = "0.5", default-features = false } names = "0.14" -rand = "0.8" +rand = "0.9" [[bench]] name = "stream_bench" diff --git a/sstable/benches/stream_bench.rs b/sstable/benches/stream_bench.rs index cffe41e26..70dcdd8e3 100644 --- a/sstable/benches/stream_bench.rs +++ b/sstable/benches/stream_bench.rs @@ -10,9 +10,9 @@ use tantivy_sstable::{Dictionary, MonotonicU64SSTable}; const CHARSET: &[u8] = b"abcdefghij"; fn generate_key(rng: &mut impl Rng) -> String { - let len = rng.gen_range(3..12); + let len = rng.random_range(3..12); std::iter::from_fn(|| { - let idx = rng.gen_range(0..CHARSET.len()); + let idx = rng.random_range(0..CHARSET.len()); Some(CHARSET[idx] as char) }) .take(len) diff --git a/stacker/Cargo.toml b/stacker/Cargo.toml index c78c23051..81388bdfd 100644 --- a/stacker/Cargo.toml +++ b/stacker/Cargo.toml @@ -11,7 +11,6 @@ description = "term hashmap used for indexing" murmurhash32 = "0.3" common = { version = "0.10", path = "../common/", package = "tantivy-common" } ahash = { version = "0.8.11", default-features = false, optional = true } -rand_distr = "0.4.3" [[bench]] @@ -24,11 +23,12 @@ name = "hashmap" path = "example/hashmap.rs" [dev-dependencies] -rand = "0.8.5" +rand = "0.9" zipf = "7.0.0" rustc-hash = "2.1.0" proptest = "1.2.0" binggan = { version = "0.14.0" } +rand_distr = "0.5" [features] compare_hash_only = ["ahash"] # Compare hash only, not the key in the Hashmap diff --git a/stacker/benches/bench.rs b/stacker/benches/bench.rs index ed5ea5eeb..03f801308 100644 --- a/stacker/benches/bench.rs +++ b/stacker/benches/bench.rs @@ -90,10 +90,10 @@ fn bench_vint() { } // benchmark zipfs distribution numbers { - use rand::distributions::Distribution; + use rand::distr::Distribution; use rand::rngs::StdRng; let mut rng = StdRng::from_seed([3u8; 32]); - let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap(); + let zipf = rand_distr::Zipf::new(10_000.0f64, 1.03).unwrap(); let numbers: Vec<[u8; 8]> = (0..num_numbers) .map(|_| zipf.sample(&mut rng).to_le_bytes()) .collect(); diff --git a/stacker/fuzz_test/Cargo.toml b/stacker/fuzz_test/Cargo.toml index 02478c95b..f71b36a37 100644 --- a/stacker/fuzz_test/Cargo.toml +++ b/stacker/fuzz_test/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" [dependencies] ahash = "0.8.7" -rand = "0.8.5" -rand_distr = "0.4.3" +rand = "0.9" +rand_distr = "0.5" tantivy-stacker = { version = "0.2.0", path = ".." } [workspace] diff --git a/stacker/fuzz_test/src/main.rs b/stacker/fuzz_test/src/main.rs index 2367ddc33..efe72d921 100644 --- a/stacker/fuzz_test/src/main.rs +++ b/stacker/fuzz_test/src/main.rs @@ -14,7 +14,7 @@ fn test_with_seed(seed: u64) { let mut hash_map = AHashMap::new(); let mut arena_hashmap = ArenaHashMap::default(); let mut rng = StdRng::seed_from_u64(seed); - let key_count = rng.gen_range(1_000..=1_000_000); + let key_count = rng.random_range(1_000..=1_000_000); let exp = Exp::new(0.05).unwrap(); for _ in 0..key_count { diff --git a/stacker/src/expull.rs b/stacker/src/expull.rs index 5fd00db3d..3b6353b38 100644 --- a/stacker/src/expull.rs +++ b/stacker/src/expull.rs @@ -5,7 +5,7 @@ use common::serialize_vint_u32; use crate::fastcpy::fast_short_slice_copy; use crate::{Addr, MemoryArena}; -const FIRST_BLOCK_NUM: u16 = 2; +const FIRST_BLOCK_NUM: u32 = 2; /// An exponential unrolled link. /// @@ -33,8 +33,8 @@ pub struct ExpUnrolledLinkedList { // u16, since the max size of each block is (1< { } } -// The block size is 2^block_num + 2, but max 2^15= 32k -// Initial size is 8, for the first block => block_num == 1 +// The block size is 2^block_num, but max 2^15 = 32KB +// Initial size is 8 bytes (2^3), for the first block => block_num == 2 +// Block size caps at 32KB (2^15) regardless of how high block_num goes #[inline] -fn get_block_size(block_num: u16) -> u16 { - 1 << block_num.min(15) +fn get_block_size(block_num: u32) -> u16 { + // Cap at 15 to prevent block sizes > 32KB + // block_num can now be much larger than 15, but block size maxes out + let exp: u32 = block_num.min(15u32); + (1u32 << exp) as u16 } impl ExpUnrolledLinkedList { + #[inline(always)] pub fn increment_num_blocks(&mut self) { - self.block_num += 1; + // Add overflow check as a safety measure + // With u32, we can handle up to ~4 billion blocks before overflow + // At 32KB per block (max size), that's 128 TB of data + self.block_num = self + .block_num + .checked_add(1) + .expect("ExpUnrolledLinkedList block count overflow - exceeded 4 billion blocks"); } #[inline] @@ -132,9 +143,26 @@ impl ExpUnrolledLinkedList { if addr.is_null() { return; } - let last_block_len = get_block_size(self.block_num) as usize - self.remaining_cap as usize; - // Full Blocks + // Calculate last block length with bounds checking to prevent underflow + let block_size = get_block_size(self.block_num) as usize; + let last_block_len = block_size.saturating_sub(self.remaining_cap as usize); + + // Safety check: if remaining_cap > block_size, the metadata is corrupted + assert!( + self.remaining_cap as usize <= block_size, + "ExpUnrolledLinkedList metadata corruption detected: remaining_cap ({}) > block_size \ + ({}). This indicates a serious bug, please report! (block_num={}, head={:?}, \ + tail={:?})", + self.remaining_cap, + block_size, + self.block_num, + self.head, + self.tail + ); + + // Full Blocks (iterate through all blocks except the last one) + // Note: Blocks are numbered starting from FIRST_BLOCK_NUM+1 (=3) after first allocation for block_num in FIRST_BLOCK_NUM + 1..self.block_num { let cap = get_block_size(block_num) as usize; let data = arena.slice(addr, cap); @@ -259,6 +287,180 @@ mod tests { assert_eq!(&vec1[..], &res1[..]); assert_eq!(&vec2[..], &res2[..]); } + + // Tests for u32 block_num fix (issue with large arrays) + + #[test] + fn test_block_num_exceeds_u16_max() { + // Test that we can handle more than 65,535 blocks (old u16 limit) + let mut eull = ExpUnrolledLinkedList::default(); + + // Simulate allocating 70,000 blocks (exceeds u16::MAX of 65,535) + for _ in 0..70_000 { + eull.increment_num_blocks(); + } + + // Verify block_num is correct + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 70_000); + + // Verify we can still get block size (should be capped at 32KB) + let block_size = get_block_size(eull.block_num); + assert_eq!(block_size, 1 << 15); // 32KB max + } + + #[test] + #[allow(clippy::needless_range_loop)] + fn test_large_dataset_simulation() { + // Simulate the scenario: large arrays requiring many blocks + // We write enough data to require thousands of blocks + let mut arena = MemoryArena::default(); + let mut eull = ExpUnrolledLinkedList::default(); + + // Write 100 MB of data (this will require ~3,200 blocks at 32KB each) + // This is enough to validate the system works with large datasets + // but not so much that the test is slow + let bytes_per_write = 10_000; + let num_writes = 10_000; // 10k * 10k = 100 MB + + let data: Vec = (0..bytes_per_write).map(|i| (i % 256) as u8).collect(); + for _ in 0..num_writes { + eull.writer(&mut arena).extend_from_slice(&data); + } + + // Verify we allocated many blocks (should be in the thousands) + assert!( + eull.block_num > 1000, + "block_num ({}) should be > 1000 for this much data", + eull.block_num + ); + + // Verify we can read back correctly + let mut buffer = Vec::new(); + eull.read_to_end(&arena, &mut buffer); + assert_eq!(buffer.len(), bytes_per_write * num_writes); + + // Verify data integrity on a sample + for i in 0..bytes_per_write { + assert_eq!(buffer[i], (i % 256) as u8); + } + } + + #[test] + fn test_get_block_size_with_large_block_num() { + // Test that get_block_size handles large u32 values correctly + + // Small block numbers (under 15) + assert_eq!(get_block_size(2), 4); // 2^2 = 4 + assert_eq!(get_block_size(3), 8); // 2^3 = 8 + assert_eq!(get_block_size(10), 1024); // 2^10 = 1KB + + // At the cap (15) + assert_eq!(get_block_size(15), 32768); // 2^15 = 32KB + + // Beyond the cap (should stay at 32KB) + assert_eq!(get_block_size(16), 32768); + assert_eq!(get_block_size(100), 32768); + assert_eq!(get_block_size(65_536), 32768); // Old u16::MAX + 1 + assert_eq!(get_block_size(100_000), 32768); + assert_eq!(get_block_size(1_000_000), 32768); + } + + #[test] + fn test_increment_blocks_near_u16_boundary() { + // Test incrementing around the old u16::MAX boundary + let mut eull = ExpUnrolledLinkedList::default(); + + // Set to just before old limit + for _ in 0..65_533 { + eull.increment_num_blocks(); + } + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_533); + + // Cross the old u16::MAX boundary (this would have overflowed before) + eull.increment_num_blocks(); // 65,534 + eull.increment_num_blocks(); // 65,535 (old max) + eull.increment_num_blocks(); // 65,536 (would overflow u16) + eull.increment_num_blocks(); // 65,537 + + // Verify we're past the old limit + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_537); + } + + #[test] + fn test_write_and_read_with_many_blocks() { + // Test that write/read works correctly with many blocks + let mut arena = MemoryArena::default(); + let mut eull = ExpUnrolledLinkedList::default(); + + // Write data that will span many blocks + let test_data: Vec = (0..50_000).map(|i| (i % 256) as u8).collect(); + eull.writer(&mut arena).extend_from_slice(&test_data); + + // Read it back + let mut buffer = Vec::new(); + eull.read_to_end(&arena, &mut buffer); + + // Verify data integrity + assert_eq!(buffer.len(), test_data.len()); + assert_eq!(&buffer[..], &test_data[..]); + } + + #[test] + fn test_multiple_eull_with_large_block_counts() { + // Test multiple ExpUnrolledLinkedLists with high block counts + // (simulates parallel columnar writes) + let mut arena = MemoryArena::default(); + let mut eull1 = ExpUnrolledLinkedList::default(); + let mut eull2 = ExpUnrolledLinkedList::default(); + + // Write different data to each + for i in 0..10_000u32 { + eull1.writer(&mut arena).write_u32_vint(i); + eull2.writer(&mut arena).write_u32_vint(i * 2); + } + + // Read back and verify + let mut buf1 = Vec::new(); + let mut buf2 = Vec::new(); + eull1.read_to_end(&arena, &mut buf1); + eull2.read_to_end(&arena, &mut buf2); + + // Deserialize and check + let mut cursor1 = &buf1[..]; + let mut cursor2 = &buf2[..]; + for i in 0..10_000u32 { + assert_eq!(read_u32_vint(&mut cursor1), i); + assert_eq!(read_u32_vint(&mut cursor2), i * 2); + } + } + + #[test] + fn test_block_size_stays_capped() { + // Verify that even with massive block numbers, size stays at 32KB + let mut eull = ExpUnrolledLinkedList::default(); + + // Increment to a very large number + for _ in 0..200_000 { + eull.increment_num_blocks(); + } + + let block_size = get_block_size(eull.block_num); + assert_eq!(block_size, 32768, "Block size should be capped at 32KB"); + } + + #[test] + #[should_panic(expected = "ExpUnrolledLinkedList block count overflow")] + fn test_increment_overflow_protection() { + // Test that we panic gracefully if we somehow hit u32::MAX + // This is extremely unlikely in practice (would require 128TB of data) + let mut eull = ExpUnrolledLinkedList { + block_num: u32::MAX, + ..Default::default() + }; + + // This should panic with our custom error message + eull.increment_num_blocks(); + } } #[cfg(all(test, feature = "unstable"))]