From 8a1079b2dcccbbff0f1bf4c04085f1918b420703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20No=C3=ABl?= <21990816+philippemnoel@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:10:42 -0500 Subject: [PATCH 01/26] expose AddOperation and with_max_doc (#7) (#2762) Co-authored-by: Ming --- src/index/segment.rs | 2 +- src/indexer/mod.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/index/segment.rs b/src/index/segment.rs index 4c9382cb0..fcd32a1ff 100644 --- a/src/index/segment.rs +++ b/src/index/segment.rs @@ -46,7 +46,7 @@ impl Segment { /// /// This method is only used when updating `max_doc` from 0 /// as we finalize a fresh new segment. - pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment { + pub fn with_max_doc(self, max_doc: u32) -> Segment { Segment { index: self.index, meta: self.meta.with_max_doc(max_doc), diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index 2d86aa461..ee53bdc7a 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -36,8 +36,7 @@ pub use self::index_writer::{IndexWriter, IndexWriterOptions}; pub use self::log_merge_policy::LogMergePolicy; pub use self::merge_operation::MergeOperation; pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy}; -use self::operation::AddOperation; -pub use self::operation::UserOperation; +pub use self::operation::{AddOperation, UserOperation}; pub use self::prepared_commit::PreparedCommit; pub use self::segment_entry::SegmentEntry; pub(crate) use self::segment_serializer::SegmentSerializer; From 14cc24614e0cfc82086461a4f97a4f04447c1027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20No=C3=ABl?= <21990816+philippemnoel@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:11:03 -0500 Subject: [PATCH 02/26] Make DeleteMeta pub (#2765) Co-authored-by: Ming Ying --- src/index/index_meta.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/index/index_meta.rs b/src/index/index_meta.rs index 86eaa35d6..d95ce6ff7 100644 --- a/src/index/index_meta.rs +++ b/src/index/index_meta.rs @@ -13,9 +13,9 @@ use crate::store::Compressor; use crate::{Inventory, Opstamp, TrackedObject}; #[derive(Clone, Debug, Serialize, Deserialize)] -struct DeleteMeta { +pub struct DeleteMeta { num_deleted_docs: u32, - opstamp: Opstamp, + pub opstamp: Opstamp, } #[derive(Clone, Default)] @@ -213,7 +213,7 @@ impl SegmentMeta { struct InnerSegmentMeta { segment_id: SegmentId, max_doc: u32, - deletes: Option, + pub deletes: Option, /// If you want to avoid the SegmentComponent::TempStore file to be covered by /// garbage collection and deleted, set this to true. This is used during merge. #[serde(skip)] From 22dde8f9aee31244822697690446ea161fdbf715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20No=C3=ABl?= <21990816+philippemnoel@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:22:15 -0500 Subject: [PATCH 03/26] chore: Make some delete-related functions public (#46) (#2766) Co-authored-by: Ming --- src/indexer/index_writer.rs | 2 +- src/indexer/mod.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 1ba92d6de..4ce5e1db5 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -128,7 +128,7 @@ fn compute_deleted_bitset( /// is `==` target_opstamp. /// For instance, there was no delete operation between the state of the `segment_entry` and /// the `target_opstamp`, `segment_entry` is not updated. -pub(crate) fn advance_deletes( +pub fn advance_deletes( mut segment: Segment, segment_entry: &mut SegmentEntry, target_opstamp: Opstamp, diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index ee53bdc7a..a6d3cab38 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -4,7 +4,7 @@ //! `IndexWriter` is the main entry point for that, which created from //! [`Index::writer`](crate::Index::writer). -pub(crate) mod delete_queue; +pub mod delete_queue; pub(crate) mod path_to_unordered_id; pub(crate) mod doc_id_mapping; @@ -32,11 +32,11 @@ mod stamper; use crossbeam_channel as channel; use smallvec::SmallVec; -pub use self::index_writer::{IndexWriter, IndexWriterOptions}; +pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions}; pub use self::log_merge_policy::LogMergePolicy; pub use self::merge_operation::MergeOperation; pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy}; -pub use self::operation::{AddOperation, UserOperation}; +pub use self::operation::{AddOperation, DeleteOperation, UserOperation}; pub use self::prepared_commit::PreparedCommit; pub use self::segment_entry::SegmentEntry; pub(crate) use self::segment_serializer::SegmentSerializer; From 5ba0031f7d7383cf04c7261f1f9b3dede2ecf053 Mon Sep 17 00:00:00 2001 From: PSeitz-dd Date: Thu, 11 Dec 2025 11:23:50 +0100 Subject: [PATCH 04/26] move rand_distr to dev_dep (#2772) --- stacker/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stacker/Cargo.toml b/stacker/Cargo.toml index c78c23051..38b293dff 100644 --- a/stacker/Cargo.toml +++ b/stacker/Cargo.toml @@ -11,7 +11,6 @@ description = "term hashmap used for indexing" murmurhash32 = "0.3" common = { version = "0.10", path = "../common/", package = "tantivy-common" } ahash = { version = "0.8.11", default-features = false, optional = true } -rand_distr = "0.4.3" [[bench]] @@ -29,6 +28,7 @@ zipf = "7.0.0" rustc-hash = "2.1.0" proptest = "1.2.0" binggan = { version = "0.14.0" } +rand_distr = "0.4.3" [features] compare_hash_only = ["ahash"] # Compare hash only, not the key in the Hashmap From e9020d17d4d9dfbc9d6a536bcb85ebae93fc4376 Mon Sep 17 00:00:00 2001 From: PSeitz-dd Date: Thu, 11 Dec 2025 11:35:58 +0100 Subject: [PATCH 05/26] fix coverage (#2769) --- .github/workflows/coverage.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 7201897d7..95167ba41 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -15,11 +15,11 @@ jobs: steps: - uses: actions/checkout@v4 - name: Install Rust - run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview + run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview - uses: Swatinem/rust-cache@v2 - uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info + run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 continue-on-error: true From d0e16001357b0238645d2e09db59e913b09fee07 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 14 Dec 2025 10:10:45 +0100 Subject: [PATCH 06/26] fix bug with minimum_should_match and AllScorer (#2774) --- src/query/all_query.rs | 6 ++- src/query/boolean_query/boolean_weight.rs | 8 ++-- src/query/boolean_query/mod.rs | 47 ++++++++++++++++++++++- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 11172f9ed..16a83ec56 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -23,7 +23,11 @@ pub struct AllWeight; impl Weight for AllWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { let all_scorer = AllScorer::new(reader.max_doc()); - Ok(Box::new(BoostScorer::new(all_scorer, boost))) + if boost != 1.0 { + Ok(Box::new(BoostScorer::new(all_scorer, boost))) + } else { + Ok(Box::new(all_scorer)) + } } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 9e8cedf2e..a39249130 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -193,18 +193,18 @@ impl BooleanWeight { return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - let minimum_number_should_match = self + let effective_minimum_number_should_match = self .minimum_number_should_match .saturating_sub(should_special_scorer_counts.num_all_scorers); let should_scorers: ShouldScorersCombinationMethod = { let num_of_should_scorers = should_scorers.len(); - if minimum_number_should_match > num_of_should_scorers { + if effective_minimum_number_should_match > num_of_should_scorers { // We don't have enough scorers to satisfy the minimum number of should matches. // The request will match no documents. return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - match minimum_number_should_match { + match effective_minimum_number_should_match { 0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored, 0 => ShouldScorersCombinationMethod::Optional(scorer_union( should_scorers, @@ -226,7 +226,7 @@ impl BooleanWeight { scorer_disjunction( should_scorers, score_combiner_fn(), - self.minimum_number_should_match, + effective_minimum_number_should_match, ), )), } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 0ddc5a26c..5dc042c46 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -9,12 +9,14 @@ pub use self::boolean_weight::BooleanWeight; #[cfg(test)] mod tests { + use std::ops::Bound; + use super::*; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; - use crate::collector::TopDocs; + use crate::collector::{Count, TopDocs}; use crate::query::term_query::TermScorer; use crate::query::{ - AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, + AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery, RequiredOptionalScorer, Scorer, SumCombiner, TermQuery, }; use crate::schema::*; @@ -374,4 +376,45 @@ mod tests { } Ok(()) } + + #[test] + pub fn test_min_should_match_with_all_query() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + let effective_all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "apple"), + IndexRecordOption::Basic, + )); + + // in some previous version, we would remove the 2 all_match, but then say we need *4* + // matches out of the 3 term queries, which matches nothing. + let mut bool_query = BooleanQuery::new(vec![ + (Occur::Should, effective_all_match_query.box_clone()), + (Occur::Should, effective_all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + (Occur::Should, term_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + bool_query.set_minimum_number_should_match(4); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 1); + + Ok(()) + } } From ba61ed6ef3e46b06764db7179efece1502d895eb Mon Sep 17 00:00:00 2001 From: Ming Date: Tue, 16 Dec 2025 16:50:41 -0500 Subject: [PATCH 07/26] fix: vint buffer can overflow (#2778) * fix vint overflow * comment --- src/postings/compression/mod.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index 487da620c..6b7b0de9f 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -1,8 +1,10 @@ use bitpacking::{BitPacker, BitPacker4x}; -use common::FixedSize; pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN; -const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES; +// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to +// store a u32 in the worst case, rounding up to 5 bytes total +const MAX_VINT_SIZE: usize = 5; +const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE; mod vint; @@ -267,7 +269,6 @@ impl VIntDecoder for BlockDecoder { #[cfg(test)] pub(crate) mod tests { - use super::*; use crate::TERMINATED; @@ -372,6 +373,13 @@ pub(crate) mod tests { } } } + + #[test] + fn test_compress_vint_unsorted_does_not_overflow() { + let mut encoder = BlockEncoder::new(); + let input: Vec = vec![u32::MAX; COMPRESSION_BLOCK_SIZE]; + encoder.compress_vint_unsorted(&input); + } } #[cfg(all(test, feature = "unstable"))] From e3c9be1f92b35f74369cae374fb6da6c8dec7d22 Mon Sep 17 00:00:00 2001 From: Moe Date: Tue, 16 Dec 2025 13:52:02 -0800 Subject: [PATCH 08/26] fix: boolean query incorrectly dropping documents when AllScorer is present (#2760) * Fixed the range issue. * Fixed the second all scorer issue * Improved docs + tests * Improved code. * Fixed lint issues. * Improved tests + logic based on PR comments. * Fixed lint issues. * Increase the document count. * Improved the prop-tests * Expand the index size, and remove unused parameter. --------- Co-authored-by: Stu Hood --- src/query/boolean_query/boolean_weight.rs | 166 ++++++--- src/query/boolean_query/mod.rs | 422 ++++++++++++++++++++++ 2 files changed, 547 insertions(+), 41 deletions(-) diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index a39249130..c46e9b0b1 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -97,6 +97,65 @@ fn into_box_scorer( } } +/// Returns the effective MUST scorer, accounting for removed AllScorers. +/// +/// When AllScorer instances are removed from must_scorers as an optimization, +/// we must restore the "match all" semantics if the list becomes empty. +fn effective_must_scorer( + must_scorers: Vec>, + removed_all_scorer_count: usize, + max_doc: DocId, + num_docs: u32, +) -> Option> { + if must_scorers.is_empty() { + if removed_all_scorer_count > 0 { + // Had AllScorer(s) only - all docs match + Some(Box::new(AllScorer::new(max_doc))) + } else { + // No MUST constraint at all + None + } + } else { + Some(intersect_scorers(must_scorers, num_docs)) + } +} + +/// Returns a SHOULD scorer with AllScorer union if any were removed. +/// +/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result +/// should include all documents. We restore this by unioning with AllScorer. +/// +/// When `scoring_enabled` is false, we can just return AllScorer alone since +/// we don't need score contributions from the should_scorer. +fn effective_should_scorer_for_union( + should_scorer: SpecializedScorer, + removed_all_scorer_count: usize, + max_doc: DocId, + num_docs: u32, + score_combiner_fn: impl Fn() -> TScoreCombiner, + scoring_enabled: bool, +) -> SpecializedScorer { + if removed_all_scorer_count > 0 { + if scoring_enabled { + // Need to union to get score contributions from both + let all_scorers: Vec> = vec![ + into_box_scorer(should_scorer, &score_combiner_fn, num_docs), + Box::new(AllScorer::new(max_doc)), + ]; + SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( + all_scorers, + score_combiner_fn, + num_docs, + ))) + } else { + // Scoring disabled - AllScorer alone is sufficient + SpecializedScorer::Other(Box::new(AllScorer::new(max_doc))) + } + } else { + should_scorer + } +} + enum ShouldScorersCombinationMethod { // Should scorers are irrelevant. Ignored, @@ -246,53 +305,78 @@ impl BooleanWeight { let include_scorer = match (should_scorers, must_scorers) { (ShouldScorersCombinationMethod::Ignored, must_scorers) => { - let boxed_scorer: Box = if must_scorers.is_empty() { - // We do not have any should scorers, nor all scorers. - // There are still two cases here. - // - // If this follows the removal of some AllScorers in the should/must clauses, - // then we match all documents. - // - // Otherwise, it is really just an EmptyScorer. - if must_special_scorer_counts.num_all_scorers - + should_special_scorer_counts.num_all_scorers - > 0 - { - Box::new(AllScorer::new(reader.max_doc())) - } else { - Box::new(EmptyScorer) - } - } else { - intersect_scorers(must_scorers, num_docs) - }; + // No SHOULD clauses (or they were absorbed into MUST). + // Result depends entirely on MUST + any removed AllScorers. + let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers + + should_special_scorer_counts.num_all_scorers; + let boxed_scorer: Box = effective_must_scorer( + must_scorers, + combined_all_scorer_count, + reader.max_doc(), + num_docs, + ) + .unwrap_or_else(|| Box::new(EmptyScorer)); SpecializedScorer::Other(boxed_scorer) } (ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => { - if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 { - // Optional options are promoted to required if no must scorers exists. - should_scorer - } else { - let must_scorer = intersect_scorers(must_scorers, num_docs); - if self.scoring_enabled { - SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< - _, - _, - TScoreCombiner, - >::new( - must_scorer, - into_box_scorer(should_scorer, &score_combiner_fn, num_docs), - ))) - } else { - SpecializedScorer::Other(must_scorer) + // Optional SHOULD: contributes to scoring but not required for matching. + match effective_must_scorer( + must_scorers, + must_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + ) { + None => { + // No MUST constraint: promote SHOULD to required. + // Must preserve any removed AllScorers from SHOULD via union. + effective_should_scorer_for_union( + should_scorer, + should_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + &score_combiner_fn, + self.scoring_enabled, + ) + } + Some(must_scorer) => { + // Has MUST constraint: SHOULD only affects scoring. + if self.scoring_enabled { + SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< + _, + _, + TScoreCombiner, + >::new( + must_scorer, + into_box_scorer(should_scorer, &score_combiner_fn, num_docs), + ))) + } else { + SpecializedScorer::Other(must_scorer) + } } } } - (ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => { - if must_scorers.is_empty() { - should_scorer - } else { - must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs)); - SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) + (ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => { + // Required SHOULD: at least `minimum_number_should_match` must match. + // Semantics: (MUST constraint) AND (SHOULD constraint) + match effective_must_scorer( + must_scorers, + must_special_scorer_counts.num_all_scorers, + reader.max_doc(), + num_docs, + ) { + None => { + // No MUST constraint: SHOULD alone determines matching. + should_scorer + } + Some(must_scorer) => { + // Has MUST constraint: intersect MUST with SHOULD. + let should_boxed = + into_box_scorer(should_scorer, &score_combiner_fn, num_docs); + SpecializedScorer::Other(intersect_scorers( + vec![must_scorer, should_boxed], + num_docs, + )) + } } } }; diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 5dc042c46..2d7936f00 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -417,4 +417,426 @@ mod tests { Ok(()) } + + // ========================================================================= + // AllScorer Preservation Regression Tests + // ========================================================================= + // + // These tests verify the fix for a bug where AllScorer instances (produced by + // queries matching all documents, such as range queries covering all values) + // were incorrectly removed from Boolean query processing, causing documents + // to be unexpectedly excluded from results. + // + // The bug manifested in several scenarios: + // 1. SHOULD + SHOULD where one clause is AllScorer + // 2. MUST (AllScorer) + SHOULD + // 3. Range queries in Boolean clauses when all documents match the range + + /// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses. + /// + /// When a SHOULD clause produces an AllScorer (e.g., from a range query matching + /// all documents), the Boolean query should still match all documents. + /// + /// Bug before fix: AllScorer was removed during optimization, leaving only the + /// other SHOULD clauses, which incorrectly excluded documents. + #[test] + pub fn test_should_with_all_scorer_regression() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range query will return AllScorer + index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?; + index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?; + index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?; + index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Range query matching all docs (returns AllScorer) + let all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "hello"), + IndexRecordOption::Basic, + )); + + // Verify range matches all 6 docs + assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6); + + // RangeQuery(all) OR TermQuery should match all 6 docs + let bool_query = BooleanQuery::new(vec![ + (Occur::Should, all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 6, "SHOULD with AllScorer should match all docs"); + + // Order should not matter + let bool_query_reversed = BooleanQuery::new(vec![ + (Occur::Should, term_query.box_clone()), + (Occur::Should, all_match_query.box_clone()), + ]); + let count_reversed = searcher.search(&bool_query_reversed, &Count)?; + assert_eq!( + count_reversed, 6, + "Order of SHOULD clauses should not matter" + ); + + Ok(()) + } + + /// Regression test: MUST clause with AllScorer combined with SHOULD clause. + /// + /// When MUST contains an AllScorer, all documents satisfy the MUST constraint. + /// The SHOULD clause should only affect scoring, not filtering. + /// + /// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector. + /// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents. + #[test] + pub fn test_must_all_with_should_regression() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range query will return AllScorer + index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?; + index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Range query matching all docs (returns AllScorer) + let all_match_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "apple"), + IndexRecordOption::Basic, + )); + + // Verify range matches all 4 docs + assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4); + + // MUST(range matching all) AND SHOULD(term) should match all 4 docs + let bool_query = BooleanQuery::new(vec![ + (Occur::Must, all_match_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + let count = searcher.search(&bool_query, &Count)?; + assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs"); + + Ok(()) + } + + /// Regression test: Range queries in Boolean clauses when all documents match. + /// + /// Range queries can return AllScorer as an optimization when all indexed values + /// fall within the range. This test ensures such queries work correctly in + /// Boolean combinations. + /// + /// This is the most common real-world manifestation of the bug, occurring in + /// queries like: (age > 50 OR name = 'Alice') AND status = 'active' + /// when all documents have age > 50. + #[test] + pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let name_field = schema_builder.add_text_field("name", TEXT); + let age_field = + schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All documents have age > 50, so range query will return AllScorer + index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?; + index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?; + index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?; + index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + let range_query: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(age_field, 50)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(name_field, "alice"), + IndexRecordOption::Basic, + )); + + // Verify preconditions + assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4); + assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1); + + // SHOULD(range) OR SHOULD(term): range matches all, so result is 4 + let should_query = BooleanQuery::new(vec![ + (Occur::Should, range_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&should_query, &Count)?, + 4, + "SHOULD range OR term should match all" + ); + + // MUST(range) AND SHOULD(term): range matches all, term is optional + let must_should_query = BooleanQuery::new(vec![ + (Occur::Must, range_query.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&must_should_query, &Count)?, + 4, + "MUST range + SHOULD term should match all" + ); + + Ok(()) + } + + /// Test multiple AllScorer instances in different clause types. + /// + /// Verifies correct behavior when AllScorers appear in multiple positions. + #[test] + pub fn test_multiple_all_scorers() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let num_field = + schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed()); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + + // All docs have num > 0, so range queries will return AllScorer + index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?; + index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?; + index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?; + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Two different range queries that both match all docs (return AllScorer) + let all_query1: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 0)), + Bound::Unbounded, + )); + let all_query2: Box = Box::new(RangeQuery::new( + Bound::Excluded(Term::from_field_i64(num_field, 5)), + Bound::Unbounded, + )); + let term_query: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "doc1"), + IndexRecordOption::Basic, + )); + + // Multiple AllScorers in SHOULD + let multi_all_should = BooleanQuery::new(vec![ + (Occur::Should, all_query1.box_clone()), + (Occur::Should, all_query2.box_clone()), + (Occur::Should, term_query.box_clone()), + ]); + assert_eq!( + searcher.search(&multi_all_should, &Count)?, + 3, + "Multiple AllScorers in SHOULD" + ); + + // AllScorer in both MUST and SHOULD + let all_must_and_should = BooleanQuery::new(vec![ + (Occur::Must, all_query1.box_clone()), + (Occur::Should, all_query2.box_clone()), + ]); + assert_eq!( + searcher.search(&all_must_and_should, &Count)?, + 3, + "AllScorer in both MUST and SHOULD" + ); + + Ok(()) + } +} + +/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches +/// the result against an index which contains all permutations of documents with N fields. +#[cfg(test)] +mod proptest_boolean_query { + use std::collections::{BTreeMap, HashSet}; + use std::ops::Bound; + + use proptest::collection::vec; + use proptest::prelude::*; + + use crate::collector::DocSetCollector; + use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery}; + use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT}; + use crate::{DocId, Index, Term}; + + #[derive(Debug, Clone)] + enum BooleanQueryAST { + /// Matches all documents via AllQuery (wraps AllScorer in BoostScorer) + All, + /// Matches all documents via RangeQuery (returns bare AllScorer) + /// This is the actual trigger for the AllScorer preservation bug + RangeAll, + /// Matches documents where the field has value "true" + Leaf { + field_idx: usize, + }, + Union(Vec), + Intersection(Vec), + } + + impl BooleanQueryAST { + fn matches(&self, doc_id: DocId) -> bool { + match self { + BooleanQueryAST::All => true, + BooleanQueryAST::RangeAll => true, + BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx), + BooleanQueryAST::Union(children) => { + children.iter().any(|child| child.matches(doc_id)) + } + BooleanQueryAST::Intersection(children) => { + children.iter().all(|child| child.matches(doc_id)) + } + } + } + + fn matches_field(doc_id: DocId, field_idx: usize) -> bool { + ((doc_id as usize) >> field_idx) & 1 == 1 + } + + fn to_query(&self, fields: &[Field], range_field: Field) -> Box { + match self { + BooleanQueryAST::All => Box::new(AllQuery), + BooleanQueryAST::RangeAll => { + // Range query that matches all docs (all have value >= 0) + // This returns bare AllScorer, triggering the bug we fixed + Box::new(RangeQuery::new( + Bound::Included(Term::from_field_i64(range_field, 0)), + Bound::Unbounded, + )) + } + BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new( + Term::from_field_text(fields[*field_idx], "true"), + crate::schema::IndexRecordOption::Basic, + )), + BooleanQueryAST::Union(children) => { + let sub_queries = children + .iter() + .map(|child| (Occur::Should, child.to_query(fields, range_field))) + .collect(); + Box::new(BooleanQuery::new(sub_queries)) + } + BooleanQueryAST::Intersection(children) => { + let sub_queries = children + .iter() + .map(|child| (Occur::Must, child.to_query(fields, range_field))) + .collect(); + Box::new(BooleanQuery::new(sub_queries)) + } + } + } + } + + fn doc_ids(num_docs: usize, num_fields: usize) -> impl Iterator { + let permutations = 1 << num_fields; + let copies = (num_docs as f32 / permutations as f32).ceil() as u32; + (0..(permutations * copies)).into_iter() + } + + fn create_index_with_boolean_permutations( + num_docs: usize, + num_fields: usize, + ) -> (Index, Vec, Field) { + let mut schema_builder = Schema::builder(); + let fields: Vec = (0..num_fields) + .map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT)) + .collect(); + // Add a numeric field for RangeQuery tests - all docs have value = doc_id + let range_field = schema_builder.add_i64_field( + "range_field", + NumericOptions::default().set_fast().set_indexed(), + ); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + + for doc_id in doc_ids(num_docs, num_fields) { + let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default(); + for (field_idx, &field) in fields.iter().enumerate() { + if (doc_id >> field_idx) & 1 == 1 { + doc.insert(field, "true".into()); + } + } + // All docs have non-negative values, so RangeQuery(>=0) matches all + doc.insert(range_field, (doc_id as i64).into()); + writer.add_document(doc).unwrap(); + } + writer.commit().unwrap(); + (index, fields, range_field) + } + + fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy { + // Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs + let leaf = prop_oneof![ + (0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }), + Just(BooleanQueryAST::All), + Just(BooleanQueryAST::RangeAll), + ]; + leaf.prop_recursive( + 8, // 8 levels of recursion + 256, // 256 nodes max + 10, // 10 items per collection + |inner| { + prop_oneof![ + vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union), + vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection), + ] + }, + ) + } + + #[test] + fn proptest_boolean_query() { + // In the presence of optimizations around buffering, it can take large numbers of + // documents to uncover some issues. + let num_docs = 10000; + let num_fields = 8; + let num_docs = 1 << num_fields; + let (index, fields, range_field) = + create_index_with_boolean_permutations(num_docs, num_fields); + let searcher = index.reader().unwrap().searcher(); + proptest!(|(ast in arb_boolean_query_ast(num_fields))| { + let query = ast.to_query(&fields, range_field); + + let mut matching_docs = HashSet::new(); + for doc_id in doc_ids(num_docs, num_fields) { + if ast.matches(doc_id as DocId) { + matching_docs.insert(doc_id as DocId); + } + } + + let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap(); + let result_docs: HashSet = + doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect(); + prop_assert_eq!(result_docs, matching_docs); + }); + } } From 73657dff775bff27967d4855b812deeec94765a4 Mon Sep 17 00:00:00 2001 From: Moe Date: Tue, 16 Dec 2025 13:57:12 -0800 Subject: [PATCH 09/26] fix: fixed integer overflow in ExpUnrolledLinkedList for large datasets (#2735) * Fixed the overflow issue. * Fixed lint issues. * Applied PR fixes. * Fixed a lint issue. --- stacker/src/expull.rs | 219 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 209 insertions(+), 10 deletions(-) diff --git a/stacker/src/expull.rs b/stacker/src/expull.rs index 5fd00db3d..189dc95f2 100644 --- a/stacker/src/expull.rs +++ b/stacker/src/expull.rs @@ -5,7 +5,7 @@ use common::serialize_vint_u32; use crate::fastcpy::fast_short_slice_copy; use crate::{Addr, MemoryArena}; -const FIRST_BLOCK_NUM: u16 = 2; +const FIRST_BLOCK_NUM: u32 = 2; /// An exponential unrolled link. /// @@ -33,8 +33,8 @@ pub struct ExpUnrolledLinkedList { // u16, since the max size of each block is (1< { } } -// The block size is 2^block_num + 2, but max 2^15= 32k -// Initial size is 8, for the first block => block_num == 1 +// The block size is 2^block_num, but max 2^15 = 32KB +// Initial size is 8 bytes (2^3), for the first block => block_num == 2 +// Block size caps at 32KB (2^15) regardless of how high block_num goes #[inline] -fn get_block_size(block_num: u16) -> u16 { - 1 << block_num.min(15) +fn get_block_size(block_num: u32) -> u16 { + // Cap at 15 to prevent block sizes > 32KB + // block_num can now be much larger than 15, but block size maxes out + let exp = block_num.min(15) as u32; + (1u32 << exp) as u16 } impl ExpUnrolledLinkedList { + #[inline(always)] pub fn increment_num_blocks(&mut self) { - self.block_num += 1; + // Add overflow check as a safety measure + // With u32, we can handle up to ~4 billion blocks before overflow + // At 32KB per block (max size), that's 128 TB of data + self.block_num = self + .block_num + .checked_add(1) + .expect("ExpUnrolledLinkedList block count overflow - exceeded 4 billion blocks"); } #[inline] @@ -132,9 +143,26 @@ impl ExpUnrolledLinkedList { if addr.is_null() { return; } - let last_block_len = get_block_size(self.block_num) as usize - self.remaining_cap as usize; - // Full Blocks + // Calculate last block length with bounds checking to prevent underflow + let block_size = get_block_size(self.block_num) as usize; + let last_block_len = block_size.saturating_sub(self.remaining_cap as usize); + + // Safety check: if remaining_cap > block_size, the metadata is corrupted + assert!( + self.remaining_cap as usize <= block_size, + "ExpUnrolledLinkedList metadata corruption detected: remaining_cap ({}) > block_size \ + ({}). This indicates a serious bug, please report! (block_num={}, head={:?}, \ + tail={:?})", + self.remaining_cap, + block_size, + self.block_num, + self.head, + self.tail + ); + + // Full Blocks (iterate through all blocks except the last one) + // Note: Blocks are numbered starting from FIRST_BLOCK_NUM+1 (=3) after first allocation for block_num in FIRST_BLOCK_NUM + 1..self.block_num { let cap = get_block_size(block_num) as usize; let data = arena.slice(addr, cap); @@ -259,6 +287,177 @@ mod tests { assert_eq!(&vec1[..], &res1[..]); assert_eq!(&vec2[..], &res2[..]); } + + // Tests for u32 block_num fix (issue with large arrays) + + #[test] + fn test_block_num_exceeds_u16_max() { + // Test that we can handle more than 65,535 blocks (old u16 limit) + let mut eull = ExpUnrolledLinkedList::default(); + + // Simulate allocating 70,000 blocks (exceeds u16::MAX of 65,535) + for _ in 0..70_000 { + eull.increment_num_blocks(); + } + + // Verify block_num is correct + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 70_000); + + // Verify we can still get block size (should be capped at 32KB) + let block_size = get_block_size(eull.block_num); + assert_eq!(block_size, 1 << 15); // 32KB max + } + + #[test] + fn test_large_dataset_simulation() { + // Simulate the scenario: large arrays requiring many blocks + // We write enough data to require thousands of blocks + let mut arena = MemoryArena::default(); + let mut eull = ExpUnrolledLinkedList::default(); + + // Write 100 MB of data (this will require ~3,200 blocks at 32KB each) + // This is enough to validate the system works with large datasets + // but not so much that the test is slow + let bytes_per_write = 10_000; + let num_writes = 10_000; // 10k * 10k = 100 MB + + let data: Vec = (0..bytes_per_write).map(|i| (i % 256) as u8).collect(); + for _ in 0..num_writes { + eull.writer(&mut arena).extend_from_slice(&data); + } + + // Verify we allocated many blocks (should be in the thousands) + assert!( + eull.block_num > 1000, + "block_num ({}) should be > 1000 for this much data", + eull.block_num + ); + + // Verify we can read back correctly + let mut buffer = Vec::new(); + eull.read_to_end(&arena, &mut buffer); + assert_eq!(buffer.len(), bytes_per_write * num_writes); + + // Verify data integrity on a sample + for i in 0..bytes_per_write { + assert_eq!(buffer[i], (i % 256) as u8); + } + } + + #[test] + fn test_get_block_size_with_large_block_num() { + // Test that get_block_size handles large u32 values correctly + + // Small block numbers (under 15) + assert_eq!(get_block_size(2), 4); // 2^2 = 4 + assert_eq!(get_block_size(3), 8); // 2^3 = 8 + assert_eq!(get_block_size(10), 1024); // 2^10 = 1KB + + // At the cap (15) + assert_eq!(get_block_size(15), 32768); // 2^15 = 32KB + + // Beyond the cap (should stay at 32KB) + assert_eq!(get_block_size(16), 32768); + assert_eq!(get_block_size(100), 32768); + assert_eq!(get_block_size(65_536), 32768); // Old u16::MAX + 1 + assert_eq!(get_block_size(100_000), 32768); + assert_eq!(get_block_size(1_000_000), 32768); + } + + #[test] + fn test_increment_blocks_near_u16_boundary() { + // Test incrementing around the old u16::MAX boundary + let mut eull = ExpUnrolledLinkedList::default(); + + // Set to just before old limit + for _ in 0..65_533 { + eull.increment_num_blocks(); + } + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_533); + + // Cross the old u16::MAX boundary (this would have overflowed before) + eull.increment_num_blocks(); // 65,534 + eull.increment_num_blocks(); // 65,535 (old max) + eull.increment_num_blocks(); // 65,536 (would overflow u16) + eull.increment_num_blocks(); // 65,537 + + // Verify we're past the old limit + assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_537); + } + + #[test] + fn test_write_and_read_with_many_blocks() { + // Test that write/read works correctly with many blocks + let mut arena = MemoryArena::default(); + let mut eull = ExpUnrolledLinkedList::default(); + + // Write data that will span many blocks + let test_data: Vec = (0..50_000).map(|i| (i % 256) as u8).collect(); + eull.writer(&mut arena).extend_from_slice(&test_data); + + // Read it back + let mut buffer = Vec::new(); + eull.read_to_end(&arena, &mut buffer); + + // Verify data integrity + assert_eq!(buffer.len(), test_data.len()); + assert_eq!(&buffer[..], &test_data[..]); + } + + #[test] + fn test_multiple_eull_with_large_block_counts() { + // Test multiple ExpUnrolledLinkedLists with high block counts + // (simulates parallel columnar writes) + let mut arena = MemoryArena::default(); + let mut eull1 = ExpUnrolledLinkedList::default(); + let mut eull2 = ExpUnrolledLinkedList::default(); + + // Write different data to each + for i in 0..10_000u32 { + eull1.writer(&mut arena).write_u32_vint(i); + eull2.writer(&mut arena).write_u32_vint(i * 2); + } + + // Read back and verify + let mut buf1 = Vec::new(); + let mut buf2 = Vec::new(); + eull1.read_to_end(&arena, &mut buf1); + eull2.read_to_end(&arena, &mut buf2); + + // Deserialize and check + let mut cursor1 = &buf1[..]; + let mut cursor2 = &buf2[..]; + for i in 0..10_000u32 { + assert_eq!(read_u32_vint(&mut cursor1), i); + assert_eq!(read_u32_vint(&mut cursor2), i * 2); + } + } + + #[test] + fn test_block_size_stays_capped() { + // Verify that even with massive block numbers, size stays at 32KB + let mut eull = ExpUnrolledLinkedList::default(); + + // Increment to a very large number + for _ in 0..200_000 { + eull.increment_num_blocks(); + } + + let block_size = get_block_size(eull.block_num); + assert_eq!(block_size, 32768, "Block size should be capped at 32KB"); + } + + #[test] + #[should_panic(expected = "ExpUnrolledLinkedList block count overflow")] + fn test_increment_overflow_protection() { + // Test that we panic gracefully if we somehow hit u32::MAX + // This is extremely unlikely in practice (would require 128TB of data) + let mut eull = ExpUnrolledLinkedList::default(); + eull.block_num = u32::MAX; + + // This should panic with our custom error message + eull.increment_num_blocks(); + } } #[cfg(all(test, feature = "unstable"))] From c0f21a45ae99a37996d1edaceb434d24ab3f057c Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Thu, 18 Dec 2025 03:13:23 -0800 Subject: [PATCH 10/26] Use a strict comparison in TopNComputer (#2777) * Remove `(Partial)Ord` from `ComparableDoc`, and unify comparison between `TopNComputer` and `Comparator`. * Doc cleanups. * Require Ord for `ComparableDoc`. * Semantics are actually _ascending_ DocId order. * Adjust docs again for ascending DocId order. * minor change --------- Co-authored-by: Paul Masurel --- src/collector/sort_key/mod.rs | 32 +++++++--- src/collector/sort_key/order.rs | 3 +- src/collector/top_collector.rs | 58 +++--------------- src/collector/top_score_collector.rs | 91 ++++++++++++++++++++-------- src/indexer/delete_queue.rs | 18 ++++-- src/indexer/operation.rs | 6 ++ 6 files changed, 116 insertions(+), 92 deletions(-) diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index a66115633..3bfb3b1c8 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -11,7 +11,26 @@ pub use sort_by_string::SortByString; pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer}; #[cfg(test)] -mod tests { +pub(crate) mod tests { + + // By spec, regardless of whether ascending or descending order was requested, in presence of a + // tie, we sort by ascending doc id/doc address. + pub(crate) fn sort_hits( + hits: &mut [ComparableDoc], + order: Order, + ) { + if order.is_asc() { + hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc))); + } else { + hits.sort_by(|l, r| { + l.sort_key + .cmp(&r.sort_key) + .reverse() // This is descending + .then(l.doc.cmp(&r.doc)) + }); + } + } + use std::collections::HashMap; use std::ops::Range; @@ -372,15 +391,10 @@ mod tests { // Using the TopDocs collector should always be equivalent to sorting, skipping the // offset, and then taking the limit. - let sorted_docs: Vec<_> = if order.is_desc() { - let mut comparable_docs: Vec> = + let sorted_docs: Vec<_> = { + let mut comparable_docs: Vec> = all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() - } else { - let mut comparable_docs: Vec> = - all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); + sort_hits(&mut comparable_docs, order); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 923d5cb8e..40a718b90 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -30,7 +30,8 @@ impl Comparator for NaturalComparator { /// first. /// /// The ReverseComparator does not necessarily imply that the sort order is reversed compared -/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids. +/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be +/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the comparator. #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct ReverseComparator; diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 6981c86c9..1990b3837 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,64 +1,22 @@ -use std::cmp::Ordering; - use serde::{Deserialize, Serialize}; /// Contains a feature (field, score, etc.) of a document along with the document address. /// -/// It guarantees stable sorting: in case of a tie on the feature, the document -/// address is used. -/// -/// The REVERSE_ORDER generic parameter controls whether the by-feature order -/// should be reversed, which is useful for achieving for example largest-first -/// semantics without having to wrap the feature in a `Reverse`. -#[derive(Clone, Default, Serialize, Deserialize)] -pub struct ComparableDoc { +/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`. +#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)] +pub struct ComparableDoc { /// The feature of the document. In practice, this is - /// is any type that implements `PartialOrd`. + /// is a type which can be compared with a `Comparator`. pub sort_key: T, - /// The document address. In practice, this is any - /// type that implements `PartialOrd`, and is guaranteed - /// to be unique for each document. + /// The document address. In practice, this is either a `DocId` or `DocAddress`. pub doc: D, } -impl std::fmt::Debug - for ComparableDoc -{ + +impl std::fmt::Debug for ComparableDoc { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) + f.debug_struct("ComparableDoc") .field("feature", &self.sort_key) .field("doc", &self.doc) .finish() } } - -impl PartialOrd for ComparableDoc { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for ComparableDoc { - #[inline] - fn cmp(&self, other: &Self) -> Ordering { - let by_feature = self - .sort_key - .partial_cmp(&other.sort_key) - .map(|ord| if R { ord.reverse() } else { ord }) - .unwrap_or(Ordering::Equal); - - let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal); - - // In case of a tie on the feature, we sort by ascending - // `DocAddress` in order to ensure a stable sorting of the - // documents. - by_feature.then_with(lazy_by_doc_address) - } -} - -impl PartialEq for ComparableDoc { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl Eq for ComparableDoc {} diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 78c344dbe..3c3f1beb9 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader}; /// The theoretical complexity for collecting the top `K` out of `N` documents /// is `O(N + K)`. /// -/// This collector does not guarantee a stable sorting in case of a tie on the -/// document score, for stable sorting `PartialOrd` needs to resolve on other fields -/// like docid in case of score equality. -/// Only then, it is suitable for pagination. +/// This collector guarantees a stable sorting in case of a tie on the +/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker. +/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`. /// /// ```rust /// use tantivy::collector::TopDocs; @@ -500,8 +499,13 @@ where /// /// For TopN == 0, it will be relative expensive. /// -/// When using the natural comparator, the top N computer returns the top N elements in -/// descending order, as expected for a top N. +/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress): +/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in +/// ascending order, regardless of the `Comparator` used for the `Score` type. +/// +/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the +/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides +/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons. #[derive(Serialize, Deserialize)] #[serde(from = "TopNComputerDeser")] pub struct TopNComputer { @@ -580,6 +584,18 @@ where } } +#[inline(always)] +fn compare_for_top_k>( + c: &C, + lhs: &ComparableDoc, + rhs: &ComparableDoc, +) -> std::cmp::Ordering { + c.compare(&lhs.sort_key, &rhs.sort_key) + .reverse() // Reverse here because we want top K. + .then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we + // sort by doc id +} + impl TopNComputer where D: Ord, @@ -600,10 +616,13 @@ where /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. + /// + /// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order. #[inline] pub fn push(&mut self, sort_key: TSortKey, doc: D) { if let Some(last_median) = &self.threshold { - if self.comparator.compare(&sort_key, last_median) == Ordering::Less { + // See the struct docs for an explanation of why this comparison is strict. + if self.comparator.compare(&sort_key, last_median) != Ordering::Greater { return; } } @@ -629,9 +648,7 @@ where fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| { - self.comparator - .compare(&rhs.sort_key, &lhs.sort_key) - .then_with(|| lhs.doc.cmp(&rhs.doc)) + compare_for_top_k(&self.comparator, lhs, rhs) }); let median_score = median_el.sort_key.clone(); @@ -646,11 +663,8 @@ where if self.buffer.len() > self.top_n { self.truncate_top_n(); } - self.buffer.sort_unstable_by(|left, right| { - self.comparator - .compare(&right.sort_key, &left.sort_key) - .then_with(|| left.doc.cmp(&right.doc)) - }); + self.buffer + .sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs)); self.buffer } @@ -755,6 +769,33 @@ mod tests { ); } + #[test] + fn test_topn_computer_duplicates() { + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(2, NaturalComparator); + + computer.push(1u32, 1u32); + computer.push(1u32, 2u32); + computer.push(1u32, 3u32); + computer.push(1u32, 4u32); + computer.push(1u32, 5u32); + + // In the presence of duplicates, DocIds are always ascending order. + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: 1u32, + doc: 1u32, + }, + ComparableDoc { + sort_key: 1u32, + doc: 2u32, + } + ] + ); + } + #[test] fn test_topn_computer_no_panic() { for top_n in 0..10 { @@ -772,14 +813,17 @@ mod tests { #[test] fn test_topn_computer_asc_prop( limit in 0..10_usize, - docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), + mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), ) { + // NB: TopNComputer must receive inputs in ascending DocId order. + docs.sort_by_key(|(_, doc_id)| *doc_id); let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator); for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs: Vec> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::>(); - comparable_docs.sort(); + let mut comparable_docs: Vec> = + docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect(); + crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc); comparable_docs.truncate(limit); prop_assert_eq!( computer.into_sorted_vec(), @@ -1406,15 +1450,10 @@ mod tests { // Using the TopDocs collector should always be equivalent to sorting, skipping the // offset, and then taking the limit. - let sorted_docs: Vec<_> = if order.is_desc() { - let mut comparable_docs: Vec> = + let sorted_docs: Vec<_> = { + let mut comparable_docs: Vec> = all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() - } else { - let mut comparable_docs: Vec> = - all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); - comparable_docs.sort(); + crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); diff --git a/src/indexer/delete_queue.rs b/src/indexer/delete_queue.rs index 3aa9f0d85..25c84fb36 100644 --- a/src/indexer/delete_queue.rs +++ b/src/indexer/delete_queue.rs @@ -23,13 +23,18 @@ struct InnerDeleteQueue { last_block: Weak, } +/// The delete queue is a linked list storing delete operations. +/// +/// Several consumers can hold a reference to it. Delete operations +/// get dropped/gc'ed when no more consumers are holding a reference +/// to them. #[derive(Clone)] pub struct DeleteQueue { inner: Arc>, } impl DeleteQueue { - // Creates a new delete queue. + /// Creates a new empty delete queue. pub fn new() -> DeleteQueue { DeleteQueue { inner: Arc::default(), @@ -58,10 +63,10 @@ impl DeleteQueue { block } - // Creates a new cursor that makes it possible to - // consume future delete operations. - // - // Past delete operations are not accessible. + /// Creates a new cursor that makes it possible to + /// consume future delete operations. + /// + /// Past delete operations are not accessible. pub fn cursor(&self) -> DeleteCursor { let last_block = self.get_last_block(); let operations_len = last_block.operations.len(); @@ -71,7 +76,7 @@ impl DeleteQueue { } } - // Appends a new delete operations. + /// Appends a new delete operations. pub fn push(&self, delete_operation: DeleteOperation) { self.inner .write() @@ -169,6 +174,7 @@ struct Block { next: NextBlock, } +/// As we process delete operations, keeps track of our position. #[derive(Clone)] pub struct DeleteCursor { block: Arc, diff --git a/src/indexer/operation.rs b/src/indexer/operation.rs index 69bffec17..9316f6fa7 100644 --- a/src/indexer/operation.rs +++ b/src/indexer/operation.rs @@ -5,14 +5,20 @@ use crate::Opstamp; /// Timestamped Delete operation. pub struct DeleteOperation { + /// Operation stamp. + /// It is used to check whether the delete operation + /// applies to an added document operation. pub opstamp: Opstamp, + /// Weight is used to define the set of documents to be deleted. pub target: Box, } /// Timestamped Add operation. #[derive(Eq, PartialEq, Debug)] pub struct AddOperation { + /// Operation stamp. pub opstamp: Opstamp, + /// Document to be added. pub document: D, } From ce97beb86f9f1f49f63c387e644c2f21d97405a8 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Tue, 23 Dec 2025 01:22:20 -0700 Subject: [PATCH 11/26] Add support for natural-order-with-none-highest in `TopDocs::order_by` (#2780) * Add `ComparatorEnum::NaturalNoneHigher`. * Fix comments. --- src/collector/sort_key/order.rs | 129 +++++++++++++++++++++++++++++--- 1 file changed, 119 insertions(+), 10 deletions(-) diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 40a718b90..e89154c96 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -12,8 +12,13 @@ pub trait Comparator: Send + Sync + std::fmt::Debug + Default { fn compare(&self, lhs: &T, rhs: &T) -> Ordering; } -/// With the natural comparator, the top k collector will return -/// the top documents in decreasing order. +/// Compare values naturally (e.g. 1 < 2). +/// +/// When used with `TopDocs`, which reverses the order, this results in a +/// "Descending" sort (Greatest values first). +/// +/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value, +/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`). #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct NaturalComparator; @@ -24,14 +29,18 @@ impl Comparator for NaturalComparator { } } -/// Sorts document in reverse order. +/// Compare values in reverse (e.g. 2 < 1). /// -/// If the sort key is None, it will considered as the lowest value, and will therefore appear -/// first. +/// When used with `TopDocs`, which reverses the order, this results in an +/// "Ascending" sort (Smallest values first). +/// +/// `None` is considered smaller than `Some` in the underlying comparator, but because the +/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting +/// Ascending sort (e.g. `[None, Some(10), Some(20)]`). /// /// The ReverseComparator does not necessarily imply that the sort order is reversed compared /// to the NaturalComparator. In presence of a tie on the sort key, documents will always be -/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the comparator. +/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order. #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct ReverseComparator; @@ -44,11 +53,15 @@ where NaturalComparator: Comparator } } -/// Sorts document in reverse order, but considers None as having the lowest value. +/// Compare values in reverse, but treating `None` as lower than `Some`. +/// +/// When used with `TopDocs`, which reverses the order, this results in an +/// "Ascending" sort (Smallest values first), but with `None` values appearing last +/// (e.g. `[Some(10), Some(20), None]`). /// /// This is usually what is wanted when sorting by a field in an ascending order. -/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the -/// cheapest items first, and the items without a price at last. +/// For instance, in an e-commerce website, if sorting by price ascending, +/// the cheapest items would appear first, and items without a price would appear last. #[derive(Debug, Copy, Clone, Default)] pub struct ReverseNoneIsLowerComparator; @@ -108,6 +121,70 @@ impl Comparator for ReverseNoneIsLowerComparator { } } +/// Compare values naturally, but treating `None` as higher than `Some`. +/// +/// When used with `TopDocs`, which reverses the order, this results in a +/// "Descending" sort (Greatest values first), but with `None` values appearing first +/// (e.g. `[None, Some(20), Some(10)]`). +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +pub struct NaturalNoneIsHigherComparator; + +impl Comparator> for NaturalNoneIsHigherComparator +where NaturalComparator: Comparator +{ + #[inline(always)] + fn compare(&self, lhs_opt: &Option, rhs_opt: &Option) -> Ordering { + match (lhs_opt, rhs_opt) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + (Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs), + } + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &String, rhs: &String) -> Ordering { + NaturalComparator.compare(lhs, rhs) + } +} + /// An enum representing the different sort orders. #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] pub enum ComparatorEnum { @@ -116,8 +193,10 @@ pub enum ComparatorEnum { Natural, /// Reverse order (See [ReverseComparator]) Reverse, - /// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator]) + /// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator]) ReverseNoneLower, + /// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator]) + NaturalNoneHigher, } impl From for ComparatorEnum { @@ -134,6 +213,7 @@ where ReverseNoneIsLowerComparator: Comparator, NaturalComparator: Comparator, ReverseComparator: Comparator, + NaturalNoneIsHigherComparator: Comparator, { #[inline(always)] fn compare(&self, lhs: &T, rhs: &T) -> Ordering { @@ -141,6 +221,7 @@ where ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs), ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs), ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs), + ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs), } } } @@ -347,3 +428,31 @@ where .convert_segment_sort_key(sort_key) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_natural_none_is_higher() { + let comp = NaturalNoneIsHigherComparator; + let null = None; + let v1 = Some(1_u64); + let v2 = Some(2_u64); + + // NaturalNoneIsGreaterComparator logic: + // 1. Delegates to NaturalComparator for non-nulls. + // NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater. + assert_eq!(comp.compare(&v2, &v1), Ordering::Greater); + + // 2. Treats None (Null) as Greater than any value. + // compare(None, Some(2)) should be Greater. + assert_eq!(comp.compare(&null, &v2), Ordering::Greater); + + // compare(Some(1), None) should be Less. + assert_eq!(comp.compare(&v1, &null), Ordering::Less); + + // compare(None, None) should be Equal. + assert_eq!(comp.compare(&null, &null), Ordering::Equal); + } +} From e0b62e00ac8b2e39a48f6d72c3ace040e845ebd5 Mon Sep 17 00:00:00 2001 From: ChangRui-Ryan Date: Mon, 29 Dec 2025 23:55:28 +0800 Subject: [PATCH 12/26] optimize RangeDocSet for non-overlapping query ranges (#2783) --- Cargo.toml | 4 + benches/range_queries.rs | 365 ++++++++++++++++++ .../src/column_values/u64_based/bitpacked.rs | 6 - .../range_query/fast_field_range_doc_set.rs | 59 +++ 4 files changed, 428 insertions(+), 6 deletions(-) create mode 100644 benches/range_queries.rs diff --git a/Cargo.toml b/Cargo.toml index 32d7bd990..dfbb1ea1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -176,3 +176,7 @@ harness = false [[bench]] name = "and_or_queries" harness = false + +[[bench]] +name = "range_queries" +harness = false diff --git a/benches/range_queries.rs b/benches/range_queries.rs new file mode 100644 index 000000000..56aaf54b9 --- /dev/null +++ b/benches/range_queries.rs @@ -0,0 +1,365 @@ +use std::ops::Bound; + +use binggan::{black_box, BenchGroup, BenchRunner}; +use rand::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tantivy::collector::{Count, DocSetCollector, TopDocs}; +use tantivy::query::RangeQuery; +use tantivy::schema::{Schema, FAST, INDEXED}; +use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term}; + +#[derive(Clone)] +struct BenchIndex { + #[allow(dead_code)] + index: Index, + searcher: Searcher, +} + +fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex { + // Schema with fast fields only + let mut schema_builder = Schema::builder(); + let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST); + let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + + // Populate index with stable RNG for reproducibility. + let mut rng = StdRng::from_seed([7u8; 32]); + + { + let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap(); + + match distribution { + "dense" => { + for doc_id in 0..num_docs { + let num_rand = rng.gen_range(0u64..1000u64); + let num_asc = (doc_id / 10000) as u64; + + writer + .add_document(doc!( + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + "sparse" => { + for doc_id in 0..num_docs { + let num_rand = rng.gen_range(0u64..10000000u64); + let num_asc = doc_id as u64; + + writer + .add_document(doc!( + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + _ => { + panic!("Unsupported distribution type"); + } + } + writer.commit().unwrap(); + } + + // Prepare reader/searcher once. + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::Manual) + .try_into() + .unwrap(); + let searcher = reader.searcher(); + + BenchIndex { index, searcher } +} + +fn main() { + // Prepare corpora with varying scenarios + let scenarios = vec![ + // Dense distribution - random values in small range (0-999) + ( + "dense_values_search_low_value_range".to_string(), + 10_000_000, + "dense", + 0, + 9, + ), + ( + "dense_values_search_high_value_range".to_string(), + 10_000_000, + "dense", + 990, + 999, + ), + ( + "dense_values_search_out_of_range".to_string(), + 10_000_000, + "dense", + 1000, + 1002, + ), + ( + "sparse_values_search_low_value_range".to_string(), + 10_000_000, + "sparse", + 0, + 9, + ), + ( + "sparse_values_search_high_value_range".to_string(), + 10_000_000, + "sparse", + 9_999_990, + 9_999_999, + ), + ( + "sparse_values_search_out_of_range".to_string(), + 10_000_000, + "sparse", + 10_000_000, + 10_000_002, + ), + ]; + + let mut runner = BenchRunner::new(); + for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios { + // Build index for this scenario + let bench_index = build_shared_indices(n, num_rand_distribution); + + // Create benchmark group + let mut group = runner.new_group(); + + // Now set the name (this moves scenario_id) + group.set_name(scenario_id); + + // Define fast field types + let field_names = ["num_rand_fast", "num_asc_fast"]; + + // Generate range queries for fast fields + for &field_name in &field_names { + // Create the range query + let field = bench_index.searcher.schema().get_field(field_name).unwrap(); + let lower_term = Term::from_field_u64(field, range_low); + let upper_term = Term::from_field_u64(field, range_high); + + let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term)); + + run_benchmark_tasks( + &mut group, + &bench_index, + query, + field_name, + range_low, + range_high, + ); + } + + group.run(); + } +} + +/// Run all benchmark tasks for a given range query and field name +fn run_benchmark_tasks( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + field_name: &str, + range_low: u64, + range_high: u64, +) { + // Test count + add_bench_task_count( + bench_group, + bench_index, + query.clone(), + "count", + field_name, + range_low, + range_high, + ); + + // Test top 100 by the field (ascending order) + { + let collector_name = format!("top100_by_{}_asc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task_top100_asc( + bench_group, + bench_index, + query.clone(), + &collector_name, + field_name, + range_low, + range_high, + field_name_owned, + ); + } + + // Test top 100 by the field (descending order) + { + let collector_name = format!("top100_by_{}_desc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task_top100_desc( + bench_group, + bench_index, + query, + &collector_name, + field_name, + range_low, + range_high, + field_name_owned, + ); + } +} + +fn add_bench_task_count( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = CountSearchTask { + searcher: bench_index.searcher.clone(), + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_docset( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = DocSetSearchTask { + searcher: bench_index.searcher.clone(), + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_top100_asc( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, + field_name_owned: String, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = Top100AscSearchTask { + searcher: bench_index.searcher.clone(), + query, + field_name: field_name_owned, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +fn add_bench_task_top100_desc( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query: RangeQuery, + collector_name: &str, + field_name: &str, + range_low: u64, + range_high: u64, + field_name_owned: String, +) { + let task_name = format!( + "range_{}_[{} TO {}]_{}", + field_name, range_low, range_high, collector_name + ); + + let search_task = Top100DescSearchTask { + searcher: bench_index.searcher.clone(), + query, + field_name: field_name_owned, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +struct CountSearchTask { + searcher: Searcher, + query: RangeQuery, +} + +impl CountSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + self.searcher.search(&self.query, &Count).unwrap() + } +} + +struct DocSetSearchTask { + searcher: Searcher, + query: RangeQuery, +} + +impl DocSetSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let result = self.searcher.search(&self.query, &DocSetCollector).unwrap(); + result.len() + } +} + +struct Top100AscSearchTask { + searcher: Searcher, + query: RangeQuery, + field_name: String, +} + +impl Top100AscSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let collector = + TopDocs::with_limit(100).order_by_fast_field::(&self.field_name, Order::Asc); + let result = self.searcher.search(&self.query, &collector).unwrap(); + for (_score, doc_address) in &result { + let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap(); + } + result.len() + } +} + +struct Top100DescSearchTask { + searcher: Searcher, + query: RangeQuery, + field_name: String, +} + +impl Top100DescSearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let collector = + TopDocs::with_limit(100).order_by_fast_field::(&self.field_name, Order::Desc); + let result = self.searcher.search(&self.query, &collector).unwrap(); + for (_score, doc_address) in &result { + let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap(); + } + result.len() + } +} diff --git a/columnar/src/column_values/u64_based/bitpacked.rs b/columnar/src/column_values/u64_based/bitpacked.rs index fde012937..71319cbec 100644 --- a/columnar/src/column_values/u64_based/bitpacked.rs +++ b/columnar/src/column_values/u64_based/bitpacked.rs @@ -41,12 +41,6 @@ fn transform_range_before_linear_transformation( if range.is_empty() { return None; } - if stats.min_value > *range.end() { - return None; - } - if stats.max_value < *range.start() { - return None; - } let shifted_range = range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value); let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd); diff --git a/src/query/range_query/fast_field_range_doc_set.rs b/src/query/range_query/fast_field_range_doc_set.rs index dd4b8fe68..5a76f7e9d 100644 --- a/src/query/range_query/fast_field_range_doc_set.rs +++ b/src/query/range_query/fast_field_range_doc_set.rs @@ -62,6 +62,17 @@ pub(crate) struct RangeDocSet { const DEFAULT_FETCH_HORIZON: u32 = 128; impl RangeDocSet { pub(crate) fn new(value_range: RangeInclusive, column: Column) -> Self { + if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() { + return Self { + value_range, + column, + loaded_docs: VecCursor::new(), + next_fetch_start: TERMINATED, + fetch_horizon: DEFAULT_FETCH_HORIZON, + last_seek_pos_opt: None, + }; + } + let mut range_docset = Self { value_range, column, @@ -236,4 +247,52 @@ mod tests { let count = searcher.search(&query, &Count).unwrap(); assert_eq!(count, 500); } + + #[test] + fn range_query_no_overlap_optimization() { + let mut schema_builder = schema::SchemaBuilder::new(); + let id_field = schema_builder.add_text_field("id", schema::STRING); + let value_field = schema_builder.add_u64_field("value", schema::FAST | schema::INDEXED); + + let dir = RamDirectory::default(); + let index = IndexBuilder::new() + .schema(schema_builder.build()) + .open_or_create(dir) + .unwrap(); + + { + let mut writer = index.writer(15_000_000).unwrap(); + + // Add documents with values in the range [10, 20] + for i in 0..100 { + let mut doc = TantivyDocument::new(); + doc.add_text(id_field, format!("doc{i}")); + doc.add_u64(value_field, 10 + (i % 11) as u64); // values in range 10-20 + + writer.add_document(doc).unwrap(); + } + writer.commit().unwrap(); + } + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + // Test a range query [100, 200] that has no overlap with data range [10, 20] + let query = RangeQuery::new( + Bound::Included(Term::from_field_u64(value_field, 100)), + Bound::Included(Term::from_field_u64(value_field, 200)), + ); + + let count = searcher.search(&query, &Count).unwrap(); + assert_eq!(count, 0); // should return 0 results since there's no overlap + + // Test another non-overlapping range: [0, 5] while data range is [10, 20] + let query2 = RangeQuery::new( + Bound::Included(Term::from_field_u64(value_field, 0)), + Bound::Included(Term::from_field_u64(value_field, 5)), + ); + + let count2 = searcher.search(&query2, &Count).unwrap(); + assert_eq!(count2, 0); // should return 0 results since there's no overlap + } } From 923f0508f291e3dbeeb4db020722436eecb5a596 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Tue, 30 Dec 2025 21:43:25 +0800 Subject: [PATCH 13/26] seek_exact + cost based intersection (#2538) * seek_exact + cost based intersection Adds `seek_exact` and `cost` to `DocSet` for a more efficient intersection. Unlike `seek`, `seek_exact` does not require the DocSet to advance to the next hit, if the target does not exist. `cost` allows to address the different DocSet types and their cost model and is used to determine the DocSet that drives the intersection. E.g. fast field range queries may do a full scan. Phrase queries load the positions to check if a we have a hit. They both have a higher cost than their size_hint would suggest. Improves `size_hint` estimation for intersection and union, by having a estimation based on random distribution with a co-location factor. Refactor range query benchmark. Closes #2531 *Future Work* Implement `seek_exact` for BufferedUnionScorer and RangeDocSet (fast field range queries) Evaluate replacing `seek` with `seek_exact` to reduce code complexity * Apply suggestions from code review Co-authored-by: Paul Masurel * add API contract verfication * impl seek_exact on union * rename seek_exact * add mixed AND OR test, fix buffered_union * Add a proptest of BooleanQuery. (#2690) * fix build * Increase the document count. * fix merge conflict * fix debug assert * Fix compilation errors after rebase - Remove duplicate proptest_boolean_query module - Remove duplicate cost() method implementations - Fix TopDocs API usage (add .order_by_score()) - Remove duplicate imports - Remove unused variable assignments --------- Co-authored-by: Paul Masurel Co-authored-by: Pascal Seitz Co-authored-by: Stu Hood --- Cargo.toml | 8 +- benches/range_query.rs | 260 ++++++++++ src/docset.rs | 47 ++ src/postings/compression/mod.rs | 1 + src/query/all_query.rs | 9 + src/query/boolean_query/block_wand.rs | 2 +- src/query/boolean_query/mod.rs | 1 - src/query/boost_query.rs | 3 + src/query/disjunction.rs | 10 + src/query/intersection.rs | 96 +++- src/query/mod.rs | 76 ++- .../phrase_prefix_scorer.rs | 8 + src/query/phrase_query/phrase_scorer.rs | 25 +- .../range_query/fast_field_range_doc_set.rs | 36 +- .../range_query/range_query_fastfield.rs | 446 ------------------ src/query/reqopt_scorer.rs | 5 + src/query/term_query/term_scorer.rs | 3 + src/query/union/buffered_union.rs | 31 +- src/query/union/simple_union.rs | 1 + 19 files changed, 581 insertions(+), 487 deletions(-) create mode 100644 benches/range_query.rs diff --git a/Cargo.toml b/Cargo.toml index dfbb1ea1c..28cd49b07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,12 +75,12 @@ typetag = "0.2.21" winapi = "0.3.9" [dev-dependencies] -binggan = "0.14.0" +binggan = "0.14.2" rand = "0.8.5" maplit = "1.0.2" matches = "0.1.9" pretty_assertions = "1.2.1" -proptest = "1.0.0" +proptest = "1.7.0" test-log = "0.2.10" futures = "0.3.21" paste = "1.0.11" @@ -173,6 +173,10 @@ harness = false name = "exists_json" harness = false +[[bench]] +name = "range_query" +harness = false + [[bench]] name = "and_or_queries" harness = false diff --git a/benches/range_query.rs b/benches/range_query.rs new file mode 100644 index 000000000..bf46666f3 --- /dev/null +++ b/benches/range_query.rs @@ -0,0 +1,260 @@ +use std::fmt::Display; +use std::net::Ipv6Addr; +use std::ops::RangeInclusive; + +use binggan::plugins::PeakMemAllocPlugin; +use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM}; +use columnar::MonotonicallyMappableToU128; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tantivy::collector::{Count, TopDocs}; +use tantivy::query::QueryParser; +use tantivy::schema::*; +use tantivy::{doc, Index}; + +#[global_allocator] +pub static GLOBAL: &PeakMemAlloc = &INSTRUMENTED_SYSTEM; + +fn main() { + bench_range_query(); +} + +fn bench_range_query() { + let index = get_index_0_to_100(); + let mut runner = BenchRunner::new(); + runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL)); + + runner.set_name("range_query on u64"); + let field_name_and_descr: Vec<_> = vec![ + ("id", "Single Valued Range Field"), + ("ids", "Multi Valued Range Field"), + ]; + let range_num_hits = vec![ + ("90_percent", get_90_percent()), + ("10_percent", get_10_percent()), + ("1_percent", get_1_percent()), + ]; + + test_range(&mut runner, &index, &field_name_and_descr, range_num_hits); + + runner.set_name("range_query on ip"); + let field_name_and_descr: Vec<_> = vec![ + ("ip", "Single Valued Range Field"), + ("ips", "Multi Valued Range Field"), + ]; + let range_num_hits = vec![ + ("90_percent", get_90_percent_ip()), + ("10_percent", get_10_percent_ip()), + ("1_percent", get_1_percent_ip()), + ]; + + test_range(&mut runner, &index, &field_name_and_descr, range_num_hits); +} + +fn test_range( + runner: &mut BenchRunner, + index: &Index, + field_name_and_descr: &[(&str, &str)], + range_num_hits: Vec<(&str, RangeInclusive)>, +) { + for (field, suffix) in field_name_and_descr { + let term_num_hits = vec![ + ("", ""), + ("1_percent", "veryfew"), + ("10_percent", "few"), + ("90_percent", "most"), + ]; + let mut group = runner.new_group(); + group.set_name(suffix); + // all intersect combinations + for (range_name, range) in &range_num_hits { + for (term_name, term) in &term_num_hits { + let index = &index; + let test_name = if term_name.is_empty() { + format!("id_range_hit_{}", range_name) + } else { + format!( + "id_range_hit_{}_intersect_with_term_{}", + range_name, term_name + ) + }; + group.register(test_name, move |_| { + let query = if term_name.is_empty() { + "".to_string() + } else { + format!("AND id_name:{}", term) + }; + black_box(execute_query(field, range, &query, index)); + }); + } + } + group.run(); + } +} + +fn get_index_0_to_100() -> Index { + let mut rng = StdRng::from_seed([1u8; 32]); + let num_vals = 100_000; + let docs: Vec<_> = (0..num_vals) + .map(|_i| { + let id_name = if rng.gen_bool(0.01) { + "veryfew".to_string() // 1% + } else if rng.gen_bool(0.1) { + "few".to_string() // 9% + } else { + "most".to_string() // 90% + }; + Doc { + id_name, + id: rng.gen_range(0..100), + // Multiply by 1000, so that we create most buckets in the compact space + // The benches depend on this range to select n-percent of elements with the + // methods below. + ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000), + } + }) + .collect(); + + create_index_from_docs(&docs) +} + +#[derive(Clone, Debug)] +pub struct Doc { + pub id_name: String, + pub id: u64, + pub ip: Ipv6Addr, +} + +pub fn create_index_from_docs(docs: &[Doc]) -> Index { + let mut schema_builder = Schema::builder(); + let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST); + let ids_u64_field = + schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed()); + + let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST); + let ids_f64_field = schema_builder.add_f64_field( + "ids_f64", + NumericOptions::default().set_fast().set_indexed(), + ); + + let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST); + let ids_i64_field = schema_builder.add_i64_field( + "ids_i64", + NumericOptions::default().set_fast().set_indexed(), + ); + + let text_field = schema_builder.add_text_field("id_name", STRING | STORED); + let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST); + + let ip_field = schema_builder.add_ip_addr_field("ip", FAST); + let ips_field = schema_builder.add_ip_addr_field("ips", FAST); + + let schema = schema_builder.build(); + + let index = Index::create_in_ram(schema); + + { + let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap(); + for doc in docs.iter() { + index_writer + .add_document(doc!( + ids_i64_field => doc.id as i64, + ids_i64_field => doc.id as i64, + ids_f64_field => doc.id as f64, + ids_f64_field => doc.id as f64, + ids_u64_field => doc.id, + ids_u64_field => doc.id, + id_u64_field => doc.id, + id_f64_field => doc.id as f64, + id_i64_field => doc.id as i64, + text_field => doc.id_name.to_string(), + text_field2 => doc.id_name.to_string(), + ips_field => doc.ip, + ips_field => doc.ip, + ip_field => doc.ip, + )) + .unwrap(); + } + + index_writer.commit().unwrap(); + } + index +} + +fn get_90_percent() -> RangeInclusive { + 0..=90 +} + +fn get_10_percent() -> RangeInclusive { + 0..=10 +} + +fn get_1_percent() -> RangeInclusive { + 10..=10 +} + +fn get_90_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(0); + let end = Ipv6Addr::from_u128(90 * 1000); + start..=end +} + +fn get_10_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(0); + let end = Ipv6Addr::from_u128(10 * 1000); + start..=end +} + +fn get_1_percent_ip() -> RangeInclusive { + let start = Ipv6Addr::from_u128(10 * 1000); + let end = Ipv6Addr::from_u128(10 * 1000); + start..=end +} + +struct NumHits { + count: usize, +} +impl OutputValue for NumHits { + fn column_title() -> &'static str { + "NumHits" + } + fn format(&self) -> Option { + Some(self.count.to_string()) + } +} + +fn execute_query( + field: &str, + id_range: &RangeInclusive, + suffix: &str, + index: &Index, +) -> NumHits { + let gen_query_inclusive = |from: &T, to: &T| { + format!( + "{}:[{} TO {}] {}", + field, + &from.to_string(), + &to.to_string(), + suffix + ) + }; + + let query = gen_query_inclusive(id_range.start(), id_range.end()); + execute_query_(&query, index) +} + +fn execute_query_(query: &str, index: &Index) -> NumHits { + let query_from_text = |text: &str| { + QueryParser::for_index(index, vec![]) + .parse_query(text) + .unwrap() + }; + let query = query_from_text(query); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let num_hits = searcher + .search(&query, &(TopDocs::with_limit(10).order_by_score(), Count)) + .unwrap() + .1; + NumHits { count: num_hits } +} diff --git a/src/docset.rs b/src/docset.rs index 7de138da6..01ea1125a 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -40,6 +40,8 @@ pub trait DocSet: Send { /// of `DocSet` should support it. /// /// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`. + /// + /// `target` has to be larger or equal to `.doc()` when calling `seek`. fn seek(&mut self, target: DocId) -> DocId { let mut doc = self.doc(); debug_assert!(doc <= target); @@ -49,6 +51,33 @@ pub trait DocSet: Send { doc } + /// Seeks to the target if possible and returns true if the target is in the DocSet. + /// + /// DocSets that already have an efficient `seek` method don't need to implement + /// `seek_into_the_danger_zone`. All wrapper DocSets should forward + /// `seek_into_the_danger_zone` to the underlying DocSet. + /// + /// ## API Behaviour + /// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target. + /// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc + /// between the last doc that matched and target or a doc that is a valid next hit after + /// target. The DocSet is considered to be in an invalid state until + /// `seek_into_the_danger_zone` returns true again. + /// + /// `target` needs to be equal or larger than `doc` when in a valid state. + /// + /// Consecutive calls are not allowed to have decreasing `target` values. + /// + /// # Warning + /// This is an advanced API used by intersection. The API contract is tricky, avoid using it. + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let current_doc = self.doc(); + if current_doc < target { + self.seek(target); + } + self.doc() == target + } + /// Fills a given mutable buffer with the next doc ids from the /// `DocSet` /// @@ -94,6 +123,15 @@ pub trait DocSet: Send { /// which would be the number of documents in the DocSet. /// /// By default this returns `size_hint()`. + /// + /// DocSets may have vastly different cost depending on their type, + /// e.g. an intersection with 10 hits is much cheaper than + /// a phrase search with 10 hits, since it needs to load positions. + /// + /// ### Future Work + /// We may want to differentiate `DocSet` costs more more granular, e.g. + /// creation_cost, advance_cost, seek_cost on to get a good estimation + /// what query types to choose. fn cost(&self) -> u64 { self.size_hint() as u64 } @@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet { (**self).seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + (**self).seek_into_the_danger_zone(target) + } + fn doc(&self) -> u32 { (**self).doc() } @@ -169,6 +211,11 @@ impl DocSet for Box { unboxed.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let unboxed: &mut TDocSet = self.borrow_mut(); + unboxed.seek_into_the_danger_zone(target) + } + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.fill_buffer(buffer) diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index 6b7b0de9f..62eeca3d5 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -9,6 +9,7 @@ const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE; mod vint; /// Returns the size in bytes of a compressed block, given `num_bits`. +#[inline] pub fn compressed_block_size(num_bits: u8) -> usize { (num_bits as usize) * COMPRESSION_BLOCK_SIZE / 8 } diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 16a83ec56..612d7408c 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -62,6 +62,15 @@ impl DocSet for AllScorer { self.doc } + fn seek(&mut self, target: DocId) -> DocId { + debug_assert!(target >= self.doc); + self.doc = target; + if self.doc >= self.max_doc { + self.doc = TERMINATED; + } + self.doc + } + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { if self.doc() == TERMINATED { return 0; diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index c6710b09c..6b2f2d6e3 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -483,7 +483,7 @@ mod tests { let checkpoints_for_each_pruning = compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k); let checkpoints_manual = - compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000); + compute_checkpoints_manual(term_scorers.clone(), top_k, max_doc as u32); assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len()); for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning .iter() diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 2d7936f00..a78cbe8e9 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -817,7 +817,6 @@ mod proptest_boolean_query { fn proptest_boolean_query() { // In the presence of optimizations around buffering, it can take large numbers of // documents to uncover some issues. - let num_docs = 10000; let num_fields = 8; let num_docs = 1 << num_fields; let (index, fields, range_field) = diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index 06678287f..ecbf3d8d6 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -104,6 +104,9 @@ impl DocSet for BoostScorer { fn seek(&mut self, target: DocId) -> DocId { self.underlying.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + self.underlying.seek_into_the_danger_zone(target) + } fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { self.underlying.fill_buffer(buffer) diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index 910e207df..b2f1080fc 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -62,6 +62,16 @@ impl DocSet for ScorerWrapper { self.current_doc = doc_id; doc_id } + fn seek(&mut self, target: DocId) -> DocId { + let doc_id = self.scorer.seek(target); + self.current_doc = doc_id; + doc_id + } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + let found = self.scorer.seek_into_the_danger_zone(target); + self.current_doc = self.scorer.doc(); + found + } fn doc(&self) -> DocId { self.current_doc diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 10e257c43..3e8677d98 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,5 +1,5 @@ +use super::size_hint::estimate_intersection; use crate::docset::{DocSet, TERMINATED}; -use crate::query::size_hint::estimate_intersection; use crate::query::term_query::TermScorer; use crate::query::{EmptyScorer, Scorer}; use crate::{DocId, Score}; @@ -12,6 +12,9 @@ use crate::{DocId, Score}; /// For better performance, the function uses a /// specialized implementation if the two /// shortest scorers are `TermScorer`s. +/// +/// num_docs_segment is the number of documents in the segment. It is used for estimating the +/// `size_hint` of the intersection. pub fn intersect_scorers( mut scorers: Vec>, num_docs_segment: u32, @@ -105,32 +108,44 @@ impl DocSet for Intersection DocId { let (left, right) = (&mut self.left, &mut self.right); let mut candidate = left.advance(); + if candidate == TERMINATED { + return TERMINATED; + } - 'outer: loop { + loop { // In the first part we look for a document in the intersection // of the two rarest `DocSet` in the intersection. loop { - let right_doc = right.seek(candidate); - candidate = left.seek(right_doc); - if candidate == right_doc { + if right.seek_into_the_danger_zone(candidate) { break; } + let right_doc = right.doc(); + // TODO: Think about which value would make sense here + // It depends on the DocSet implementation, when a seek would outweigh an advance. + if right_doc > candidate.wrapping_add(100) { + candidate = left.seek(right_doc); + } else { + candidate = left.advance(); + } + if candidate == TERMINATED { + return TERMINATED; + } } debug_assert_eq!(left.doc(), right.doc()); - // test the remaining scorers; - for docset in self.others.iter_mut() { - let seek_doc = docset.seek(candidate); - if seek_doc > candidate { - candidate = left.seek(seek_doc); - continue 'outer; - } + // test the remaining scorers + if self + .others + .iter_mut() + .all(|docset| docset.seek_into_the_danger_zone(candidate)) + { + debug_assert_eq!(candidate, self.left.doc()); + debug_assert_eq!(candidate, self.right.doc()); + debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); + return candidate; } - debug_assert_eq!(candidate, self.left.doc()); - debug_assert_eq!(candidate, self.right.doc()); - debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); - return candidate; + candidate = left.advance(); } } @@ -146,6 +161,19 @@ impl DocSet for Intersection bool { + self.left.seek_into_the_danger_zone(target) + && self.right.seek_into_the_danger_zone(target) + && self + .others + .iter_mut() + .all(|docset| docset.seek_into_the_danger_zone(target)) + } + fn doc(&self) -> DocId { self.left.doc() } @@ -181,6 +209,8 @@ where #[cfg(test)] mod tests { + use proptest::prelude::*; + use super::Intersection; use crate::docset::{DocSet, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; @@ -270,4 +300,38 @@ mod tests { let intersection = Intersection::new(vec![a, b, c], 10); assert_eq!(intersection.doc(), TERMINATED); } + + // Strategy to generate sorted and deduplicated vectors of u32 document IDs + fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy> { + prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| { + vec.sort(); + vec.dedup(); + vec + }) + } + + proptest! { + #[test] + fn prop_test_intersection_consistency( + a in sorted_deduped_vec(100, 10), + b in sorted_deduped_vec(100, 10), + num_docs in 100u32..500u32 + ) { + let left = VecDocSet::from(a.clone()); + let right = VecDocSet::from(b.clone()); + let mut intersection = Intersection::new(vec![left, right], num_docs); + + let expected: Vec = a.iter() + .cloned() + .filter(|doc| b.contains(doc)) + .collect(); + + for expected_doc in expected { + assert_eq!(intersection.doc(), expected_doc); + intersection.advance(); + } + assert_eq!(intersection.doc(), TERMINATED); + } + + } } diff --git a/src/query/mod.rs b/src/query/mod.rs index d609a0402..0bc865921 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -70,9 +70,83 @@ pub use self::weight::Weight; #[cfg(test)] mod tests { + use crate::collector::TopDocs; + use crate::query::phrase_query::tests::create_index; use crate::query::QueryParser; use crate::schema::{Schema, TEXT}; - use crate::{Index, Term}; + use crate::{DocAddress, Index, Term}; + + #[test] + pub fn test_mixed_intersection_and_union() -> crate::Result<()> { + let index = create_index(&["a b", "a c", "a b c", "b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 2]); + assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![2, 1, 0] + ); + + Ok(()) + } + + #[test] + pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> { + // Test 4096 skip in BufferedUnionScorer + let mut data: Vec<&str> = Vec::new(); + data.push("a b"); + let zz_data = vec!["z z"; 5000]; + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a c"]); + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a b c", "b"]); + let index = create_index(&data)?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 10002]); + assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![10002, 5001, 0] + ); + + Ok(()) + } #[test] fn test_query_terms() { diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index 14933f3ae..cc7bb7886 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -193,6 +193,14 @@ impl DocSet for PhrasePrefixScorer { self.advance() } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + if self.phrase_scorer.seek_into_the_danger_zone(target) { + self.matches_prefix() + } else { + false + } + } + fn doc(&self) -> DocId { self.phrase_scorer.doc() } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 12a94dce3..4f8541cd2 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -382,8 +382,9 @@ impl PhraseScorer { PostingsWithOffset::new(postings, (max_offset - offset) as u32) }) .collect::>(); + let intersection_docset = Intersection::new(postings_with_offsets, num_docs); let mut scorer = PhraseScorer { - intersection_docset: Intersection::new(postings_with_offsets, num_docs), + intersection_docset, num_terms: num_docsets, left_positions: Vec::with_capacity(100), right_positions: Vec::with_capacity(100), @@ -529,20 +530,34 @@ impl DocSet for PhraseScorer { self.advance() } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + debug_assert!(target >= self.doc()); + if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() { + return true; + } + false + } + fn doc(&self) -> DocId { self.intersection_docset.doc() } fn size_hint(&self) -> u32 { - self.intersection_docset.size_hint() + // We adjust the intersection estimate, since actual phrase hits are much lower than where + // the all appear. + // The estimate should depend on average field length, e.g. if the field is really short + // a phrase hit is more likely + self.intersection_docset.size_hint() / (10 * self.num_terms as u32) } /// Returns a best-effort hint of the /// cost to drive the docset. fn cost(&self) -> u64 { - // Evaluating phrase matches is generally more expensive than simple term matches, - // as it requires loading and comparing positions. Use a conservative multiplier - // based on the number of terms. + // While determing a potential hit is cheap for phrases, evaluating an actual hit is + // expensive since it requires to load positions for a doc and check if they are next to + // each other. + // So the cost estimation would be the number of times we need to check if a doc is a hit * + // 10 * self.num_terms. self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64 } } diff --git a/src/query/range_query/fast_field_range_doc_set.rs b/src/query/range_query/fast_field_range_doc_set.rs index 5a76f7e9d..24d2b1fe3 100644 --- a/src/query/range_query/fast_field_range_doc_set.rs +++ b/src/query/range_query/fast_field_range_doc_set.rs @@ -92,6 +92,9 @@ impl RangeDocSet { /// Returns true if more data could be fetched fn fetch_block(&mut self) { + if self.next_fetch_start >= self.column.num_docs() { + return; + } const MAX_HORIZON: u32 = 100_000; while self.loaded_docs.is_empty() { let finished_to_end = self.fetch_horizon(self.fetch_horizon); @@ -116,10 +119,10 @@ impl RangeDocSet { fn fetch_horizon(&mut self, horizon: u32) -> bool { let mut finished_to_end = false; - let limit = self.column.num_docs(); - let mut end = self.next_fetch_start + horizon; - if end >= limit { - end = limit; + let num_docs = self.column.num_docs(); + let mut fetch_end = self.next_fetch_start + horizon; + if fetch_end >= num_docs { + fetch_end = num_docs; finished_to_end = true; } @@ -127,7 +130,7 @@ impl RangeDocSet { let doc_buffer: &mut Vec = self.loaded_docs.get_cleared_data(); self.column.get_docids_for_value_range( self.value_range.clone(), - self.next_fetch_start..end, + self.next_fetch_start..fetch_end, doc_buffer, ); if let Some(last_doc) = last_doc { @@ -135,7 +138,7 @@ impl RangeDocSet { self.loaded_docs.next(); } } - self.next_fetch_start = end; + self.next_fetch_start = fetch_end; finished_to_end } @@ -147,9 +150,6 @@ impl DocSet for RangeDocSe if let Some(docid) = self.loaded_docs.next() { return docid; } - if self.next_fetch_start >= self.column.num_docs() { - return TERMINATED; - } self.fetch_block(); self.loaded_docs.current().unwrap_or(TERMINATED) } @@ -185,15 +185,25 @@ impl DocSet for RangeDocSe } fn size_hint(&self) -> u32 { - self.column.num_docs() + // TODO: Implement a better size hint + self.column.num_docs() / 10 } /// Returns a best-effort hint of the /// cost to drive the docset. fn cost(&self) -> u64 { - // Advancing the docset is relatively expensive since it scans the column. - // Keep cost relative to a term query driver; use num_docs as baseline. - self.column.num_docs() as u64 + // Advancing the docset is pretty expensive since it scans the whole column, there is no + // index currently (will change with an kd-tree) + // Since we use SIMD to scan the fast field range query we lower the cost a little bit, + // assuming that we hit 10% of the docs like in size_hint. + // + // If we would return a cost higher than num_docs, we would never choose ff range query as + // the driver in a DocSet, when intersecting a term query with a fast field. But + // it's the faster choice when the term query has a lot of docids and the range + // query has not. + // + // Ideally this would take the fast field codec into account + (self.column.num_docs() as f64 * 0.8) as u64 } } diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index 54cf0cad5..e379e108e 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -1598,449 +1598,3 @@ pub(crate) mod ip_range_tests { Ok(()) } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - use test::Bencher; - - use super::tests::*; - use super::*; - use crate::collector::Count; - use crate::query::QueryParser; - use crate::Index; - - fn get_index_0_to_100() -> Index { - let mut rng = StdRng::from_seed([1u8; 32]); - let num_vals = 100_000; - let docs: Vec<_> = (0..num_vals) - .map(|_i| { - let id_name = if rng.gen_bool(0.01) { - "veryfew".to_string() // 1% - } else if rng.gen_bool(0.1) { - "few".to_string() // 9% - } else { - "many".to_string() // 90% - }; - Doc { - id_name, - id: rng.gen_range(0..100), - } - }) - .collect(); - - create_index_from_docs(&docs, false) - } - - fn get_90_percent() -> RangeInclusive { - 0..=90 - } - - fn get_10_percent() -> RangeInclusive { - 0..=10 - } - - fn get_1_percent() -> RangeInclusive { - 10..=10 - } - - fn execute_query( - field: &str, - id_range: RangeInclusive, - suffix: &str, - index: &Index, - ) -> usize { - let gen_query_inclusive = |from: &u64, to: &u64| { - format!( - "{}:[{} TO {}] {}", - field, - &from.to_string(), - &to.to_string(), - suffix - ) - }; - - let query = gen_query_inclusive(id_range.start(), id_range.end()); - let query_from_text = |text: &str| { - QueryParser::for_index(index, vec![]) - .parse_query(text) - .unwrap() - }; - let query = query_from_text(&query); - let reader = index.reader().unwrap(); - let searcher = reader.searcher(); - searcher.search(&query, &(Count)).unwrap() - } - - #[bench] - fn bench_id_range_hit_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:veryfew", &index)); - } - - #[bench] - fn bench_id_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:many", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:few", &index)); - } - - #[bench] - fn bench_id_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index)); - } -} - -#[cfg(all(test, feature = "unstable"))] -mod bench_ip { - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - use test::Bencher; - - use super::ip_range_tests::*; - use super::*; - use crate::collector::Count; - use crate::query::QueryParser; - use crate::Index; - - fn get_index_0_to_100() -> Index { - let mut rng = StdRng::from_seed([1u8; 32]); - let num_vals = 100_000; - let docs: Vec<_> = (0..num_vals) - .map(|_i| { - let id = if rng.gen_bool(0.01) { - "veryfew".to_string() // 1% - } else if rng.gen_bool(0.1) { - "few".to_string() // 9% - } else { - "many".to_string() // 90% - }; - Doc { - id, - // Multiply by 1000, so that we create many buckets in the compact space - // The benches depend on this range to select n-percent of elements with the - // methods below. - ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000), - } - }) - .collect(); - - create_index_from_ip_docs(&docs) - } - - fn get_90_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(0); - let end = Ipv6Addr::from_u128(90 * 1000); - start..=end - } - - fn get_10_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(0); - let end = Ipv6Addr::from_u128(10 * 1000); - start..=end - } - - fn get_1_percent() -> RangeInclusive { - let start = Ipv6Addr::from_u128(10 * 1000); - let end = Ipv6Addr::from_u128(10 * 1000); - start..=end - } - - fn execute_query( - field: &str, - ip_range: RangeInclusive, - suffix: &str, - index: &Index, - ) -> usize { - let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| { - format!( - "{}:[{} TO {}] {}", - field, - &from.to_string(), - &to.to_string(), - suffix - ) - }; - - let query = gen_query_inclusive(ip_range.start(), ip_range.end()); - let query_from_text = |text: &str| { - QueryParser::for_index(index, vec![]) - .parse_query(text) - .unwrap() - }; - let query = query_from_text(&query); - let reader = index.reader().unwrap(); - let searcher = reader.searcher(); - searcher.search(&query, &(Count)).unwrap() - } - - #[bench] - fn bench_ip_range_hit_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_1_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_10_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ip", get_90_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_1_percent(), "AND id:veryfew", &index)); - } - - #[bench] - fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_10_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:many", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:few", &index)); - } - - #[bench] - fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) { - let index = get_index_0_to_100(); - - bench.iter(|| execute_query("ips", get_90_percent(), "AND id:veryfew", &index)); - } -} diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index be9e14692..45857567c 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -56,6 +56,11 @@ where self.req_scorer.seek(target) } + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + self.score_cache = None; + self.req_scorer.seek_into_the_danger_zone(target) + } + fn doc(&self) -> DocId { self.req_scorer.doc() } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 5c020febd..293aa7871 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -98,14 +98,17 @@ impl TermScorer { } impl DocSet for TermScorer { + #[inline] fn advance(&mut self) -> DocId { self.postings.advance() } + #[inline] fn seek(&mut self, target: DocId) -> DocId { self.postings.seek(target) } + #[inline] fn doc(&self) -> DocId { self.postings.doc() } diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index 3c726b8a7..70299ad6f 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32; // This function is similar except that it does is not unstable, and // it does not keep the original vector ordering. // -// Also, it does not "yield" any elements. +// Elements are dropped and not yielded. fn unordered_drain_filter(v: &mut Vec, mut predicate: P) where P: FnMut(&mut T) -> bool { let mut i = 0; @@ -143,6 +143,12 @@ impl BufferedUnionScorer bool { + // wrapping_sub, because target may be < window_start_doc + let gap = target.wrapping_sub(self.window_start_doc); + gap < HORIZON + } } impl DocSet for BufferedUnionScorer @@ -217,7 +223,27 @@ where } } - // TODO Also implement `count` with deletes efficiently. + fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + if self.is_in_horizon(target) { + // Our value is within the buffered horizon and the docset may already have been + // processed and removed, so we need to use seek, which uses the regular advance. + self.seek(target) == target + } else { + // The docsets are not in the buffered range, so we can use seek_into_the_danger_zone + // of the underlying docsets + let is_hit = self + .docsets + .iter_mut() + .any(|docset| docset.seek_into_the_danger_zone(target)); + + // The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone` + // returns true. + if is_hit { + self.seek(target); + } + is_hit + } + } fn doc(&self) -> DocId { self.doc @@ -231,6 +257,7 @@ where self.docsets.iter().map(|docset| docset.cost()).sum() } + // TODO Also implement `count` with deletes efficiently. fn count_including_deleted(&mut self) -> u32 { if self.doc == TERMINATED { return 0; diff --git a/src/query/union/simple_union.rs b/src/query/union/simple_union.rs index 61cbb94b6..b153a7f22 100644 --- a/src/query/union/simple_union.rs +++ b/src/query/union/simple_union.rs @@ -92,6 +92,7 @@ impl DocSet for SimpleUnion { } fn size_hint(&self) -> u32 { + // TODO: use estimate_union self.docsets .iter() .map(|docset| docset.size_hint()) From 75d7989cc693a8c24d40bfcf0036eb5ee92b1c98 Mon Sep 17 00:00:00 2001 From: ChangRui-Ryan Date: Wed, 31 Dec 2025 19:00:53 +0800 Subject: [PATCH 14/26] add benchmark for boolean query with range sub query (#2787) --- Cargo.toml | 4 + benches/bool_queries_with_range.rs | 288 +++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 benches/bool_queries_with_range.rs diff --git a/Cargo.toml b/Cargo.toml index 28cd49b07..10d1c8400 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,3 +184,7 @@ harness = false [[bench]] name = "range_queries" harness = false + +[[bench]] +name = "bool_queries_with_range" +harness = false diff --git a/benches/bool_queries_with_range.rs b/benches/bool_queries_with_range.rs new file mode 100644 index 000000000..9123ccf3a --- /dev/null +++ b/benches/bool_queries_with_range.rs @@ -0,0 +1,288 @@ +use binggan::{black_box, BenchGroup, BenchRunner}; +use rand::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs}; +use tantivy::query::{Query, QueryParser}; +use tantivy::schema::{Schema, FAST, INDEXED, TEXT}; +use tantivy::{doc, Index, Order, ReloadPolicy, Searcher}; + +#[derive(Clone)] +struct BenchIndex { + #[allow(dead_code)] + index: Index, + searcher: Searcher, + query_parser: QueryParser, +} + +fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex { + // Unified schema + let mut schema_builder = Schema::builder(); + let f_title = schema_builder.add_text_field("title", TEXT); + let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED); + let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED); + let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST); + let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + + // Populate index with stable RNG for reproducibility. + let mut rng = StdRng::from_seed([7u8; 32]); + + { + let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap(); + + match distribution { + "dense" => { + for doc_id in 0..num_docs { + // Always add title to avoid empty documents + let title_token = if rng.gen_bool(p_title_a as f64) { + "a" + } else { + "b" + }; + + let num_rand = rng.gen_range(0u64..1000u64); + + let num_asc = (doc_id / 10000) as u64; + + writer + .add_document(doc!( + f_title=>title_token, + f_num_rand=>num_rand, + f_num_asc=>num_asc, + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + "sparse" => { + for doc_id in 0..num_docs { + // Always add title to avoid empty documents + let title_token = if rng.gen_bool(p_title_a as f64) { + "a" + } else { + "b" + }; + + let num_rand = rng.gen_range(0u64..10000000u64); + + let num_asc = doc_id as u64; + + writer + .add_document(doc!( + f_title=>title_token, + f_num_rand=>num_rand, + f_num_asc=>num_asc, + f_num_rand_fast=>num_rand, + f_num_asc_fast=>num_asc, + )) + .unwrap(); + } + } + _ => { + panic!("Unsupported distribution type"); + } + } + writer.commit().unwrap(); + } + + // Prepare reader/searcher once. + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::Manual) + .try_into() + .unwrap(); + let searcher = reader.searcher(); + + // Build query parser for title field + let qp_title = QueryParser::for_index(&index, vec![f_title]); + + BenchIndex { + index, + searcher, + query_parser: qp_title, + } +} + +fn main() { + // Prepare corpora with varying scenarios + let scenarios = vec![ + ( + "dense and 99% a".to_string(), + 10_000_000, + 0.99, + "dense", + 0, + 9, + ), + ( + "dense and 99% a".to_string(), + 10_000_000, + 0.99, + "dense", + 990, + 999, + ), + ( + "sparse and 99% a".to_string(), + 10_000_000, + 0.99, + "sparse", + 0, + 9, + ), + ( + "sparse and 99% a".to_string(), + 10_000_000, + 0.99, + "sparse", + 9_999_990, + 9_999_999, + ), + ]; + + let mut runner = BenchRunner::new(); + for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios { + // Build index for this scenario + let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution); + + // Create benchmark group + let mut group = runner.new_group(); + + // Now set the name (this moves scenario_id) + group.set_name(scenario_id); + + // Define all four field types + let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"]; + + // Define the three terms we want to test with + let terms = ["a", "b", "z"]; + + // Generate all combinations of terms and field names + let mut queries = Vec::new(); + for &term in &terms { + for &field_name in &field_names { + let query_str = format!( + "{} AND {}:[{} TO {}]", + term, field_name, range_low, range_high + ); + queries.push((query_str, field_name.to_string())); + } + } + + let query_str = format!( + "{}:[{} TO {}] AND {}:[{} TO {}]", + "num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high + ); + queries.push((query_str, "num_asc_fast".to_string())); + + // Run all benchmark tasks for each query and its corresponding field name + for (query_str, field_name) in queries { + run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name); + } + + group.run(); + } +} + +/// Run all benchmark tasks for a given query string and field name +fn run_benchmark_tasks( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query_str: &str, + field_name: &str, +) { + // Test count + add_bench_task(bench_group, bench_index, query_str, Count, "count"); + + // Test all results + add_bench_task( + bench_group, + bench_index, + query_str, + DocSetCollector, + "all results", + ); + + // Test top 100 by the field (if it's a FAST field) + if field_name.ends_with("_fast") { + // Ascending order + { + let collector_name = format!("top100_by_{}_asc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task( + bench_group, + bench_index, + query_str, + TopDocs::with_limit(100).order_by_fast_field::(field_name_owned, Order::Asc), + &collector_name, + ); + } + + // Descending order + { + let collector_name = format!("top100_by_{}_desc", field_name); + let field_name_owned = field_name.to_string(); + add_bench_task( + bench_group, + bench_index, + query_str, + TopDocs::with_limit(100).order_by_fast_field::(field_name_owned, Order::Desc), + &collector_name, + ); + } + } +} + +fn add_bench_task( + bench_group: &mut BenchGroup, + bench_index: &BenchIndex, + query_str: &str, + collector: C, + collector_name: &str, +) { + let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name); + let query = bench_index.query_parser.parse_query(query_str).unwrap(); + let search_task = SearchTask { + searcher: bench_index.searcher.clone(), + collector, + query, + }; + bench_group.register(task_name, move |_| black_box(search_task.run())); +} + +struct SearchTask { + searcher: Searcher, + collector: C, + query: Box, +} + +impl SearchTask { + #[inline(never)] + pub fn run(&self) -> usize { + let result = self.searcher.search(&self.query, &self.collector).unwrap(); + if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::() { + *count + } else if let Some(top_docs) = (&result as &dyn std::any::Any) + .downcast_ref::, tantivy::DocAddress)>>() + { + top_docs.len() + } else if let Some(top_docs) = + (&result as &dyn std::any::Any).downcast_ref::>() + { + top_docs.len() + } else if let Some(doc_set) = (&result as &dyn std::any::Any) + .downcast_ref::>() + { + doc_set.len() + } else { + eprintln!( + "Unknown collector result type: {:?}", + std::any::type_name::() + ); + 0 + } + } +} From b11605f045d3f8faf9324a7cc7a20bac41199948 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 31 Dec 2025 18:02:00 +0100 Subject: [PATCH 15/26] Addressing clippy comments (#2789) Co-authored-by: Paul Masurel --- src/indexer/delete_queue.rs | 11 ++--------- src/indexer/index_writer.rs | 2 +- src/indexer/mod.rs | 2 +- src/indexer/segment_register.rs | 2 +- src/query/boolean_query/mod.rs | 6 +++--- stacker/src/expull.rs | 9 ++++++--- 6 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/indexer/delete_queue.rs b/src/indexer/delete_queue.rs index 25c84fb36..7cc485e1d 100644 --- a/src/indexer/delete_queue.rs +++ b/src/indexer/delete_queue.rs @@ -28,19 +28,12 @@ struct InnerDeleteQueue { /// Several consumers can hold a reference to it. Delete operations /// get dropped/gc'ed when no more consumers are holding a reference /// to them. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct DeleteQueue { inner: Arc>, } impl DeleteQueue { - /// Creates a new empty delete queue. - pub fn new() -> DeleteQueue { - DeleteQueue { - inner: Arc::default(), - } - } - fn get_last_block(&self) -> Arc { { // try get the last block with simply acquiring the read lock. @@ -267,7 +260,7 @@ mod tests { #[test] fn test_deletequeue() { - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let make_op = |i: usize| DeleteOperation { opstamp: i as u64, diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 4ce5e1db5..1e07dd210 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -303,7 +303,7 @@ impl IndexWriter { let (document_sender, document_receiver) = crossbeam_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS); - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let current_opstamp = index.load_metas()?.opstamp; diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index a6d3cab38..d96344b60 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -4,7 +4,7 @@ //! `IndexWriter` is the main entry point for that, which created from //! [`Index::writer`](crate::Index::writer). -pub mod delete_queue; +pub(crate) mod delete_queue; pub(crate) mod path_to_unordered_id; pub(crate) mod doc_id_mapping; diff --git a/src/indexer/segment_register.rs b/src/indexer/segment_register.rs index 0e7046310..fa7bfafa4 100644 --- a/src/indexer/segment_register.rs +++ b/src/indexer/segment_register.rs @@ -117,7 +117,7 @@ mod tests { #[test] fn test_segment_register() { let inventory = SegmentMetaInventory::default(); - let delete_queue = DeleteQueue::new(); + let delete_queue = DeleteQueue::default(); let mut segment_register = SegmentRegister::default(); let segment_id_a = SegmentId::generate_random(); diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index a78cbe8e9..681881c11 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -678,7 +678,7 @@ mod tests { #[cfg(test)] mod proptest_boolean_query { use std::collections::{BTreeMap, HashSet}; - use std::ops::Bound; + use std::ops::{Bound, Range}; use proptest::collection::vec; use proptest::prelude::*; @@ -755,10 +755,10 @@ mod proptest_boolean_query { } } - fn doc_ids(num_docs: usize, num_fields: usize) -> impl Iterator { + fn doc_ids(num_docs: usize, num_fields: usize) -> Range { let permutations = 1 << num_fields; let copies = (num_docs as f32 / permutations as f32).ceil() as u32; - (0..(permutations * copies)).into_iter() + 0..(permutations * copies) } fn create_index_with_boolean_permutations( diff --git a/stacker/src/expull.rs b/stacker/src/expull.rs index 189dc95f2..3b6353b38 100644 --- a/stacker/src/expull.rs +++ b/stacker/src/expull.rs @@ -117,7 +117,7 @@ impl ExpUnrolledLinkedListWriter<'_> { fn get_block_size(block_num: u32) -> u16 { // Cap at 15 to prevent block sizes > 32KB // block_num can now be much larger than 15, but block size maxes out - let exp = block_num.min(15) as u32; + let exp: u32 = block_num.min(15u32); (1u32 << exp) as u16 } @@ -309,6 +309,7 @@ mod tests { } #[test] + #[allow(clippy::needless_range_loop)] fn test_large_dataset_simulation() { // Simulate the scenario: large arrays requiring many blocks // We write enough data to require thousands of blocks @@ -452,8 +453,10 @@ mod tests { fn test_increment_overflow_protection() { // Test that we panic gracefully if we somehow hit u32::MAX // This is extremely unlikely in practice (would require 128TB of data) - let mut eull = ExpUnrolledLinkedList::default(); - eull.block_num = u32::MAX; + let mut eull = ExpUnrolledLinkedList { + block_num: u32::MAX, + ..Default::default() + }; // This should panic with our custom error message eull.increment_num_blocks(); From 4987495ee486c757a381fccc42df5933f338a10e Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Fri, 2 Jan 2026 02:28:47 -0700 Subject: [PATCH 16/26] Add an erased `SortKeyComputer` to sort on types which are not known until runtime (#2770) * Remove PartialOrd bound on compared values. * Fix declared `SortKey` type of `impl<..> SortKeyComputer for (HeadSortKeyComputer, TailSortKeyComputer)` * Add a SortByOwnedValue implementation to provide a type-erased column. * Add support for comparing mismatched `OwnedValue` types. * Support JSON columns. * Refer to https://github.com/quickwit-oss/tantivy/issues/2776 * Rename to `SortByErasedType`. * Comment on transitivity. Co-authored-by: Paul Masurel * Fix clippy warnings in new code. --------- Co-authored-by: Paul Masurel --- src/collector/sort_key/mod.rs | 61 ++- src/collector/sort_key/order.rs | 113 +++++- src/collector/sort_key/sort_by_erased_type.rs | 361 ++++++++++++++++++ src/collector/sort_key/sort_by_score.rs | 2 +- .../sort_key/sort_by_static_fast_value.rs | 4 +- src/collector/sort_key/sort_by_string.rs | 6 +- src/collector/sort_key/sort_key_computer.rs | 32 +- src/collector/top_score_collector.rs | 5 +- src/schema/document/owned_value.rs | 25 ++ 9 files changed, 581 insertions(+), 28 deletions(-) create mode 100644 src/collector/sort_key/sort_by_erased_type.rs diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index 3bfb3b1c8..391873298 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -1,10 +1,12 @@ mod order; +mod sort_by_erased_type; mod sort_by_score; mod sort_by_static_fast_value; mod sort_by_string; mod sort_key_computer; pub use order::*; +pub use sort_by_erased_type::SortByErasedType; pub use sort_by_score::SortBySimilarityScore; pub use sort_by_static_fast_value::SortByStaticFastValue; pub use sort_by_string::SortByString; @@ -34,11 +36,13 @@ pub(crate) mod tests { use std::collections::HashMap; use std::ops::Range; - use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; + use crate::collector::sort_key::{ + SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString, + }; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, QueryParser}; - use crate::schema::{Schema, FAST, TEXT}; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; use crate::{DocAddress, Document, Index, Order, Score, Searcher}; fn make_index() -> crate::Result { @@ -313,11 +317,9 @@ pub(crate) mod tests { (SortBySimilarityScore, score_order), (SortByString::for_field("city"), city_order), )); - Ok(searcher - .search(&AllQuery, &top_collector)? - .into_iter() - .map(|(f, doc)| (f, ids[&doc])) - .collect()) + let results: Vec<((Score, Option), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) } assert_eq!( @@ -342,6 +344,51 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_order_by_score_then_owned_value() -> crate::Result<()> { + let index = make_index()?; + + type SortKey = (Score, OwnedValue); + + fn query( + index: &Index, + score_order: Order, + city_order: Order, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>(( + (SortBySimilarityScore, score_order), + (SortByErasedType::for_field("city"), city_order), + )); + let results: Vec<((Score, OwnedValue), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) + } + + assert_eq!( + &query(&index, Order::Asc, Order::Asc)?, + &[ + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Null), 3), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, Order::Desc)?, + &[ + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Null), 3), + ] + ); + Ok(()) + } + use proptest::prelude::*; proptest! { diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index e89154c96..c2f346901 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,11 +1,70 @@ use std::cmp::Ordering; +use columnar::MonotonicallyMappableToU64; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; -use crate::schema::Schema; +use crate::schema::{OwnedValue, Schema}; use crate::{DocId, Order, Score}; +fn compare_owned_value(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + match (lhs, rhs) { + (OwnedValue::Null, OwnedValue::Null) => Ordering::Equal, + (OwnedValue::Null, _) => { + if NULLS_FIRST { + Ordering::Less + } else { + Ordering::Greater + } + } + (_, OwnedValue::Null) => { + if NULLS_FIRST { + Ordering::Greater + } else { + Ordering::Less + } + } + (OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b), + (OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b), + (OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b), + (OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()), + (OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b), + (OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b), + (OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b), + (OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b), + (OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::I64(b)) => { + if *b < 0 { + Ordering::Greater + } else { + a.cmp(&(*b as u64)) + } + } + (OwnedValue::I64(a), OwnedValue::U64(b)) => { + if *a < 0 { + Ordering::Less + } else { + (*a as u64).cmp(b) + } + } + (OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (a, b) => { + let ord = a.discriminant_value().cmp(&b.discriminant_value()); + // If the discriminant is equal, it's because a new type was added, but hasn't been + // included in this `match` statement. + assert!( + ord != Ordering::Equal, + "Unimplemented comparison for type of {a:?}, {b:?}" + ); + ord + } + } +} + /// Comparator trait defining the order in which documents should be ordered. pub trait Comparator: Send + Sync + std::fmt::Debug + Default { /// Return the order between two values. @@ -29,6 +88,17 @@ impl Comparator for NaturalComparator { } } +/// A (partial) implementation of comparison for OwnedValue. +/// +/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with +/// mismatched types. The one exception is Null, for which we do define all comparisons. +impl Comparator for NaturalComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// Compare values in reverse (e.g. 2 < 1). /// /// When used with `TopDocs`, which reverses the order, this results in an @@ -121,6 +191,13 @@ impl Comparator for ReverseNoneIsLowerComparator { } } +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(rhs, lhs) + } +} + /// Compare values naturally, but treating `None` as higher than `Some`. /// /// When used with `TopDocs`, which reverses the order, this results in a @@ -185,6 +262,13 @@ impl Comparator for NaturalNoneIsHigherComparator { } } +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// An enum representing the different sort orders. #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] pub enum ComparatorEnum { @@ -404,11 +488,12 @@ impl SegmentSortKeyComput for SegmentSortKeyComputerWithComparator where TSegmentSortKeyComputer: SegmentSortKeyComputer, - TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send, + TSegmentSortKey: Clone + 'static + Sync + Send, TComparator: Comparator + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; + type SegmentComparator = TComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.segment_sort_key_computer.segment_sort_key(doc, score) @@ -432,6 +517,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::schema::OwnedValue; #[test] fn test_natural_none_is_higher() { @@ -455,4 +541,27 @@ mod tests { // compare(None, None) should be Equal. assert_eq!(comp.compare(&null, &null), Ordering::Equal); } + + #[test] + fn test_mixed_ownedvalue_compare() { + let u = OwnedValue::U64(10); + let i = OwnedValue::I64(10); + let f = OwnedValue::F64(10.0); + + let nc = NaturalComparator; + assert_eq!(nc.compare(&u, &i), Ordering::Equal); + assert_eq!(nc.compare(&u, &f), Ordering::Equal); + assert_eq!(nc.compare(&i, &f), Ordering::Equal); + + let u2 = OwnedValue::U64(11); + assert_eq!(nc.compare(&u2, &f), Ordering::Greater); + + let s = OwnedValue::Str("a".to_string()); + // Str < U64 + assert_eq!(nc.compare(&s, &u), Ordering::Less); + // Str < I64 + assert_eq!(nc.compare(&s, &i), Ordering::Less); + // Str < F64 + assert_eq!(nc.compare(&s, &f), Ordering::Less); + } } diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs new file mode 100644 index 000000000..d15dd130c --- /dev/null +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -0,0 +1,361 @@ +use columnar::{ColumnType, MonotonicallyMappableToU64}; + +use crate::collector::sort_key::{ + NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, +}; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::FastFieldNotAvailableError; +use crate::schema::OwnedValue; +use crate::{DateTime, DocId, Score}; + +/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score. +/// +/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders +/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to +/// use a SortKeyComputer implementation with a known-type at compile time. +#[derive(Debug, Clone)] +pub enum SortByErasedType { + /// Sort by a fast field + Field(String), + /// Sort by score + Score, +} + +impl SortByErasedType { + /// Creates a new sort key computer which will sort by the given fast field column, with type + /// erasure. + pub fn for_field(column_name: impl ToString) -> Self { + Self::Field(column_name.to_string()) + } + + /// Creates a new sort key computer which will sort by score, with type erasure. + pub fn for_score() -> Self { + Self::Score + } +} + +trait ErasedSegmentSortKeyComputer: Send + Sync { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; +} + +struct ErasedSegmentSortKeyComputerWrapper { + inner: C, + converter: F, +} + +impl ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper +where + C: SegmentSortKeyComputer> + Send + Sync, + F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static, +{ + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let val = self.inner.convert_segment_sort_key(sort_key); + (self.converter)(val) + } +} + +struct ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, +} + +impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into(); + Some(score_value.to_u64()) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let score_value: u64 = sort_key.expect("This implementation always produces a score."); + OwnedValue::F64(f64::from_u64(score_value)) + } +} + +impl SortKeyComputer for SortByErasedType { + type SortKey = OwnedValue; + type Child = ErasedColumnSegmentSortKeyComputer; + type Comparator = NaturalComparator; + + fn requires_scoring(&self) -> bool { + matches!(self, Self::Score) + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let inner: Box = match self { + Self::Field(column_name) => { + let fast_fields = segment_reader.fast_fields(); + // TODO: We currently double-open the column to avoid relying on the implementation + // details of `SortByString` or `SortByStaticFastValue`. Once + // https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should + // consider directly constructing the appropriate `SegmentSortKeyComputer` type for + // the column that we open here. + let (_column, column_type) = + fast_fields.u64_lenient(column_name)?.ok_or_else(|| { + FastFieldNotAvailableError { + field_name: column_name.to_owned(), + } + })?; + + match column_type { + ColumnType::Str => { + let computer = SortByString::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::U64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::I64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::F64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::Bool => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::DateTime => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null) + }, + }) + } + column_type => { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is of type {column_type:?}, which is not supported for \ + sorting by owned value yet.", + column_name + ))) + } + } + } + Self::Score => Box::new(ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, + }), + }; + Ok(ErasedColumnSegmentSortKeyComputer { inner }) + } +} + +pub struct ErasedColumnSegmentSortKeyComputer { + inner: Box, +} + +impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { + type SortKey = OwnedValue; + type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { + self.inner.convert_segment_sort_key(segment_sort_key) + } +} + +#[cfg(test)] +mod tests { + use crate::collector::sort_key::{ComparatorEnum, SortByErasedType}; + use crate::collector::TopDocs; + use crate::query::AllQuery; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_owned_u64() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null] + ); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("id"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null] + ); + } + + #[test] + fn test_sort_by_owned_string() { + let mut schema_builder = Schema::builder(); + let city_field = schema_builder.add_text_field("city", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(city_field => "tokyo")).unwrap(); + writer.add_document(doc!(city_field => "austin")).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("city"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![ + OwnedValue::Str("austin".to_string()), + OwnedValue::Str("tokyo".to_string()), + OwnedValue::Null + ] + ); + } + + #[test] + fn test_sort_by_owned_reverse() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)] + ); + } + + #[test] + fn test_sort_by_owned_score() { + let mut schema_builder = Schema::builder(); + let body_field = schema_builder.add_text_field("body", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(body_field => "a a")).unwrap(); + writer.add_document(doc!(body_field => "a")).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]); + let query = query_parser.parse_query("a").unwrap(); + + // Sort by score descending (Natural) + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_score(), ComparatorEnum::Natural)); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {key:?}"), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] > values[1]); + + // Sort by score ascending (ReverseNoneLower) + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_score(), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {key:?}"), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] < values[1]); + } +} diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index df8b0dd75..a23660e56 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore { impl SegmentSortKeyComputer for SortBySimilarityScore { type SortKey = Score; - type SegmentSortKey = Score; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index b38b8b034..44a4e1d8d 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -34,9 +34,7 @@ impl SortByStaticFastValue { impl SortKeyComputer for SortByStaticFastValue { type Child = SortByFastValueSegmentSortKeyComputer; - type SortKey = Option; - type Comparator = NaturalComparator; fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> { @@ -84,8 +82,8 @@ pub struct SortByFastValueSegmentSortKeyComputer { impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 41ef22e9b..2dd0b4592 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -30,9 +30,7 @@ impl SortByString { impl SortKeyComputer for SortByString { type SortKey = Option; - type Child = ByStringColumnSegmentSortKeyComputer; - type Comparator = NaturalComparator; fn segment_sort_key_computer( @@ -50,8 +48,8 @@ pub struct ByStringColumnSegmentSortKeyComputer { impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option { @@ -60,6 +58,8 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { } fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { + // TODO: Individual lookups to the dictionary like this are very likely to repeatedly + // decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776 let term_ord = term_ord_opt?; let str_column = self.str_column_opt.as_ref()?; let mut bytes = Vec::new(); diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d56fa7cd0..6aab919a9 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -12,13 +12,21 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader}; /// It is the segment local version of the [`SortKeyComputer`]. pub trait SegmentSortKeyComputer: 'static { /// The final score being emitted. - type SortKey: 'static + PartialOrd + Send + Sync + Clone; + type SortKey: 'static + Send + Sync + Clone; /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. /// /// It is typically small like a `u64`, and is meant to be converted /// to the final score at the end of the collection of the segment. - type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + type SegmentSortKey: 'static + Clone + Send + Sync + Clone; + + /// Comparator type. + type SegmentComparator: Comparator + 'static; + + /// Returns the segment sort key comparator. + fn segment_comparator(&self) -> Self::SegmentComparator { + Self::SegmentComparator::default() + } /// Computes the sort key for the given document and score. fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; @@ -47,7 +55,7 @@ pub trait SegmentSortKeyComputer: 'static { left: &Self::SegmentSortKey, right: &Self::SegmentSortKey, ) -> Ordering { - NaturalComparator.compare(left, right) + self.segment_comparator().compare(left, right) } /// Implementing this method makes it possible to avoid computing @@ -81,7 +89,7 @@ pub trait SegmentSortKeyComputer: 'static { /// the sort key at a segment scale. pub trait SortKeyComputer: Sync { /// The sort key type. - type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug; + type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug; /// Type of the associated [`SegmentSortKeyComputer`]. type Child: SegmentSortKeyComputer; /// Comparator type. @@ -136,10 +144,7 @@ where HeadSortKeyComputer: SortKeyComputer, TailSortKeyComputer: SortKeyComputer, { - type SortKey = ( - ::SortKey, - ::SortKey, - ); + type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey); type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); type Comparator = ( @@ -188,6 +193,11 @@ where TailSegmentSortKeyComputer::SegmentSortKey, ); + type SegmentComparator = ( + HeadSegmentSortKeyComputer::SegmentComparator, + TailSegmentSortKeyComputer::SegmentComparator, + ); + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. /// @@ -269,11 +279,12 @@ impl SegmentSortKeyComputer for MappedSegmentSortKeyComputer where T: SegmentSortKeyComputer, - PreviousScore: 'static + Clone + Send + Sync + PartialOrd, - NewScore: 'static + Clone + Send + Sync + PartialOrd, + PreviousScore: 'static + Clone + Send + Sync, + NewScore: 'static + Clone + Send + Sync, { type SortKey = NewScore; type SegmentSortKey = T::SegmentSortKey; + type SegmentComparator = T::SegmentComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.sort_key_computer.segment_sort_key(doc, score) @@ -463,6 +474,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { (self)(doc) diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 3c3f1beb9..0ce1c611a 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -324,7 +324,7 @@ impl TopDocs { sort_key_computer: impl SortKeyComputer + Send + 'static, ) -> impl Collector> where - TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug, + TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug, { TopBySortKeyCollector::new(sort_key_computer, self.doc_range()) } @@ -445,7 +445,7 @@ where F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn, TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, TweakScoreSegmentSortKeyComputer: - SegmentSortKeyComputer, + SegmentSortKeyComputer, TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, { type SortKey = TSortKey; @@ -480,6 +480,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { (self.sort_key_fn)(doc, score) diff --git a/src/schema/document/owned_value.rs b/src/schema/document/owned_value.rs index 9fbf1f8c2..49a6b1ac7 100644 --- a/src/schema/document/owned_value.rs +++ b/src/schema/document/owned_value.rs @@ -58,6 +58,31 @@ impl AsRef for OwnedValue { } } +impl OwnedValue { + /// Returns a u8 discriminant value for the `OwnedValue` variant. + /// + /// This can be used to sort `OwnedValue` instances by their type. + pub fn discriminant_value(&self) -> u8 { + match self { + OwnedValue::Null => 0, + OwnedValue::Str(_) => 1, + OwnedValue::PreTokStr(_) => 2, + // It is key to make sure U64, I64, F64 are grouped together in there, otherwise we + // might be breaking transivity. + OwnedValue::U64(_) => 3, + OwnedValue::I64(_) => 4, + OwnedValue::F64(_) => 5, + OwnedValue::Bool(_) => 6, + OwnedValue::Date(_) => 7, + OwnedValue::Facet(_) => 8, + OwnedValue::Bytes(_) => 9, + OwnedValue::Array(_) => 10, + OwnedValue::Object(_) => 11, + OwnedValue::IpAddr(_) => 12, + } + } +} + impl<'a> Value<'a> for &'a OwnedValue { type ArrayIter = std::slice::Iter<'a, OwnedValue>; type ObjectIter = ObjectMapIter<'a>; From 6443b631773739155c6c3737a22cec740134b8f1 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Fri, 2 Jan 2026 10:32:37 +0100 Subject: [PATCH 17/26] document 1bit hole and some queries supporting running with just fastfield (#2779) * add small doc on some queries using fast field when not indexed * document 1 unused bit in skiplist --- src/postings/skip.rs | 12 ++++++++---- src/schema/mod.rs | 4 ++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/postings/skip.rs b/src/postings/skip.rs index c36690444..dd762ca46 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED}; // doc num bits uses the following encoding: // given 0b a b cdefgh -// |1|2| 3 | +// |1|2|3| 4 | // - 1: unused // - 2: is delta-1 encoded. 0 if not, 1, if yes -// - 3: a 6 bit number in 0..=32, the actual bitwidth +// - 3: unused +// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32 +// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can +// be represented on 31b without delta encoding. fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 { + assert!(bitwidth < 32); bitwidth | ((delta_1 as u8) << 6) } fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) { let delta_1 = ((raw_bitwidth >> 6) & 1) != 0; - let bitwidth = raw_bitwidth & 0x3f; + let bitwidth = raw_bitwidth & 0x1f; (bitwidth, delta_1) } @@ -430,7 +434,7 @@ mod tests { #[test] fn test_encode_decode_bitwidth() { - for bitwidth in 0..=32 { + for bitwidth in 0..32 { for delta_1 in [false, true] { assert_eq!( (bitwidth, delta_1), diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 1cd4b7243..c8af359d9 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -98,6 +98,10 @@ //! make it possible to access the value given the doc id rapidly. This is useful if the value //! of the field is required during scoring or collection for instance. //! +//! Some queries may leverage Fast fields when run on a field that is not indexed. This can be +//! handy if that kind of request is infrequent, however note that searching on a Fast field is +//! generally much slower than searching in an index. +//! //! ``` //! use tantivy::schema::*; //! let mut schema_builder = Schema::builder(); From 242a1531bf5eefa66dcc18d3c1f42781c6f2e6ff Mon Sep 17 00:00:00 2001 From: PSeitz Date: Fri, 2 Jan 2026 18:30:51 +0800 Subject: [PATCH 18/26] fix flaky test (#2784) Signed-off-by: Pascal Seitz --- src/indexer/delete_queue.rs | 27 ++++++++++++++------------- src/indexer/mod.rs | 1 + src/lib.rs | 8 ++++++-- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/indexer/delete_queue.rs b/src/indexer/delete_queue.rs index 7cc485e1d..1a269caed 100644 --- a/src/indexer/delete_queue.rs +++ b/src/indexer/delete_queue.rs @@ -4,19 +4,20 @@ use std::sync::{Arc, RwLock, Weak}; use super::operation::DeleteOperation; use crate::Opstamp; -// The DeleteQueue is similar in conceptually to a multiple -// consumer single producer broadcast channel. -// -// All consumer will receive all messages. -// -// Consumer of the delete queue are holding a `DeleteCursor`, -// which points to a specific place of the `DeleteQueue`. -// -// New consumer can be created in two ways -// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete operation -// (and some or none of the past operations... The client is in charge of checking the opstamps.). -// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can -// now advance independently from the original cursor. +/// The DeleteQueue is similar in conceptually to a multiple +/// consumer single producer broadcast channel. +/// +/// All consumer will receive all messages. +/// +/// Consumer of the delete queue are holding a `DeleteCursor`, +/// which points to a specific place of the `DeleteQueue`. +/// +/// New consumer can be created in two ways +/// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete +/// operation (and some or none of the past operations... The client is in charge of checking the +/// opstamps.). +/// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can +/// now advance independently from the original cursor. #[derive(Default)] struct InnerDeleteQueue { writer: Vec, diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index d96344b60..53cc57034 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -4,6 +4,7 @@ //! `IndexWriter` is the main entry point for that, which created from //! [`Index::writer`](crate::Index::writer). +/// Delete queue implementation for broadcasting delete operations to consumers. pub(crate) mod delete_queue; pub(crate) mod path_to_unordered_id; diff --git a/src/lib.rs b/src/lib.rs index 22eab343a..7890d6188 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ //! //! ```rust //! # use std::path::Path; +//! # use std::fs; //! # use tempfile::TempDir; //! # use tantivy::collector::TopDocs; //! # use tantivy::query::QueryParser; @@ -27,8 +28,11 @@ //! # // Let's create a temporary directory for the //! # // sake of this example //! # if let Ok(dir) = TempDir::new() { -//! # run_example(dir.path()).unwrap(); -//! # dir.close().unwrap(); +//! # let index_path = dir.path().join("index"); +//! # // In case the directory already exists, we remove it +//! # let _ = fs::remove_dir_all(&index_path); +//! # fs::create_dir_all(&index_path).unwrap(); +//! # run_example(&index_path).unwrap(); //! # } //! # } //! # From 735c588f4f89b67eab6f05285cd93f8e14816574 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Fri, 2 Jan 2026 19:06:51 +0800 Subject: [PATCH 19/26] fix union performance regression (#2790) * add inlines * fix union performance regression Remove unwrap from hotpath generates better assembly. closes #2788 --- src/collector/sort_key/order.rs | 2 +- src/postings/mod.rs | 1 + src/query/all_query.rs | 1 + src/query/boost_query.rs | 1 + src/query/const_score_query.rs | 1 + src/query/disjunction.rs | 2 ++ src/query/empty_query.rs | 1 + src/query/exclude.rs | 1 + src/query/intersection.rs | 3 +++ src/query/phrase_prefix_query/phrase_prefix_scorer.rs | 2 ++ src/query/phrase_query/phrase_scorer.rs | 1 + src/query/reqopt_scorer.rs | 1 + src/query/score_combiner.rs | 3 +++ src/query/scorer.rs | 1 + src/query/term_query/term_scorer.rs | 1 + src/query/union/buffered_union.rs | 4 ++++ 16 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index c2f346901..3cac357ad 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -84,7 +84,7 @@ pub struct NaturalComparator; impl Comparator for NaturalComparator { #[inline(always)] fn compare(&self, lhs: &T, rhs: &T) -> Ordering { - lhs.partial_cmp(rhs).unwrap() + lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal) } } diff --git a/src/postings/mod.rs b/src/postings/mod.rs index efc0e069d..b9c400859 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -527,6 +527,7 @@ pub(crate) mod tests { } impl Scorer for UnoptimizedDocSet { + #[inline] fn score(&mut self) -> Score { self.0.score() } diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 612d7408c..5431a3a1b 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -105,6 +105,7 @@ impl DocSet for AllScorer { } impl Scorer for AllScorer { + #[inline] fn score(&mut self) -> Score { 1.0 } diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index ecbf3d8d6..cc4c10f7a 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -134,6 +134,7 @@ impl DocSet for BoostScorer { } impl Scorer for BoostScorer { + #[inline] fn score(&mut self) -> Score { self.underlying.score() * self.boost } diff --git a/src/query/const_score_query.rs b/src/query/const_score_query.rs index 570c7feca..d07e6a96f 100644 --- a/src/query/const_score_query.rs +++ b/src/query/const_score_query.rs @@ -137,6 +137,7 @@ impl DocSet for ConstScorer { } impl Scorer for ConstScorer { + #[inline] fn score(&mut self) -> Score { self.score } diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index b2f1080fc..ca7eab20d 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -173,6 +173,7 @@ impl DocSet impl Scorer for Disjunction { + #[inline] fn score(&mut self) -> Score { self.current_score } @@ -307,6 +308,7 @@ mod tests { } impl Scorer for DummyScorer { + #[inline] fn score(&mut self) -> Score { self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0) } diff --git a/src/query/empty_query.rs b/src/query/empty_query.rs index 86ff84c08..2fa1772bd 100644 --- a/src/query/empty_query.rs +++ b/src/query/empty_query.rs @@ -55,6 +55,7 @@ impl DocSet for EmptyScorer { } impl Scorer for EmptyScorer { + #[inline] fn score(&mut self) -> Score { 0.0 } diff --git a/src/query/exclude.rs b/src/query/exclude.rs index 0b13e66e0..15e609c1e 100644 --- a/src/query/exclude.rs +++ b/src/query/exclude.rs @@ -84,6 +84,7 @@ where TScorer: Scorer, TDocSetExclude: DocSet + 'static, { + #[inline] fn score(&mut self) -> Score { self.underlying_docset.score() } diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 3e8677d98..d536dcf05 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -105,6 +105,7 @@ impl Intersection { } impl DocSet for Intersection { + #[inline] fn advance(&mut self) -> DocId { let (left, right) = (&mut self.left, &mut self.right); let mut candidate = left.advance(); @@ -174,6 +175,7 @@ impl DocSet for Intersection DocId { self.left.doc() } @@ -200,6 +202,7 @@ where TScorer: Scorer, TOtherScorer: Scorer, { + #[inline] fn score(&mut self) -> Score { self.left.score() + self.right.score() diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index cc7bb7886..8b03089fa 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -81,6 +81,7 @@ impl DocSet for PhraseKind { } impl Scorer for PhraseKind { + #[inline] fn score(&mut self) -> Score { match self { PhraseKind::SinglePrefix { positions, .. } => { @@ -215,6 +216,7 @@ impl DocSet for PhrasePrefixScorer { } impl Scorer for PhrasePrefixScorer { + #[inline] fn score(&mut self) -> Score { // TODO modify score?? self.phrase_scorer.score() diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 4f8541cd2..108783b40 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -563,6 +563,7 @@ impl DocSet for PhraseScorer { } impl Scorer for PhraseScorer { + #[inline] fn score(&mut self) -> Score { let doc = self.doc(); let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index 45857567c..bed99f5b7 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -81,6 +81,7 @@ where TOptScorer: Scorer, TScoreCombiner: ScoreCombiner, { + #[inline] fn score(&mut self) -> Score { if let Some(score) = self.score_cache { return score; diff --git a/src/query/score_combiner.rs b/src/query/score_combiner.rs index a49f8b104..2fe760c3d 100644 --- a/src/query/score_combiner.rs +++ b/src/query/score_combiner.rs @@ -29,6 +29,7 @@ impl ScoreCombiner for DoNothingCombiner { fn clear(&mut self) {} + #[inline] fn score(&self) -> Score { 1.0 } @@ -49,6 +50,7 @@ impl ScoreCombiner for SumCombiner { self.score = 0.0; } + #[inline] fn score(&self) -> Score { self.score } @@ -86,6 +88,7 @@ impl ScoreCombiner for DisjunctionMaxCombiner { self.sum = 0.0; } + #[inline] fn score(&self) -> Score { self.max + (self.sum - self.max) * self.tie_breaker } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 69448042f..e91fc2fbc 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -18,6 +18,7 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { impl_downcast!(Scorer); impl Scorer for Box { + #[inline] fn score(&mut self) -> Score { self.deref_mut().score() } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 293aa7871..00fb8ca0b 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -119,6 +119,7 @@ impl DocSet for TermScorer { } impl Scorer for TermScorer { + #[inline] fn score(&mut self) -> Score { let fieldnorm_id = self.fieldnorm_id(); let term_freq = self.term_freq(); diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index 70299ad6f..ee554e357 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -128,6 +128,7 @@ impl BufferedUnionScorer bool { while self.bucket_idx < HORIZON_NUM_TINYBITSETS { if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() { @@ -156,6 +157,7 @@ where TScorer: Scorer, TScoreCombiner: ScoreCombiner, { + #[inline] fn advance(&mut self) -> DocId { if self.advance_buffered() { return self.doc; @@ -245,6 +247,7 @@ where } } + #[inline] fn doc(&self) -> DocId { self.doc } @@ -286,6 +289,7 @@ where TScoreCombiner: ScoreCombiner, TScorer: Scorer, { + #[inline] fn score(&mut self) -> Score { self.score } From 77505c3d03727baa7acb4fc7bd1ade645aed4141 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 2 Jan 2026 12:40:42 +0100 Subject: [PATCH 20/26] Making stemming optional. (#2791) Fixed code and CI to run on no default features. Co-authored-by: Paul Masurel --- .github/workflows/test.yml | 30 +++++++--- Cargo.toml | 5 +- .../{ => mmap_directory}/file_watcher.rs | 0 .../mod.rs} | 4 +- src/directory/mod.rs | 1 - src/index/index_meta.rs | 9 ++- src/indexer/segment_writer.rs | 8 +-- src/lib.rs | 6 +- src/snippet/mod.rs | 4 +- src/store/mod.rs | 3 +- src/store/reader.rs | 4 +- src/tokenizer/mod.rs | 57 ++----------------- src/tokenizer/stemmer.rs | 57 +++++++++++++++++++ src/tokenizer/tokenizer_manager.rs | 23 ++++---- 14 files changed, 123 insertions(+), 88 deletions(-) rename src/directory/{ => mmap_directory}/file_watcher.rs (100%) rename src/directory/{mmap_directory.rs => mmap_directory/mod.rs} (99%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 13080f11d..3a6ba2df9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,11 +39,11 @@ jobs: - name: Check Formatting run: cargo +nightly fmt --all -- --check - + - name: Check Stable Compilation run: cargo build --all-features - + - name: Check Bench Compilation run: cargo +nightly bench --no-run --profile=dev --all-features @@ -59,10 +59,10 @@ jobs: strategy: matrix: - features: [ - { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" }, - { label: "quickwit", flags: "mmap,quickwit,failpoints" } - ] + features: + - { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" } + - { label: "quickwit", flags: "mmap,quickwit,failpoints" } + - { label: "none", flags: "" } name: test-${{ matrix.features.label}} @@ -80,7 +80,21 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests - run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace + run: | + # if matrix.feature.flags is empty then run on --lib to avoid compiling examples + # (as most of them rely on mmap) otherwise run all + if [ -z "${{ matrix.features.flags }}" ]; then + cargo +stable nextest run --lib --no-default-features --verbose --workspace + else + cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace + fi - name: Run doctests - run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace + run: | + # if matrix.feature.flags is empty then run on --lib to avoid compiling examples + # (as most of them rely on mmap) otherwise run all + if [ -z "${{ matrix.features.flags }}" ]; then + echo "no doctest for no feature flag" + else + cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace + fi diff --git a/Cargo.toml b/Cargo.toml index 10d1c8400..40eff7814 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ fs4 = { version = "0.13.1", optional = true } levenshtein_automata = "0.2.1" uuid = { version = "1.0.0", features = ["v4", "serde"] } crossbeam-channel = "0.5.4" -rust-stemmers = "1.2.0" +rust-stemmers = { version = "1.2.0", optional = true } downcast-rs = "2.0.1" bitpacking = { version = "0.9.2", default-features = false, features = [ "bitpacker4x", @@ -113,7 +113,8 @@ debug-assertions = true overflow-checks = true [features] -default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"] +default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"] +stemmer = ["rust-stemmers"] mmap = ["fs4", "tempfile", "memmap2"] stopwords = [] diff --git a/src/directory/file_watcher.rs b/src/directory/mmap_directory/file_watcher.rs similarity index 100% rename from src/directory/file_watcher.rs rename to src/directory/mmap_directory/file_watcher.rs diff --git a/src/directory/mmap_directory.rs b/src/directory/mmap_directory/mod.rs similarity index 99% rename from src/directory/mmap_directory.rs rename to src/directory/mmap_directory/mod.rs index f4785ef72..60ef82b30 100644 --- a/src/directory/mmap_directory.rs +++ b/src/directory/mmap_directory/mod.rs @@ -1,3 +1,5 @@ +mod file_watcher; + use std::collections::HashMap; use std::fmt; use std::fs::{self, File, OpenOptions}; @@ -7,6 +9,7 @@ use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock, Weak}; use common::StableDeref; +use file_watcher::FileWatcher; use fs4::fs_std::FileExt; #[cfg(all(feature = "mmap", unix))] pub use memmap2::Advice; @@ -18,7 +21,6 @@ use crate::core::META_FILEPATH; use crate::directory::error::{ DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError, }; -use crate::directory::file_watcher::FileWatcher; use crate::directory::{ AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle, WritePtr, diff --git a/src/directory/mod.rs b/src/directory/mod.rs index 7fab7e051..d4494d307 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -5,7 +5,6 @@ mod mmap_directory; mod directory; mod directory_lock; -mod file_watcher; pub mod footer; mod managed_directory; mod ram_directory; diff --git a/src/index/index_meta.rs b/src/index/index_meta.rs index d95ce6ff7..d06d706c4 100644 --- a/src/index/index_meta.rs +++ b/src/index/index_meta.rs @@ -404,7 +404,10 @@ mod tests { schema_builder.build() }; let index_metas = IndexMeta { - index_settings: IndexSettings::default(), + index_settings: IndexSettings { + docstore_compression: Compressor::None, + ..Default::default() + }, segments: Vec::new(), schema, opstamp: 0u64, @@ -413,7 +416,7 @@ mod tests { let json = serde_json::ser::to_string(&index_metas).expect("serialization failed"); assert_eq!( json, - r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"# + r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"# ); let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap(); @@ -494,6 +497,8 @@ mod tests { #[test] #[cfg(feature = "lz4-compression")] fn test_index_settings_default() { + use crate::store::Compressor; + let mut index_settings = IndexSettings::default(); assert_eq!( index_settings, diff --git a/src/indexer/segment_writer.rs b/src/indexer/segment_writer.rs index 72152cffa..94e3f0de2 100644 --- a/src/indexer/segment_writer.rs +++ b/src/indexer/segment_writer.rs @@ -421,10 +421,9 @@ fn remap_and_write( #[cfg(test)] mod tests { use std::collections::BTreeMap; - use std::path::{Path, PathBuf}; + use std::path::Path; use columnar::ColumnType; - use tempfile::TempDir; use crate::collector::{Count, TopDocs}; use crate::directory::RamDirectory; @@ -1067,10 +1066,7 @@ mod tests { let mut schema_builder = Schema::builder(); schema_builder.add_text_field("title", text_options); let schema = schema_builder.build(); - let tempdir = TempDir::new().unwrap(); - let tempdir_path = PathBuf::from(tempdir.path()); - Index::create_in_dir(&tempdir_path, schema).unwrap(); - let index = Index::open_in_dir(tempdir_path).unwrap(); + let index = Index::create_in_ram(schema); let schema = index.schema(); let mut index_writer = index.writer(50_000_000).unwrap(); let title = schema.get_field("title").unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 7890d6188..f0b3120a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,7 @@ mod docset; mod reader; #[cfg(test)] +#[cfg(feature = "mmap")] mod compat_tests; pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer}; @@ -1174,12 +1175,11 @@ pub mod tests { #[test] fn test_validate_checksum() -> crate::Result<()> { - let index_path = tempfile::tempdir().expect("dir"); let mut builder = Schema::builder(); let body = builder.add_text_field("body", TEXT | STORED); let schema = builder.build(); - let index = Index::create_in_dir(&index_path, schema)?; - let mut writer: IndexWriter = index.writer(50_000_000)?; + let index = Index::create_in_ram(schema); + let mut writer: IndexWriter = index.writer_for_tests()?; writer.set_merge_policy(Box::new(NoMergePolicy)); for _ in 0..5000 { writer.add_document(doc!(body => "foo"))?; diff --git a/src/snippet/mod.rs b/src/snippet/mod.rs index 020e6b588..ee61b534a 100644 --- a/src/snippet/mod.rs +++ b/src/snippet/mod.rs @@ -483,7 +483,7 @@ mod tests { use super::{collapse_overlapped_ranges, search_fragments, select_best_fragment_combination}; use crate::query::QueryParser; - use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions, TEXT}; + use crate::schema::{Schema, TEXT}; use crate::snippet::SnippetGenerator; use crate::tokenizer::{NgramTokenizer, SimpleTokenizer}; use crate::Index; @@ -727,8 +727,10 @@ Survey in 2016, 2017, and 2018."#; Ok(()) } + #[cfg(feature = "stemmer")] #[test] fn test_snippet_generator() -> crate::Result<()> { + use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions}; let mut schema_builder = Schema::builder(); let text_options = TextOptions::default().set_indexing_options( TextFieldIndexing::default() diff --git a/src/store/mod.rs b/src/store/mod.rs index 582643515..cccf4d8f9 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -102,6 +102,7 @@ pub(crate) mod tests { } const NUM_DOCS: usize = 1_000; + #[test] fn test_doc_store_iter_with_delete_bug_1077() -> crate::Result<()> { // this will cover deletion of the first element in a checkpoint @@ -113,7 +114,7 @@ pub(crate) mod tests { let directory = RamDirectory::create(); let store_wrt = directory.open_write(path)?; let schema = - write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::Lz4, BLOCK_SIZE, true); + write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::default(), BLOCK_SIZE, true); let field_title = schema.get_field("title").unwrap(); let store_file = directory.open_read(path)?; let store = StoreReader::open(store_file, 10)?; diff --git a/src/store/reader.rs b/src/store/reader.rs index fb1533988..a4105abec 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -465,7 +465,7 @@ mod tests { let directory = RamDirectory::create(); let path = Path::new("store"); let writer = directory.open_write(path)?; - let schema = write_lorem_ipsum_store(writer, 500, Compressor::default(), BLOCK_SIZE, true); + let schema = write_lorem_ipsum_store(writer, 500, Compressor::None, BLOCK_SIZE, true); let title = schema.get_field("title").unwrap(); let store_file = directory.open_read(path)?; let store = StoreReader::open(store_file, DOCSTORE_CACHE_CAPACITY)?; @@ -499,7 +499,7 @@ mod tests { assert_eq!(store.cache_stats().cache_hits, 1); assert_eq!(store.cache_stats().cache_misses, 2); - assert_eq!(store.cache.peek_lru(), Some(11207)); + assert_eq!(store.cache.peek_lru(), Some(232206)); Ok(()) } diff --git a/src/tokenizer/mod.rs b/src/tokenizer/mod.rs index 5a5435562..31c518fd4 100644 --- a/src/tokenizer/mod.rs +++ b/src/tokenizer/mod.rs @@ -132,13 +132,14 @@ mod regex_tokenizer; mod remove_long; mod simple_tokenizer; mod split_compound_words; -mod stemmer; mod stop_word_filter; mod tokenized_string; mod tokenizer; mod tokenizer_manager; mod whitespace_tokenizer; +#[cfg(feature = "stemmer")] +mod stemmer; pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer}; pub use self::alphanum_only::AlphaNumOnlyFilter; @@ -151,6 +152,7 @@ pub use self::regex_tokenizer::RegexTokenizer; pub use self::remove_long::RemoveLongFilter; pub use self::simple_tokenizer::{SimpleTokenStream, SimpleTokenizer}; pub use self::split_compound_words::SplitCompoundWords; +#[cfg(feature = "stemmer")] pub use self::stemmer::{Language, Stemmer}; pub use self::stop_word_filter::StopWordFilter; pub use self::tokenized_string::{PreTokenizedStream, PreTokenizedString}; @@ -167,10 +169,7 @@ pub const MAX_TOKEN_LEN: usize = u16::MAX as usize - 5; #[cfg(test)] pub(crate) mod tests { - use super::{ - Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, TokenizerManager, - }; - use crate::tokenizer::TextAnalyzer; + use super::{Token, TokenizerManager}; /// This is a function that can be used in tests and doc tests /// to assert a token's correctness. @@ -205,59 +204,15 @@ pub(crate) mod tests { } #[test] - fn test_en_tokenizer() { + fn test_tokenizer_does_not_exist() { let tokenizer_manager = TokenizerManager::default(); assert!(tokenizer_manager.get("en_doesnotexist").is_none()); - let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); - let mut tokens: Vec = vec![]; - { - let mut add_token = |token: &Token| { - tokens.push(token.clone()); - }; - en_tokenizer - .token_stream("Hello, happy tax payer!") - .process(&mut add_token); - } - - assert_eq!(tokens.len(), 4); - assert_token(&tokens[0], 0, "hello", 0, 5); - assert_token(&tokens[1], 1, "happi", 7, 12); - assert_token(&tokens[2], 2, "tax", 13, 16); - assert_token(&tokens[3], 3, "payer", 17, 22); - } - - #[test] - fn test_non_en_tokenizer() { - let tokenizer_manager = TokenizerManager::default(); - tokenizer_manager.register( - "el_stem", - TextAnalyzer::builder(SimpleTokenizer::default()) - .filter(RemoveLongFilter::limit(40)) - .filter(LowerCaser) - .filter(Stemmer::new(Language::Greek)) - .build(), - ); - let mut en_tokenizer = tokenizer_manager.get("el_stem").unwrap(); - let mut tokens: Vec = vec![]; - { - let mut add_token = |token: &Token| { - tokens.push(token.clone()); - }; - en_tokenizer - .token_stream("Καλημέρα, χαρούμενε φορολογούμενε!") - .process(&mut add_token); - } - - assert_eq!(tokens.len(), 3); - assert_token(&tokens[0], 0, "καλημερ", 0, 16); - assert_token(&tokens[1], 1, "χαρουμεν", 18, 36); - assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63); } #[test] fn test_tokenizer_empty() { let tokenizer_manager = TokenizerManager::default(); - let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); + let mut en_tokenizer = tokenizer_manager.get("default").unwrap(); { let mut tokens: Vec = vec![]; { diff --git a/src/tokenizer/stemmer.rs b/src/tokenizer/stemmer.rs index fc87440ce..764efc2ee 100644 --- a/src/tokenizer/stemmer.rs +++ b/src/tokenizer/stemmer.rs @@ -142,3 +142,60 @@ impl TokenStream for StemmerTokenStream { self.tail.token_mut() } } + +#[cfg(test)] +mod tests { + use tokenizer_api::Token; + + use super::*; + use crate::tokenizer::tests::assert_token; + use crate::tokenizer::{LowerCaser, SimpleTokenizer, TextAnalyzer, TokenizerManager}; + + #[test] + fn test_en_stem() { + let tokenizer_manager = TokenizerManager::default(); + let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); + let mut tokens: Vec = vec![]; + { + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + en_tokenizer + .token_stream("Dogs are the bests!") + .process(&mut add_token); + } + + assert_eq!(tokens.len(), 4); + assert_token(&tokens[0], 0, "dog", 0, 4); + assert_token(&tokens[1], 1, "are", 5, 8); + assert_token(&tokens[2], 2, "the", 9, 12); + assert_token(&tokens[3], 3, "best", 13, 18); + } + + #[test] + fn test_non_en_stem() { + let tokenizer_manager = TokenizerManager::default(); + tokenizer_manager.register( + "el_stem", + TextAnalyzer::builder(SimpleTokenizer::default()) + .filter(LowerCaser) + .filter(Stemmer::new(Language::Greek)) + .build(), + ); + let mut el_tokenizer = tokenizer_manager.get("el_stem").unwrap(); + let mut tokens: Vec = vec![]; + { + let mut add_token = |token: &Token| { + tokens.push(token.clone()); + }; + el_tokenizer + .token_stream("Καλημέρα, χαρούμενε φορολογούμενε!") + .process(&mut add_token); + } + + assert_eq!(tokens.len(), 3); + assert_token(&tokens[0], 0, "καλημερ", 0, 16); + assert_token(&tokens[1], 1, "χαρουμεν", 18, 36); + assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63); + } +} diff --git a/src/tokenizer/tokenizer_manager.rs b/src/tokenizer/tokenizer_manager.rs index a0bdbcc0c..8bdbba7bd 100644 --- a/src/tokenizer/tokenizer_manager.rs +++ b/src/tokenizer/tokenizer_manager.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use crate::tokenizer::stemmer::Language; use crate::tokenizer::tokenizer::TextAnalyzer; use crate::tokenizer::{ - LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, Stemmer, WhitespaceTokenizer, + LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, WhitespaceTokenizer, }; /// The tokenizer manager serves as a store for @@ -64,14 +63,18 @@ impl Default for TokenizerManager { .filter(LowerCaser) .build(), ); - manager.register( - "en_stem", - TextAnalyzer::builder(SimpleTokenizer::default()) - .filter(RemoveLongFilter::limit(40)) - .filter(LowerCaser) - .filter(Stemmer::new(Language::English)) - .build(), - ); + #[cfg(feature = "stemmer")] + { + use crate::tokenizer::stemmer::{Language, Stemmer}; + manager.register( + "en_stem", + TextAnalyzer::builder(SimpleTokenizer::default()) + .filter(RemoveLongFilter::limit(40)) + .filter(LowerCaser) // The stemmer does not lowercase + .filter(Stemmer::new(Language::English)) + .build(), + ); + } manager.register("whitespace", WhitespaceTokenizer::default()); manager } From db2ecc6057ed875a5e32c79401fc9899b2262824 Mon Sep 17 00:00:00 2001 From: ChangRui-Ryan Date: Mon, 5 Jan 2026 17:03:01 +0800 Subject: [PATCH 21/26] fix Column.first method parameter type (#2792) --- columnar/src/column/mod.rs | 4 ++-- columnar/src/tests.rs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/columnar/src/column/mod.rs b/columnar/src/column/mod.rs index cc2938bb8..f6a50b45f 100644 --- a/columnar/src/column/mod.rs +++ b/columnar/src/column/mod.rs @@ -85,8 +85,8 @@ impl Column { } #[inline] - pub fn first(&self, row_id: RowId) -> Option { - self.values_for_doc(row_id).next() + pub fn first(&self, doc_id: DocId) -> Option { + self.values_for_doc(doc_id).next() } /// Load the first value for each docid in the provided slice. diff --git a/columnar/src/tests.rs b/columnar/src/tests.rs index 5fa537466..5c4a9366c 100644 --- a/columnar/src/tests.rs +++ b/columnar/src/tests.rs @@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() { let DynamicColumn::Bool(bool_col) = dyn_bool_col else { panic!(); }; - let vals: Vec> = (0..5).map(|row_id| bool_col.first(row_id)).collect(); + let vals: Vec> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect(); assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]); } @@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() { let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else { panic!(); }; - let vals: Vec> = (0..5).map(|row_id| ip_col.first(row_id)).collect(); + let vals: Vec> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect(); assert_eq!( &vals, &[ @@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() { let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else { panic!(); }; - let index: Vec> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect(); + let index: Vec> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect(); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(str_col.num_rows(), 5); let mut term_buffer = String::new(); @@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() { panic!(); }; let index: Vec> = (0..5) - .map(|row_id| bytes_col.ords().first(row_id)) + .map(|doc_id| bytes_col.ords().first(doc_id)) .collect(); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(bytes_col.num_rows(), 5); From 65b5a1a306487d7eb0192496ae0f16ec7d6098c3 Mon Sep 17 00:00:00 2001 From: PSeitz-dd Date: Tue, 6 Jan 2026 11:50:55 +0100 Subject: [PATCH 22/26] one collector per agg request instead per bucket (#2759) * improve bench * add more tests for new collection type * one collector per agg request instead per bucket In this refactoring a collector knows in which bucket of the parent their data is in. This allows to convert the previous approach of one collector per bucket to one collector per request. low card bucket optimization * reduce dynamic dispatch, faster term agg * use radix map, fix prepare_max_bucket use paged term map in term agg use special no sub agg term map impl * specialize columntype in stats * remove stacktrace bloat, use &mut helper increase cache to 2048 * cleanup remove clone move data in term req, single doc opt for stats * add comment * share column block accessor * simplify fetch block in column_block_accessor * split subaggcache into two trait impls * move partitions to heap * fix name, add comment --------- Co-authored-by: Pascal Seitz --- benches/agg_bench.rs | 136 +-- columnar/src/block_accessor.rs | 10 +- common/src/bitset.rs | 8 + src/aggregation/agg_data.rs | 113 +-- src/aggregation/agg_tests.rs | 437 ++++++++- src/aggregation/bucket/filter.rs | 146 +-- src/aggregation/bucket/histogram/histogram.rs | 125 +-- src/aggregation/bucket/range.rs | 331 +++---- src/aggregation/bucket/term_agg.rs | 829 +++++++++++------- src/aggregation/bucket/term_missing_agg.rs | 91 +- src/aggregation/buf_collector.rs | 87 -- src/aggregation/cached_sub_aggs.rs | 245 ++++++ src/aggregation/collector.rs | 53 +- src/aggregation/intermediate_agg_result.rs | 11 +- src/aggregation/metric/average.rs | 6 +- src/aggregation/metric/cardinality.rs | 127 +-- src/aggregation/metric/count.rs | 6 +- src/aggregation/metric/extended_stats.rs | 108 +-- src/aggregation/metric/max.rs | 6 +- src/aggregation/metric/min.rs | 6 +- src/aggregation/metric/mod.rs | 4 +- src/aggregation/metric/percentiles.rs | 101 +-- src/aggregation/metric/stats.rs | 198 +++-- src/aggregation/metric/sum.rs | 6 +- src/aggregation/metric/top_hits.rs | 133 +-- src/aggregation/mod.rs | 49 +- src/aggregation/segment_agg_result.rs | 103 ++- src/core/executor.rs | 10 +- 28 files changed, 2232 insertions(+), 1253 deletions(-) delete mode 100644 src/aggregation/buf_collector.rs create mode 100644 src/aggregation/cached_sub_aggs.rs diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index a4115b604..642532597 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup) { register!(group, stats_f64); register!(group, extendedstats_f64); register!(group, percentiles_f64); - register!(group, terms_few); + register!(group, terms_7); register!(group, terms_all_unique); - register!(group, terms_many); + register!(group, terms_150_000); register!(group, terms_many_top_1000); register!(group, terms_many_order_by_term); register!(group, terms_many_with_top_hits); register!(group, terms_all_unique_with_avg_sub_agg); register!(group, terms_many_with_avg_sub_agg); - register!(group, terms_few_with_avg_sub_agg); register!(group, terms_status_with_avg_sub_agg); - register!(group, terms_status); - register!(group, terms_few_with_histogram); register!(group, terms_status_with_histogram); + register!(group, terms_zipf_1000); + register!(group, terms_zipf_1000_with_histogram); + register!(group, terms_zipf_1000_with_avg_sub_agg); register!(group, terms_many_json_mixed_type_with_avg_sub_agg); register!(group, cardinality_agg); - register!(group, terms_few_with_cardinality_agg); + register!(group, terms_status_with_cardinality_agg); register!(group, range_agg); register!(group, range_agg_with_avg_sub_agg); - register!(group, range_agg_with_term_agg_few); + register!(group, range_agg_with_term_agg_status); register!(group, range_agg_with_term_agg_many); register!(group, histogram); register!(group, histogram_hard_bounds); register!(group, histogram_with_avg_sub_agg); - register!(group, histogram_with_term_agg_few); + register!(group, histogram_with_term_agg_status); register!(group, avg_and_range_with_avg_sub_agg); // Filter aggregation benchmarks @@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn terms_few_with_cardinality_agg(index: &Index) { +fn terms_status_with_cardinality_agg(index: &Index) { let agg_req = json!({ "my_texts": { - "terms": { "field": "text_few_terms" }, + "terms": { "field": "text_few_terms_status" }, "aggs": { "cardinality": { "cardinality": { @@ -175,13 +175,7 @@ fn terms_few_with_cardinality_agg(index: &Index) { execute_agg(index, agg_req); } -fn terms_few(index: &Index) { - let agg_req = json!({ - "my_texts": { "terms": { "field": "text_few_terms" } }, - }); - execute_agg(index, agg_req); -} -fn terms_status(index: &Index) { +fn terms_7(index: &Index) { let agg_req = json!({ "my_texts": { "terms": { "field": "text_few_terms_status" } }, }); @@ -194,7 +188,7 @@ fn terms_all_unique(index: &Index) { execute_agg(index, agg_req); } -fn terms_many(index: &Index) { +fn terms_150_000(index: &Index) { let agg_req = json!({ "my_texts": { "terms": { "field": "text_many_terms" } }, }); @@ -253,17 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn terms_few_with_histogram(index: &Index) { - let agg_req = json!({ - "my_texts": { - "terms": { "field": "text_few_terms" }, - "aggs": { - "histo": {"histogram": { "field": "score_f64", "interval": 10 }} - } - } - }); - execute_agg(index, agg_req); -} fn terms_status_with_histogram(index: &Index) { let agg_req = json!({ "my_texts": { @@ -276,17 +259,18 @@ fn terms_status_with_histogram(index: &Index) { execute_agg(index, agg_req); } -fn terms_few_with_avg_sub_agg(index: &Index) { +fn terms_zipf_1000_with_histogram(index: &Index) { let agg_req = json!({ "my_texts": { - "terms": { "field": "text_few_terms" }, + "terms": { "field": "text_1000_terms_zipf" }, "aggs": { - "average_f64": { "avg": { "field": "score_f64" } } + "histo": {"histogram": { "field": "score_f64", "interval": 10 }} } - }, + } }); execute_agg(index, agg_req); } + fn terms_status_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "my_texts": { @@ -299,6 +283,25 @@ fn terms_status_with_avg_sub_agg(index: &Index) { execute_agg(index, agg_req); } +fn terms_zipf_1000_with_avg_sub_agg(index: &Index) { + let agg_req = json!({ + "my_texts": { + "terms": { "field": "text_1000_terms_zipf" }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } + } + }, + }); + execute_agg(index, agg_req); +} + +fn terms_zipf_1000(index: &Index) { + let agg_req = json!({ + "my_texts": { "terms": { "field": "text_1000_terms_zipf" } }, + }); + execute_agg(index, agg_req); +} + fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "my_texts": { @@ -354,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) { execute_agg(index, agg_req); } -fn range_agg_with_term_agg_few(index: &Index) { +fn range_agg_with_term_agg_status(index: &Index) { let agg_req = json!({ "rangef64": { "range": { @@ -369,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) { ] }, "aggs": { - "my_texts": { "terms": { "field": "text_few_terms" } }, + "my_texts": { "terms": { "field": "text_few_terms_status" } }, } }, }); @@ -425,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } -fn histogram_with_term_agg_few(index: &Index) { +fn histogram_with_term_agg_status(index: &Index) { let agg_req = json!({ "rangef64": { "histogram": { "field": "score_f64", "interval": 10 }, "aggs": { - "my_texts": { "terms": { "field": "text_few_terms" } } + "my_texts": { "terms": { "field": "text_few_terms_status" } } } } }); @@ -475,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector { } fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { + // Flag to use existing index + let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok(); + if reuse_index && std::path::Path::new("agg_bench").exists() { + return Index::open_in_dir("agg_bench"); + } + // crreate dir + std::fs::create_dir_all("agg_bench")?; let mut schema_builder = Schema::builder(); let text_fieldtype = tantivy::schema::TextOptions::default() .set_indexing_options( @@ -486,24 +496,44 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { let text_field_all_unique_terms = schema_builder.add_text_field("text_all_unique_terms", STRING | FAST); let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST); - let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST); - let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST); let text_field_few_terms_status = schema_builder.add_text_field("text_few_terms_status", STRING | FAST); + let text_field_1000_terms_zipf = + schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST); let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast(); let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); - let index = Index::create_from_tempdir(schema_builder.build())?; - let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"]; - // Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare. - let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap(); + // use tmp dir + let index = if reuse_index { + Index::create_in_dir("agg_bench", schema_builder.build())? + } else { + Index::create_from_tempdir(schema_builder.build())? + }; + // Approximate log proportions + let status_field_data = [ + ("INFO", 8000), + ("ERROR", 300), + ("WARN", 1200), + ("DEBUG", 500), + ("OK", 500), + ("CRITICAL", 20), + ("EMERGENCY", 1), + ]; + let log_level_distribution = + WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap(); let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap(); let many_terms_data = (0..150_000) .map(|num| format!("author{num}")) .collect::>(); + + // Prepare 1000 unique terms sampled using a Zipf distribution. + // Exponent ~1.1 approximates top-20 terms covering around ~20%. + let terms_1000: Vec = (1..=1000).map(|i| format!("term_{i}")).collect(); + let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap(); + { let mut rng = StdRng::from_seed([1u8; 32]); let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?; @@ -513,8 +543,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { index_writer.add_document(doc!())?; } if cardinality == Cardinality::Multivalued { - let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)]; - let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)]; + let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0; + let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0; + let idx_a = zipf_1000.sample(&mut rng) as usize - 1; + let idx_b = zipf_1000.sample(&mut rng) as usize - 1; + let term_1000_a = &terms_1000[idx_a]; + let term_1000_b = &terms_1000[idx_b]; index_writer.add_document(doc!( json_field => json!({"mixed_type": 10.0}), json_field => json!({"mixed_type": 10.0}), @@ -524,10 +558,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { text_field_all_unique_terms => "coolo", text_field_many_terms => "cool", text_field_many_terms => "cool", - text_field_few_terms => "cool", - text_field_few_terms => "cool", text_field_few_terms_status => log_level_sample_a, text_field_few_terms_status => log_level_sample_b, + text_field_1000_terms_zipf => term_1000_a.as_str(), + text_field_1000_terms_zipf => term_1000_b.as_str(), score_field => 1u64, score_field => 1u64, score_field_f64 => lg_norm.sample(&mut rng), @@ -554,8 +588,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { json_field => json, text_field_all_unique_terms => format!("unique_term_{}", rng.gen::()), text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(), - text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(), - text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)], + text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0, + text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(), score_field => val as u64, score_field_f64 => lg_norm.sample(&mut rng), score_field_i64 => val as i64, @@ -607,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) { "avg_score": { "avg": { "field": "score" } }, "stats_score": { "stats": { "field": "score_f64" } }, "terms_text": { - "terms": { "field": "text_few_terms" } + "terms": { "field": "text_few_terms_status" } } } } @@ -623,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) { "avg_score": { "avg": { "field": "score" } }, "stats_score": { "stats": { "field": "score_f64" } }, "terms_text": { - "terms": { "field": "text_few_terms" } + "terms": { "field": "text_few_terms_status" } } } } diff --git a/columnar/src/block_accessor.rs b/columnar/src/block_accessor.rs index 6bd24ba3b..9926553a8 100644 --- a/columnar/src/block_accessor.rs +++ b/columnar/src/block_accessor.rs @@ -29,12 +29,20 @@ impl } } #[inline] - pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column, missing: T) { + pub fn fetch_block_with_missing( + &mut self, + docs: &[u32], + accessor: &Column, + missing: Option, + ) { self.fetch_block(docs, accessor); // no missing values if accessor.index.get_cardinality().is_full() { return; } + let Some(missing) = missing else { + return; + }; // We can compare docid_cache length with docs to find missing docs // For multi value columns we can't rely on the length and always need to scan diff --git a/common/src/bitset.rs b/common/src/bitset.rs index 8e98e6780..94e4ca5ae 100644 --- a/common/src/bitset.rs +++ b/common/src/bitset.rs @@ -181,6 +181,14 @@ pub struct BitSet { len: u64, max_value: u32, } +impl std::fmt::Debug for BitSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BitSet") + .field("len", &self.len) + .field("max_value", &self.max_value) + .finish() + } +} fn num_buckets(max_val: u32) -> u32 { max_val.div_ceil(64u32) diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index deedb5781..de521505e 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -1,4 +1,4 @@ -use columnar::{Column, ColumnType, StrColumn}; +use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; use common::BitSet; use rustc_hash::FxHashSet; use serde::Serialize; @@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, - MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + build_segment_filter_collector, build_segment_range_collector, FilterAggReqData, + HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, + RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ - AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, - ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation, - SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector, - SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, + build_segment_stats_collector, AverageAggregation, CardinalityAggReqData, + CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation, + MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector, + SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, TopHitsSegmentCollector, }; use crate::aggregation::segment_agg_result::{ @@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx { /// Request data for each aggregation type. pub per_request: PerRequestAggSegCtx, pub context: AggContextParams, + pub column_block_accessor: ColumnBlockAccessor, } impl AggregationsSegmentCtx { @@ -107,21 +108,14 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } - #[inline] - pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData { - self.per_request.filter_req_data[idx] - .as_deref() - .expect("filter_req_data slot is empty (taken)") - } // ---------- mutable getters ---------- #[inline] - pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { - self.per_request.term_req_data[idx] - .as_deref_mut() - .expect("term_req_data slot is empty (taken)") + pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { + &mut self.per_request.stats_metric_req_data[idx] } + #[inline] pub(crate) fn get_cardinality_req_data_mut( &mut self, @@ -129,10 +123,7 @@ impl AggregationsSegmentCtx { ) -> &mut CardinalityAggReqData { &mut self.per_request.cardinality_req_data[idx] } - #[inline] - pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { - &mut self.per_request.stats_metric_req_data[idx] - } + #[inline] pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { self.per_request.histogram_req_data[idx] @@ -142,21 +133,6 @@ impl AggregationsSegmentCtx { // ---------- take / put (terms, histogram, range) ---------- - /// Move out the boxed Terms request at `idx`, leaving `None`. - #[inline] - pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box { - self.per_request.term_req_data[idx] - .take() - .expect("term_req_data slot is empty (taken)") - } - - /// Put back a Terms request into an empty slot at `idx`. - #[inline] - pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box) { - debug_assert!(self.per_request.term_req_data[idx].is_none()); - self.per_request.term_req_data[idx] = Some(value); - } - /// Move out the boxed Histogram request at `idx`, leaving `None`. #[inline] pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box { @@ -320,6 +296,7 @@ impl PerRequestAggSegCtx { /// Convert the aggregation tree into a serializable struct representation. /// Each node contains: { name, kind, children }. + #[allow(dead_code)] pub fn get_view_tree(&self) -> Vec { fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode { let mut children: Vec = @@ -345,12 +322,19 @@ impl PerRequestAggSegCtx { pub(crate) fn build_segment_agg_collectors_root( req: &mut AggregationsSegmentCtx, ) -> crate::Result> { - build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) + build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone()) } pub(crate) fn build_segment_agg_collectors( req: &mut AggregationsSegmentCtx, nodes: &[AggRefNode], +) -> crate::Result> { + build_segment_agg_collectors_generic(req, nodes) +} + +fn build_segment_agg_collectors_generic( + req: &mut AggregationsSegmentCtx, + nodes: &[AggRefNode], ) -> crate::Result> { let mut collectors = Vec::new(); for node in nodes.iter() { @@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector( Ok(Box::new(SegmentCardinalityCollector::from_req( req_data.column_type, node.idx_in_req_data, + req_data.accessor.clone(), + req_data.missing_value_for_accessor, ))) } AggKind::StatsKind(stats_type) => { @@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector( | StatsType::Count | StatsType::Max | StatsType::Min - | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( - node.idx_in_req_data, - ))), - StatsType::ExtendedStats(sigma) => { - Ok(Box::new(SegmentExtendedStatsCollector::from_req( - req_data.field_type, - sigma, - node.idx_in_req_data, - req_data.missing, - ))) - } - StatsType::Percentiles => Ok(Box::new( - SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?, + | StatsType::Stats => build_segment_stats_collector(req_data), + StatsType::ExtendedStats(sigma) => Ok(Box::new( + SegmentExtendedStatsCollector::from_req(req_data, sigma), )), + StatsType::Percentiles => { + let req_data = req.get_metric_req_data_mut(node.idx_in_req_data); + Ok(Box::new( + SegmentPercentilesCollector::from_req_and_validate( + req_data.field_type, + req_data.missing_u64, + req_data.accessor.clone(), + node.idx_in_req_data, + ), + )) + } } } AggKind::TopHits => { @@ -428,12 +415,8 @@ pub(crate) fn build_segment_agg_collector( AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( req, node, )?)), - AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - req, node, - )?)), - AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Range => Ok(build_segment_range_collector(req, node)?), + AggKind::Filter => build_segment_filter_collector(req, node), } } @@ -493,6 +476,7 @@ pub(crate) fn build_aggregations_data_from_req( let mut data = AggregationsSegmentCtx { per_request: Default::default(), context, + column_block_accessor: ColumnBlockAccessor::default(), }; for (name, agg) in aggs.iter() { @@ -521,9 +505,9 @@ fn build_nodes( let idx_in_req_data = data.push_range_req_data(RangeAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: range_req.clone(), + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -541,9 +525,7 @@ fn build_nodes( let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req.clone(), is_date_histogram: false, bounds: HistogramBounds { @@ -568,9 +550,7 @@ fn build_nodes( let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req, is_date_histogram: true, bounds: HistogramBounds { @@ -650,7 +630,6 @@ fn build_nodes( let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), collecting_for, missing: *missing, @@ -678,7 +657,6 @@ fn build_nodes( let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { accessor, field_type, - column_block_accessor: Default::default(), name: agg_name.to_string(), collecting_for: StatsType::Percentiles, missing: percentiles_req.missing, @@ -753,6 +731,7 @@ fn build_nodes( segment_reader: reader.clone(), evaluator, matching_docs_buffer, + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -895,7 +874,7 @@ fn build_terms_or_cardinality_nodes( }); } - // Add one node per accessor to mirror previous behavior and allow per-type missing handling. + // Add one node per accessor for (accessor, column_type) in column_and_types { let missing_value_for_accessor = if use_special_missing_agg { None @@ -926,11 +905,8 @@ fn build_terms_or_cardinality_nodes( column_type, str_dict_column: str_dict_column.clone(), missing_value_for_accessor, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: TermsAggregationInternal::from_req(req), - // Will be filled later when building collectors - sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), allowed_term_ids, is_top_level, @@ -943,7 +919,6 @@ fn build_terms_or_cardinality_nodes( column_type, str_dict_column: str_dict_column.clone(), missing_value_for_accessor, - column_block_accessor: Default::default(), name: agg_name.to_string(), req: req.clone(), }); diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index fede0c7c7..49a8afb37 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -2,15 +2,441 @@ use serde_json::Value; use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; -use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, IndexWriter, Term}; +// The following tests ensure that each bucket aggregation type correctly functions as a +// sub-aggregation of another bucket aggregation in two scenarios: +// 1) The parent has more buckets than the child sub-aggregation +// 2) The child sub-aggregation has more buckets than the parent +// +// These scenarios exercise the bucket id mapping and sub-aggregation routing logic. + +#[test] +fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with 4 buckets + // Child: terms on text -> 2 buckets + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}} + } + } + })) + .unwrap(); + + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + // Exact expected structure and counts + assert_eq!( + res["parent_range"]["buckets"], + json!([ + { + "key": "*-3", + "doc_count": 1, + "to": 3.0, + "child_terms": { + "buckets": [ + {"doc_count": 1, "key": "cool"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "3-7", + "doc_count": 3, + "from": 3.0, + "to": 7.0, + "child_terms": { + "buckets": [ + {"doc_count": 2, "key": "cool"}, + {"doc_count": 1, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "7-20", + "doc_count": 3, + "from": 7.0, + "to": 20.0, + "child_terms": { + "buckets": [ + {"doc_count": 3, "key": "cool"} + ], + "sum_other_doc_count": 0 + } + }, + { + "key": "20-*", + "doc_count": 2, + "from": 20.0, + "child_terms": { + "buckets": [ + {"doc_count": 1, "key": "cool"}, + {"doc_count": 1, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: histogram on score with large interval -> 1 bucket + // Child: terms on text -> 2 buckets (cool/nohit) + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_hist": { + "histogram": {"field": "score", "interval": 100.0}, + "aggs": { + "child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}} + } + } + })) + .unwrap(); + + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + assert_eq!( + res["parent_hist"], + json!({ + "buckets": [ + { + "key": 0.0, + "doc_count": 9, + "child_terms": { + "buckets": [ + {"doc_count": 7, "key": "cool"}, + {"doc_count": 2, "key": "nohit"} + ], + "sum_other_doc_count": 0 + } + } + ] + }) + ); + + Ok(()) +} + +#[test] +fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with 5 buckets + // Child: coarse range with 3 buckets + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 20.0} + ] + } + } + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + assert_eq!( + res["parent_range"]["buckets"], + json!([ + {"key": "*-3", "doc_count": 1, "to": 3.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 1, "to": 3.0}, + {"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 0, "from": 20.0} + ]} + }, + {"key": "20-*", "doc_count": 2, "from": 20.0, + "child_range": {"buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0}, + {"key": "20-*", "doc_count": 2, "from": 20.0} + ]} + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: terms on text (2 buckets) + // Child: range with 4 buckets + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 20.0} + ] + } + } + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + + assert_eq!( + res["parent_terms"], + json!({ + "buckets": [ + { + "key": "cool", + "doc_count": 7, + "child_range": { + "buckets": [ + {"key": "*-3", "doc_count": 1, "to": 3.0}, + {"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0}, + {"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0}, + {"key": "20-*", "doc_count": 1, "from": 20.0} + ] + } + }, + { + "key": "nohit", + "doc_count": 2, + "child_range": { + "buckets": [ + {"key": "*-3", "doc_count": 0, "to": 3.0}, + {"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0}, + {"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0}, + {"key": "20-*", "doc_count": 1, "from": 20.0} + ] + } + } + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0 + }) + ); + + Ok(()) +} + +#[test] +fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with several ranges + // Child: histogram with large interval (single bucket per parent) + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_hist": {"histogram": {"field": "score", "interval": 100.0}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + assert_eq!( + res["parent_range"]["buckets"], + json!([ + {"key": "*-3", "doc_count": 1, "to": 3.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]} + }, + {"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]} + }, + {"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]} + }, + {"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]} + }, + {"key": "20-*", "doc_count": 2, "from": 20.0, + "child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]} + } + ]) + ); + + // Case B: child has more buckets than parent + // Parent: terms on text -> 2 buckets + // Child: histogram with small interval -> multiple buckets including empties + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_hist": {"histogram": {"field": "score", "interval": 10.0}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + assert_eq!( + res["parent_terms"], + json!({ + "buckets": [ + { + "key": "cool", + "doc_count": 7, + "child_hist": { + "buckets": [ + {"key": 0.0, "doc_count": 4}, + {"key": 10.0, "doc_count": 2}, + {"key": 20.0, "doc_count": 0}, + {"key": 30.0, "doc_count": 0}, + {"key": 40.0, "doc_count": 1} + ] + } + }, + { + "key": "nohit", + "doc_count": 2, + "child_hist": { + "buckets": [ + {"key": 0.0, "doc_count": 1}, + {"key": 10.0, "doc_count": 0}, + {"key": 20.0, "doc_count": 0}, + {"key": 30.0, "doc_count": 0}, + {"key": 40.0, "doc_count": 1} + ] + } + } + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 0 + }) + ); + + Ok(()) +} + +#[test] +fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + // Case A: parent has more buckets than child + // Parent: range with several buckets + // Child: date_histogram with 30d -> single bucket per parent + let agg_parent_more: Aggregations = serde_json::from_value(json!({ + "parent_range": { + "range": { + "field": "score", + "ranges": [ + {"to": 3.0}, + {"from": 3.0, "to": 7.0}, + {"from": 7.0, "to": 11.0}, + {"from": 11.0, "to": 20.0}, + {"from": 20.0} + ] + }, + "aggs": { + "child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?; + let buckets = res["parent_range"]["buckets"].as_array().unwrap(); + // Verify each parent bucket has exactly one child date bucket with matching doc_count + for bucket in buckets { + let parent_count = bucket["doc_count"].as_u64().unwrap(); + let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(child_buckets.len(), 1); + assert_eq!(child_buckets[0]["doc_count"], parent_count); + } + + // Case B: child has more buckets than parent + // Parent: terms on text (2 buckets) + // Child: date_histogram with 1d -> multiple buckets + let agg_child_more: Aggregations = serde_json::from_value(json!({ + "parent_terms": { + "terms": {"field": "text"}, + "aggs": { + "child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}} + } + } + })) + .unwrap(); + let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?; + let buckets = res["parent_terms"]["buckets"].as_array().unwrap(); + + // cool bucket + assert_eq!(buckets[0]["key"], "cool"); + let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(cool_buckets.len(), 3); + assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0 + assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1 + assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2 + + // nohit bucket + assert_eq!(buckets[1]["key"], "nohit"); + let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap(); + assert_eq!(nohit_buckets.len(), 2); + assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1 + assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2 + + Ok(()) +} + fn get_avg_req(field_name: &str) -> Aggregation { serde_json::from_value(json!({ "avg": { @@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector { } // *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE *** +// Note: The flushng part of these tests are outdated, since the buffering change after converting +// the collection into one collector per request instead of per bucket. +// +// However they are useful as they test a complex aggregation requests. fn test_aggregation_flushing( merge_segments: bool, use_distributed_collector: bool, @@ -37,8 +467,9 @@ fn test_aggregation_flushing( let reader = index.reader()?; - assert_eq!(DOC_BLOCK_SIZE, 64); - // In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. + assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64); + // In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one + // block. // // Build a request so that on the first level we have one full cache, which is then flushed. // The same cache should have some residue docs at the end, which are flushed (Range 0-70) diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index 18f2a692a..73518238a 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, }; -use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; use crate::docset::DocSet; use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; use crate::schema::Schema; @@ -404,15 +408,18 @@ pub struct FilterAggReqData { pub evaluator: DocumentQueryEvaluator, /// Reusable buffer for matching documents to minimize allocations during collection pub matching_docs_buffer: Vec, + /// True if this filter aggregation is at the top level of the aggregation tree (not nested). + pub is_top_level: bool, } impl FilterAggReqData { pub(crate) fn get_memory_consumption(&self) -> usize { // Estimate: name + segment reader reference + bitset + buffer capacity self.name.len() - + std::mem::size_of::() - + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) - + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() + + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() } } @@ -489,17 +496,24 @@ impl Debug for DocumentQueryEvaluator { } } -/// Segment collector for filter aggregation -pub struct SegmentFilterCollector { - /// Document count in this bucket +#[derive(Debug, Clone, PartialEq, Copy)] +struct DocCount { doc_count: u64, + bucket_id: BucketId, +} + +/// Segment collector for filter aggregation +pub struct SegmentFilterCollector { + /// Document counts per parent bucket + parent_buckets: Vec, /// Sub-aggregation collectors - sub_aggregations: Option>, + sub_aggregations: Option>, + bucket_id_provider: BucketIdProvider, /// Accessor index for this filter aggregation (to access FilterAggReqData) accessor_idx: usize, } -impl SegmentFilterCollector { +impl SegmentFilterCollector { /// Create a new filter segment collector following the new agg_data pattern pub(crate) fn from_req_and_validate( req: &mut AggregationsSegmentCtx, @@ -511,47 +525,75 @@ impl SegmentFilterCollector { } else { None }; + let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new); Ok(SegmentFilterCollector { - doc_count: 0, + parent_buckets: Vec::new(), sub_aggregations: sub_agg_collector, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } -impl Debug for SegmentFilterCollector { +pub(crate) fn build_segment_filter_collector( + req: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data] + .as_ref() + .expect("filter_req_data slot is empty") + .is_top_level; + + if is_top_level { + Ok(Box::new( + SegmentFilterCollector::::from_req_and_validate(req, node)?, + )) + } else { + Ok(Box::new( + SegmentFilterCollector::::from_req_and_validate(req, node)?, + )) + } +} + +impl Debug for SegmentFilterCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentFilterCollector") - .field("doc_count", &self.doc_count) + .field("buckets", &self.parent_buckets) .field("has_sub_aggs", &self.sub_aggregations.is_some()) .field("accessor_idx", &self.accessor_idx) .finish() } } -impl CollectorClone for SegmentFilterCollector { - fn clone_box(&self) -> Box { - // For now, panic - this needs proper implementation with weight recreation - panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation") - } -} - -impl SegmentAggregationCollector for SegmentFilterCollector { +impl SegmentAggregationCollector for SegmentFilterCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let mut sub_results = IntermediateAggregationResults::default(); + let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize); - if let Some(sub_aggs) = self.sub_aggregations { - sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_results, + // Here we create a new bucket ID for sub-aggregations if the bucket doesn't + // exist, so that sub-aggregations can still produce results (e.g., zero doc + // count) + bucket_opt + .map(|bucket| bucket.bucket_id) + .unwrap_or(self.bucket_id_provider.next_bucket_id()), + )?; } // Create the filter bucket result let filter_bucket_result = IntermediateBucketResult::Filter { - doc_count: self.doc_count, + doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0), sub_aggregations: sub_results, }; @@ -570,32 +612,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector { Ok(()) } - fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - // Access the evaluator from FilterAggReqData - let req_data = agg_data.get_filter_req_data(self.accessor_idx); - - // O(1) BitSet lookup to check if document matches filter - if req_data.evaluator.matches_document(doc) { - self.doc_count += 1; - - // If we have sub-aggregations, collect on them for this filtered document - if let Some(sub_aggs) = &mut self.sub_aggregations { - sub_aggs.collect(doc, agg_data)?; - } - } - Ok(()) - } - - #[inline] - fn collect_block( + fn collect( &mut self, - docs: &[DocId], + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { if docs.is_empty() { return Ok(()); } + let mut bucket = self.parent_buckets[parent_bucket_id as usize]; // Take the request data to avoid borrow checker issues with sub-aggregations let mut req = agg_data.take_filter_req_data(self.accessor_idx); @@ -604,18 +631,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector { req.evaluator .filter_batch(docs, &mut req.matching_docs_buffer); - self.doc_count += req.matching_docs_buffer.len() as u64; + bucket.doc_count += req.matching_docs_buffer.len() as u64; // Batch process sub-aggregations if we have matches if !req.matching_docs_buffer.is_empty() { if let Some(sub_aggs) = &mut self.sub_aggregations { - // Use collect_block for better sub-aggregation performance - sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; + for &doc_id in &req.matching_docs_buffer { + sub_aggs.push(bucket.bucket_id, doc_id); + } } } // Put the request data back agg_data.put_back_filter_req_data(self.accessor_idx, req); + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs.check_flush_local(agg_data)?; + } + // put back bucket + self.parent_buckets[parent_bucket_id as usize] = bucket; Ok(()) } @@ -626,6 +659,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.parent_buckets.push(DocCount { + doc_count: 0, + bucket_id, + }); + } + Ok(()) + } } /// Intermediate result for filter aggregation @@ -1519,9 +1567,9 @@ mod tests { let searcher = reader.searcher(); let agg = json!({ - "test": { - "filter": deserialized, - "aggs": { "count": { "value_count": { "field": "brand" } } } + "test": { + "filter": deserialized, + "aggs": { "count": { "value_count": { "field": "brand" } } } } }); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 36c0fe57e..adf7936c6 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; @@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; -use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::BucketEntry; +use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -26,13 +26,8 @@ pub struct HistogramAggReqData { pub accessor: Column, /// The field type of the fast field. pub field_type: ColumnType, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, - /// The sub aggregation blueprint, used to create sub aggregations for each bucket. - /// Will be filled during initialization of the collector. - pub sub_aggregation_blueprint: Option>, /// The histogram aggregation request. pub req: HistogramAggregation, /// True if this is a date_histogram aggregation. @@ -257,18 +252,24 @@ impl HistogramBounds { pub(crate) struct SegmentHistogramBucketEntry { pub key: f64, pub doc_count: u64, + pub bucket_id: BucketId, } impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - sub_aggregation: Option>, + sub_aggregation: &mut Option, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + self.bucket_id, + )?; } Ok(IntermediateHistogramBucketEntry { key: self.key, @@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry { } } +#[derive(Clone, Debug, Default)] +struct HistogramBuckets { + pub buckets: FxHashMap, +} + /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. - buckets: FxHashMap, - sub_aggregations: FxHashMap>, + /// One Histogram bucket per parent bucket id. + parent_buckets: Vec, + sub_agg: Option, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } impl SegmentAggregationCollector for SegmentHistogramCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data .get_histogram_req_data(self.accessor_idx) .name .clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + // TODO: avoid prepare_max_bucket here and handle empty buckets. + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req = agg_data.take_histogram_req_data(self.accessor_idx); + let req = agg_data.take_histogram_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); + let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets; let bounds = req.bounds; let interval = req.req.interval; let offset = req.offset; let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64; - req.column_block_accessor.fetch_block(docs, &req.accessor); - for (doc, val) in req + agg_data + .column_block_accessor + .fetch_block(docs, &req.accessor); + for (doc, val) in agg_data .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let val = f64_from_fastfield_u64(val, &req.field_type); + let val = f64_from_fastfield_u64(val, req.field_type); let bucket_pos = get_bucket_pos(val); if bounds.contains(val) { - let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { + let bucket = buckets.entry(bucket_pos).or_insert_with(|| { let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); - SegmentHistogramBucketEntry { key, doc_count: 0 } + SegmentHistogramBucketEntry { + key, + doc_count: 0, + bucket_id: self.bucket_id_provider.next_bucket_id(), + } }); bucket.doc_count += 1; - if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { - self.sub_aggregations - .entry(bucket_pos) - .or_insert_with(|| sub_aggregation_blueprint.clone()) - .collect(doc, agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.push(bucket.bucket_id, doc); } } } @@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { .add_memory_consumed(mem_delta as u64)?; } + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } + Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregation in self.sub_aggregations.values_mut() { + if let Some(sub_aggregation) = &mut self.sub_agg { sub_aggregation.flush(agg_data)?; } + Ok(()) + } + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + self.parent_buckets.push(HistogramBuckets { + buckets: FxHashMap::default(), + }); + } Ok(()) } } @@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { impl SegmentHistogramCollector { fn get_memory_consumption(&self) -> usize { let self_mem = std::mem::size_of::(); - let sub_aggs_mem = self.sub_aggregations.memory_consumption(); - let buckets_mem = self.buckets.memory_consumption(); - self_mem + sub_aggs_mem + buckets_mem + let buckets_mem = self.parent_buckets.len() * std::mem::size_of::(); + self_mem + buckets_mem } /// Converts the collector result into a intermediate bucket result. - pub fn into_intermediate_bucket_result( - self, + fn add_intermediate_bucket_result( + &mut self, agg_data: &AggregationsSegmentCtx, + histogram: HistogramBuckets, ) -> crate::Result { - let mut buckets = Vec::with_capacity(self.buckets.len()); + let mut buckets = Vec::with_capacity(histogram.buckets.len()); - for (bucket_pos, bucket) in self.buckets { - let bucket_res = bucket.into_intermediate_bucket_entry( - self.sub_aggregations.get(&bucket_pos).cloned(), - agg_data, - ); + for bucket in histogram.buckets.into_values() { + let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data); buckets.push(bucket_res?); } @@ -408,7 +429,7 @@ impl SegmentHistogramCollector { agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { - let blueprint = if !node.children.is_empty() { + let sub_agg = if !node.children.is_empty() { Some(build_segment_agg_collectors(agg_data, &node.children)?) } else { None @@ -423,13 +444,13 @@ impl SegmentHistogramCollector { max: f64::MAX, }); req_data.offset = req_data.req.offset.unwrap_or(0.0); - - req_data.sub_aggregation_blueprint = blueprint; + let sub_agg = sub_agg.map(CachedSubAggs::new); Ok(Self { - buckets: Default::default(), - sub_aggregations: Default::default(), + parent_buckets: Default::default(), + sub_agg, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index c26872e9b..46e0065ce 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,18 +1,22 @@ use std::fmt::Debug; use std::ops::Range; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::agg_limits::AggregationLimitsGuard; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -23,12 +27,12 @@ pub struct RangeAggReqData { pub accessor: Column, /// The type of the fast field. pub field_type: ColumnType, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The range aggregation request. pub req: RangeAggregation, /// The name of the aggregation. pub name: String, + /// Whether this is a top-level aggregation. + pub is_top_level: bool, } impl RangeAggReqData { @@ -151,19 +155,47 @@ pub(crate) struct SegmentRangeAndBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -pub struct SegmentRangeCollector { +pub struct SegmentRangeCollector { /// The buckets containing the aggregation data. - buckets: Vec, + /// One for each ParentBucketId + parent_buckets: Vec>, column_type: ColumnType, pub(crate) accessor_idx: usize, + sub_agg: Option>, + /// Here things get a bit weird. We need to assign unique bucket ids across all + /// parent buckets. So we keep track of the next available bucket id here. + /// This allows a kind of flattening of the bucket ids across all parent buckets. + /// E.g. in nested aggregations: + /// Term Agg -> Range aggregation -> Stats aggregation + /// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range + /// aggregation with 4 buckets. The Range aggregation will create buckets with ids: + /// - INFO: 0,1,2,3 + /// - ERROR: 4,5,6,7 + /// - WARN: 8,9,10,11 + /// + /// This allows the Stats aggregation to have unique bucket ids to refer to. + bucket_id_provider: BucketIdProvider, + limits: AggregationLimitsGuard, } +impl Debug for SegmentRangeCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeCollector") + .field("parent_buckets_len", &self.parent_buckets.len()) + .field("column_type", &self.column_type) + .field("accessor_idx", &self.accessor_idx) + .field("has_sub_agg", &self.sub_agg.is_some()) + .finish() + } +} + +/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry #[derive(Clone)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key, pub doc_count: u64, - pub sub_aggregation: Option>, + // pub sub_aggregation: Option>, + pub bucket_id: BucketId, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not @@ -184,48 +216,50 @@ impl Debug for SegmentRangeBucketEntry { impl SegmentRangeBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let mut sub_aggregation_res = IntermediateAggregationResults::default(); - if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)? - } else { - Default::default() - }; + let sub_aggregation = IntermediateAggregationResults::default(); Ok(IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, - sub_aggregation: sub_aggregation_res, + sub_aggregation_res: sub_aggregation, from: self.from, to: self.to, }) } } -impl SegmentAggregationCollector for SegmentRangeCollector { +impl SegmentAggregationCollector for SegmentRangeCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let field_type = self.column_type; let name = agg_data .get_range_req_data(self.accessor_idx) .name .to_string(); - let buckets: FxHashMap = self - .buckets + let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + + let buckets: FxHashMap = buckets .into_iter() - .map(move |range_bucket| { - Ok(( - range_to_string(&range_bucket.range, &field_type)?, - range_bucket - .bucket - .into_intermediate_bucket_entry(agg_data)?, - )) + .map(|range_bucket| { + let bucket_id = range_bucket.bucket.bucket_id; + let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?; + if let Some(sub_aggregation) = &mut self.sub_agg { + sub_aggregation + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut agg.sub_aggregation_res, + bucket_id, + )?; + } + Ok((range_to_string(&range_bucket.range, &field_type)?, agg)) }) .collect::>()?; @@ -242,73 +276,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - // Take request data to avoid borrow conflicts during sub-aggregation - let mut req = agg_data.take_range_req_data(self.accessor_idx); + let req = agg_data.take_range_req_data(self.accessor_idx); - req.column_block_accessor.fetch_block(docs, &req.accessor); + agg_data + .column_block_accessor + .fetch_block(docs, &req.accessor); - for (doc, val) in req + let buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + + for (doc, val) in agg_data .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let bucket_pos = self.get_bucket_pos(val); - let bucket = &mut self.buckets[bucket_pos]; + let bucket_pos = get_bucket_pos(val, buckets); + let bucket = &mut buckets[bucket_pos]; bucket.bucket.doc_count += 1; - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.collect(doc, agg_data)?; + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket.bucket_id, doc); } } agg_data.put_back_range_req_data(self.accessor_idx, req); + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.buckets.iter_mut() { - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.flush(agg_data)?; - } + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let new_buckets = self.create_new_buckets(agg_data)?; + self.parent_buckets.push(new_buckets); + } + + Ok(()) + } +} +/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed +/// bucket storage, depending on the column type and aggregation level. +pub(crate) fn build_segment_range_collector( + agg_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let accessor_idx = node.idx_in_req_data; + let req_data = agg_data.get_range_req_data(node.idx_in_req_data); + let field_type = req_data.field_type; + + // TODO: A better metric instead of is_top_level would be the number of buckets expected. + // E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets, + // we can are still in low cardinality territory. + let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64; + + let sub_agg = if !node.children.is_empty() { + Some(build_segment_agg_collectors(agg_data, &node.children)?) + } else { + None + }; + + if is_low_card { + Ok(Box::new(SegmentRangeCollector:: { + sub_agg: sub_agg.map(LowCardCachedSubAggs::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } else { + Ok(Box::new(SegmentRangeCollector:: { + sub_agg: sub_agg.map(CachedSubAggs::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } } -impl SegmentRangeCollector { - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let accessor_idx = node.idx_in_req_data; - let (field_type, ranges) = { - let req_view = req_data.get_range_req_data(node.idx_in_req_data); - (req_view.field_type, req_view.req.ranges.clone()) - }; - +impl SegmentRangeCollector { + pub(crate) fn create_new_buckets( + &mut self, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result> { + let field_type = self.column_type; + let req_data = agg_data.get_range_req_data(self.accessor_idx); // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let sub_agg_prototype = if !node.children.is_empty() { - Some(build_segment_agg_collectors(req_data, &node.children)?) - } else { - None - }; - - let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)? .iter() .map(|range| { + let bucket_id = self.bucket_id_provider.next_bucket_id(); let key = range .key .clone() @@ -317,20 +392,20 @@ impl SegmentRangeCollector { let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, field_type)) }; let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, field_type)) }; - let sub_aggregation = sub_agg_prototype.clone(); + // let sub_aggregation = sub_agg_prototype.clone(); Ok(SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation, + bucket_id, key, from, to, @@ -339,27 +414,20 @@ impl SegmentRangeCollector { }) .collect::>()?; - req_data.context.limits.add_memory_consumed( + self.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; - - Ok(SegmentRangeCollector { - buckets, - column_type: field_type, - accessor_idx, - }) - } - - #[inline] - fn get_bucket_pos(&self, val: u64) -> usize { - let pos = self - .buckets - .binary_search_by_key(&val, |probe| probe.range.start) - .unwrap_or_else(|pos| pos - 1); - debug_assert!(self.buckets[pos].range.contains(&val)); - pos + Ok(buckets) } } +#[inline] +fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize { + let pos = buckets + .binary_search_by_key(&val, |probe| probe.range.start) + .unwrap_or_else(|pos| pos - 1); + debug_assert!(buckets[pos].range.contains(&val)); + pos +} /// Converts the user provided f64 range value to fast field value space. /// @@ -456,7 +524,7 @@ pub(crate) fn range_to_string( let val = i64::from_u64(val); format_date(val) } else { - Ok(f64_from_fastfield_u64(val, field_type).to_string()) + Ok(f64_from_fastfield_u64(val, *field_type).to_string()) } }; @@ -486,7 +554,7 @@ mod tests { pub fn get_collector_from_ranges( ranges: Vec, field_type: ColumnType, - ) -> SegmentRangeCollector { + ) -> SegmentRangeCollector { let req = RangeAggregation { field: "dummy".to_string(), ranges, @@ -506,30 +574,33 @@ mod tests { let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, field_type)) }; let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, field_type)) }; SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation: None, key, from, to, + bucket_id: 0, }, } }) .collect(); SegmentRangeCollector { - buckets, + parent_buckets: vec![buckets], column_type: field_type, accessor_idx: 0, + sub_agg: None, + bucket_id_provider: Default::default(), + limits: AggregationLimitsGuard::default(), } } @@ -776,7 +847,7 @@ mod tests { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -799,7 +870,7 @@ mod tests { ]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -814,7 +885,7 @@ mod tests { let buckets = vec![(-10f64..-1f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); } @@ -823,7 +894,7 @@ mod tests { let buckets = vec![(0f64..10f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); } @@ -832,7 +903,7 @@ mod tests { fn range_binary_search_test_u64() { let check_ranges = |ranges: Vec| { let collector = get_collector_from_ranges(ranges, ColumnType::U64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9), 0); @@ -878,7 +949,7 @@ mod tests { let ranges = vec![(10.0..100.0).into()]; let collector = get_collector_from_ranges(ranges, ColumnType::F64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9f64.to_u64()), 0); @@ -890,63 +961,3 @@ mod tests { // the max value } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use itertools::Itertools; - use rand::seq::SliceRandom; - use rand::thread_rng; - - use super::*; - use crate::aggregation::bucket::range::tests::get_collector_from_ranges; - - const TOTAL_DOCS: u64 = 1_000_000u64; - const NUM_DOCS: u64 = 50_000u64; - - fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector { - let bucket_size = num_docs / num_buckets; - let mut buckets: Vec = vec![]; - for i in 0..num_buckets { - let bucket_start = (i * bucket_size) as f64; - buckets.push((bucket_start..bucket_start + bucket_size as f64).into()) - } - - get_collector_from_ranges(buckets, ColumnType::U64) - } - - fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec { - let mut rng = thread_rng(); - - let all_docs = (0..total_docs - 1).collect_vec(); - let mut vals = all_docs - .as_slice() - .choose_multiple(&mut rng, num_docs_returned as usize) - .cloned() - .collect_vec(); - vals.sort(); - vals - } - - fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) { - let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS); - let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS); - b.iter(|| { - let mut bucket_pos = 0; - for val in &vals { - bucket_pos = collector.get_bucket_pos(*val); - } - bucket_pos - }) - } - - #[bench] - fn bench_range_100_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 100) - } - - #[bench] - fn bench_range_10_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 10) - } -} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index d87cd0078..ed2793bd1 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -4,10 +4,10 @@ use std::net::Ipv6Addr; use columnar::column_values::CompactSpaceU64Accessor; use columnar::{ - Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128, - MonotonicallyMappableToU64, NumericalValue, StrColumn, + Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, + NumericalValue, StrColumn, }; -use common::BitSet; +use common::{BitSet, TinySet}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -17,18 +17,21 @@ use crate::aggregation::agg_data::{ }; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; -use crate::aggregation::buf_collector::BufAggregationCollector; +use crate::aggregation::cached_sub_aggs::{ + CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{format_date, Key}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::{format_date, BucketId, Key}; use crate::error::DataCorruption; use crate::TantivyError; /// Contains all information required by the SegmentTermCollector to perform the /// terms aggregation on a segment. +#[derive(Debug, Clone)] pub struct TermsAggReqData { /// The column accessor to access the fast field values. pub accessor: Column, @@ -38,10 +41,6 @@ pub struct TermsAggReqData { pub str_dict_column: Option, /// The missing value as u64 value. pub missing_value_for_accessor: Option, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, - /// Note: sub_aggregation_blueprint is filled later when building collectors - pub sub_aggregation_blueprint: Option>, /// Used to build the correct nested result when we have an empty result. pub sug_aggregations: Aggregations, /// The name of the aggregation. @@ -257,9 +256,9 @@ pub struct TermsAggregation { /// Internally, `missing` requires some specialized handling in some scenarios. /// /// Simple Case: - /// In the simplest case, we can just put the missing value in the termmap use that. In case of - /// text we put a special u64::MAX and replace it at the end with the actual missing value, - /// when loading the text. + /// In the simplest case, we can just put the missing value in the termmap and use that. In + /// case of text we put a special u64::MAX and replace it at the end with the actual + /// missing value, when loading the text. /// Special Case 1: /// If we have multiple columns on one field, we need to have a union on the indices on both /// columns, to find docids without a value. That requires a special missing aggregation. @@ -334,85 +333,9 @@ impl TermsAggregationInternal { } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector { - #[inline(always)] - fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self { - let sub_agg = sub_agg_blueprint_opt.clone_box(); - BufAggregationCollector::new(sub_agg) - } -} - -#[derive(Debug, Clone)] -struct BoxedAggregation(Box); - -impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation { - #[inline(always)] - fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self { - BoxedAggregation(sub_agg_blueprint.clone_box()) - } -} - -impl SegmentAggregationCollector for BoxedAggregation { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - self.0 - .add_intermediate_aggregation_result(agg_data, results) - } - - #[inline(always)] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect(doc, agg_data) - } - - #[inline(always)] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect_block(docs, agg_data) - } -} - -#[derive(Debug, Clone, Copy)] -struct NoSubAgg; - -impl SegmentAggregationCollector for NoSubAgg { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - _agg_data: &AggregationsSegmentCtx, - _results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect( - &mut self, - _doc: crate::DocId, - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect_block( - &mut self, - _docs: &[crate::DocId], - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } -} +/// The treshold for maximum number of terms to use a Vec-backed bucket storage. +/// TODO: Benchmark to validate the threshold +pub const MAX_NUM_TERMS_FOR_VEC: u64 = 100; /// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed /// bucket storage, depending on the column type and aggregation level. @@ -420,11 +343,8 @@ pub(crate) fn build_segment_term_collector( req_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result> { - let accessor_idx = node.idx_in_req_data; - let column_type = { - let terms_req_data = req_data.get_term_req_data(accessor_idx); - terms_req_data.column_type - }; + let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data).clone(); + let column_type = terms_req_data.column_type; if column_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( @@ -434,7 +354,6 @@ pub(crate) fn build_segment_term_collector( // Validate sub aggregation exists when ordering by sub-aggregation. { - let terms_req_data = req_data.get_term_req_data(accessor_idx); if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); @@ -450,127 +369,115 @@ pub(crate) fn build_segment_term_collector( // Build sub-aggregation blueprint if there are children. let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - { - let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx); - terms_req_data_mut.sub_aggregation_blueprint = blueprint; - } - - // Decide whether to use a Vec-backed or HashMap-backed bucket storage. - let terms_req_data = req_data.get_term_req_data(accessor_idx); // TODO: A better metric instead of is_top_level would be the number of buckets expected. // E.g. If term agg is not top level, but the parent is a bucket agg with less than 10 buckets, // we can still use Vec. - let can_use_vec = terms_req_data.is_top_level; - - // TODO: Benchmark to validate the threshold - const MAX_NUM_TERMS_FOR_VEC: usize = 100; + let is_top_level = terms_req_data.is_top_level; // Let's see if we can use a vec to aggregate our data // instead of a hashmap. let col_max_value = terms_req_data.accessor.max_value(); - let max_term: usize = - col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize; + let max_term_id: u64 = + col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)); - // - use a Vec instead of a hashmap for our aggregation. - // - buffer aggregation of our child aggregations (in any) - #[allow(clippy::collapsible_else_if)] - if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { - if has_sub_aggregations { - let sub_agg_blueprint = &req_data - .get_term_req_data_mut(accessor_idx) - .sub_aggregation_blueprint - .as_ref() - .ok_or_else(|| { - // Handle the error case here - // For example, return an error message or a default value - TantivyError::InternalError("Sub-aggregation blueprint not found".to_string()) - })?; - let term_buckets = VecTermBuckets::new(max_term + 1, || { - let collector_clone = sub_agg_blueprint.clone_box(); - BufAggregationCollector::new(collector_clone) - }); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + let sub_agg_collector = if has_sub_aggregations { + Some(build_segment_agg_collectors(req_data, &node.children)?) } else { - if has_sub_aggregations { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + None + }; + + let mut bucket_id_provider = BucketIdProvider::default(); + // Decide which bucket storage is best suited for this aggregation. + if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations { + let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider); + let collector: SegmentTermCollector<_, HighCardSubAggCache> = SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg: None, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC { + let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider); + let sub_agg = sub_agg_collector.map(LowCardCachedSubAggs::new); + let collector: SegmentTermCollector<_, LowCardSubAggCache> = SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else if max_term_id < 8_000_000 && is_top_level { + let term_buckets: PagedTermMap = + PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::new); + let collector: SegmentTermCollector = + SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) + } else { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::new); + let collector: SegmentTermCollector = + SegmentTermCollector { + parent_buckets: vec![term_buckets], + sub_agg, + bucket_id_provider, + max_term_id, + terms_req_data, + }; + Ok(Box::new(collector)) } } -#[derive(Debug, Clone)] -struct Bucket { +#[derive(Debug, Clone, Copy, Default)] +struct Bucket { pub count: u32, - pub sub_agg: SubAgg, + pub bucket_id: BucketId, } -impl Bucket { +impl Bucket { #[inline(always)] - fn new(sub_agg: SubAgg) -> Self { - Self { count: 0, sub_agg } + fn new(bucket_id: BucketId) -> Self { + Self { + count: 0, + bucket_id, + } } } /// Abstraction over the storage used for term buckets (counts only). trait TermAggregationMap: Clone + Debug + 'static { - type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static; + /// Create a new instance with a strict upper bound on term ids. + fn new(max_term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> Self; /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize; - /// Returns the bucket associated to a given term_id. - fn term_entry( - &mut self, - term_id: u64, - blue_print: &dyn SegmentAggregationCollector, - ) -> &mut Bucket; - - /// If the tree of aggregations contains buffered aggregations, flush them. - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>; + /// Increments the count and returns the bucket_id associated to a given term_id. + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId; /// Returns the term aggregation as a vector of (term_id, bucket) pairs, /// in any order. - fn into_vec(self) -> Vec<(u64, Bucket)>; + fn into_vec(self) -> Vec<(u64, Bucket)>; } #[derive(Clone, Debug)] -struct HashMapTermBuckets { - bucket_map: FxHashMap>, +struct HashMapTermBuckets { + bucket_map: FxHashMap, } -impl Default for HashMapTermBuckets { +impl Default for HashMapTermBuckets { #[inline(always)] fn default() -> Self { Self { @@ -579,83 +486,188 @@ impl Default for HashMapTermBuckets { } } -impl< - SubAgg: Debug - + Clone - + SegmentAggregationCollector - + for<'a> From<&'a dyn SegmentAggregationCollector> - + 'static, - > TermAggregationMap for HashMapTermBuckets -{ - type SubAggregation = SubAgg; +const PAGE_SHIFT: usize = 10; +const PAGE_SIZE: usize = 1 << PAGE_SHIFT; // 1024 +const PAGE_MASK: usize = PAGE_SIZE - 1; +const BITMASK_LEN: usize = PAGE_SIZE / 64; +#[derive(Clone, Debug)] +struct Page { + /// Bitmask indicating which offsets are present. + /// It is chunked into TinySet words. + presence: [TinySet; BITMASK_LEN], + data: [Bucket; PAGE_SIZE], +} + +impl Page { + fn new() -> Self { + Self { + presence: [TinySet::empty(); BITMASK_LEN], + data: [Bucket::default(); PAGE_SIZE], + } + } + + #[inline] + fn is_set(&self, offset: usize) -> bool { + let bucket_idx = offset / 64; + let bit_idx = offset % 64; + self.presence[bucket_idx].contains(bit_idx as u32) + } + + #[inline] + fn set_present(&mut self, offset: usize) { + let bucket_idx = offset / 64; + let bit_idx = offset % 64; + self.presence[bucket_idx].insert_mut(bit_idx as u32); + } + + // Flattened iteration logic + fn collect_items(&self, base_term_id: u64, result: &mut Vec<(u64, Bucket)>) { + for (bucket_pos, &tiny_set) in self.presence.iter().enumerate() { + let base_offset = bucket_pos * 64; + + for bit in tiny_set.into_iter() { + let offset = base_offset + bit as usize; + result.push((base_term_id + offset as u64, self.data[offset])); + } + } + } +} + +/// A paged term map implementation for moderate sized term id sets. +/// Uses a fixed size vector of pages, each page containing a fixed size array of buckets. +/// +/// Each page covers a range of term ids. Pages are allocated on demand. +/// This implementation is more memory efficient than a full Vec for high cardinality term id sets, +/// +/// It has a fixed cost of `num_pages * 8 bytes` for the page directory. +/// For 1 million terms, this is 8 * 1024 = 8KB. +/// +/// Note that for nested aggregations we create one TermAggregationMap per parent bucket. +/// For example, with 100 parent buckets and 1 million terms, this is 800KB overhead for the page +/// directories only. Therefore, this implementation is only enabled for top-level aggregations +/// TODO: pass expected number of buckets from parent instead of strict is_top_level flag. +#[derive(Clone, Debug, Default)] +struct PagedTermMap { + // Fixed size vector based on max_term_id + pages: Vec>>, + mem_usage: usize, +} + +impl PagedTermMap {} + +impl TermAggregationMap for PagedTermMap { + #[inline] + fn get_memory_consumption(&self) -> usize { + self.mem_usage + std::mem::size_of::() + } + + #[inline] + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let term_id = term_id as usize; + let page_idx = term_id >> PAGE_SHIFT; + let offset = term_id & PAGE_MASK; + + // This panics if term_id > max_term_id + let page = match &mut self.pages[page_idx] { + Some(p) => p, + None => { + let new_page = Box::new(Page::new()); + self.mem_usage += std::mem::size_of::(); + self.pages[page_idx] = Some(new_page); + self.pages[page_idx].as_mut().unwrap() + } + }; + + if page.is_set(offset) { + let bucket = &mut page.data[offset]; + bucket.count += 1; + bucket.bucket_id + } else { + let new_id = bucket_id_provider.next_bucket_id(); + page.data[offset] = Bucket { + count: 1, + bucket_id: new_id, + }; + page.set_present(offset); + new_id + } + } + + fn into_vec(self) -> Vec<(u64, Bucket)> { + // estimate 16 entries per non-empty page + let estimated_count = self.pages.iter().filter(|p| p.is_some()).count() * 16; + let mut result = Vec::with_capacity(estimated_count); + + for (i, page_opt) in self.pages.into_iter().enumerate() { + if let Some(page) = page_opt { + let base_term_id = (i << PAGE_SHIFT) as u64; + page.collect_items(base_term_id, &mut result); + } + } + result + } + + /// Initialize with a strict upper bound. + /// Panics if you try to insert a term_id > max_term_id. + fn new(max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + let max_page_idx = (max_term_id as usize) >> PAGE_SHIFT; + let num_pages = max_page_idx + 1; + + // Pre-allocate the directory (pointers only, not the heavy pages) + // Memory cost: num_pages * 8 bytes + let pages = vec![None; num_pages]; + + let mem_usage = pages.capacity() * std::mem::size_of::>>(); + + Self { pages, mem_usage } + } +} + +impl TermAggregationMap for HashMapTermBuckets { #[inline] fn get_memory_consumption(&self) -> usize { self.bucket_map.memory_consumption() } #[inline(always)] - fn term_entry( - &mut self, - term_id: u64, - sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { - self.bucket_map + fn term_entry(&mut self, term_id: u64, bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let bucket = self + .bucket_map .entry(term_id) - .or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint))) + .or_insert_with(|| Bucket::new(bucket_id_provider.next_bucket_id())); + bucket.count += 1; + bucket.bucket_id } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.bucket_map.values_mut() { - bucket.sub_agg.flush(agg_data)?; - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.bucket_map.into_iter().collect() } + + #[inline] + fn new(_max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + Self::default() + } } /// An optimized term map implementation for a compact set of term ordinals. #[derive(Clone, Debug)] -struct VecTermBuckets { - buckets: Vec>, +struct VecTermBucketsNoAgg { + buckets: Vec, } -impl VecTermBuckets { - fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self { - VecTermBuckets { - buckets: std::iter::repeat_with(item_factory_fn) - .map(Bucket::new) - .take(num_terms) - .collect(), - } - } -} - -impl TermAggregationMap - for VecTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for VecTermBucketsNoAgg { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize { // We do not include `std::mem::size_of::()` // It is already measure by the parent aggregation. // - // The root aggregation mem size is not measure but we do not care. - self.buckets.capacity() * std::mem::size_of::>() + self.buckets.capacity() * std::mem::size_of::() } /// Add an occurrence of the given term id. #[inline(always)] - fn term_entry( - &mut self, - term_id: u64, - _sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + fn term_entry(&mut self, term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> BucketId { let term_id_usize = term_id as usize; debug_assert!( term_id_usize < self.buckets.len(), @@ -663,20 +675,69 @@ impl TermAggregat term_id, self.buckets.len() ); - unsafe { self.buckets.get_unchecked_mut(term_id_usize) } + let count = unsafe { self.buckets.get_unchecked_mut(term_id_usize) }; + *count += 1; + 0 // unused } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in &mut self.buckets { - if bucket.count > 0 { - bucket.sub_agg.flush(agg_data)?; - } + fn into_vec(self) -> Vec<(u64, Bucket)> { + self.buckets + .into_iter() + .enumerate() + .filter(|(_term_id, count)| *count > 0) + .map(|(term_id, count)| { + ( + term_id as u64, + Bucket { + count, + bucket_id: 0, // unused, there are no sub-aggregations + }, + ) + }) + .collect() + } + + fn new(num_terms: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self { + Self { + buckets: std::iter::repeat_with(|| 0) + .take(num_terms as usize) + .collect(), } - Ok(()) + } +} + +/// An optimized term map implementation for a compact set of term ordinals. +#[derive(Clone, Debug)] +struct VecTermBuckets { + buckets: Vec, +} + +impl TermAggregationMap for VecTermBuckets { + /// Estimate the memory consumption of this struct in bytes. + fn get_memory_consumption(&self) -> usize { + // We do not include `std::mem::size_of::()` + // It is already measure by the parent aggregation. + // + // The root aggregation mem size is not measure but we do not care. + self.buckets.capacity() * std::mem::size_of::() } - fn into_vec(self) -> Vec<(u64, Bucket)> { + /// Add an occurrence of the given term id. + #[inline(always)] + fn term_entry(&mut self, term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> BucketId { + let term_id_usize = term_id as usize; + debug_assert!( + term_id_usize < self.buckets.len(), + "term_id {} out of bounds for VecTermBuckets (len={})", + term_id, + self.buckets.len() + ); + let bucket = unsafe { self.buckets.get_unchecked_mut(term_id_usize) }; + bucket.count += 1; + bucket.bucket_id + } + + fn into_vec(self) -> Vec<(u64, Bucket)> { self.buckets .into_iter() .enumerate() @@ -684,22 +745,26 @@ impl TermAggregat .map(|(term_id, bucket)| (term_id as u64, bucket)) .collect() } -} -impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg { - #[inline(always)] - fn from(_: &'a dyn SegmentAggregationCollector) -> Self { - Self + fn new(num_terms: u64, bucket_id_provider: &mut BucketIdProvider) -> Self { + VecTermBuckets { + buckets: std::iter::repeat_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) + .take(num_terms as usize) + .collect(), + } } } /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -struct SegmentTermCollector { +#[derive(Debug)] +struct SegmentTermCollector { /// The buckets containing the aggregation data. - term_buckets: TermMap, - accessor_idx: usize, + parent_buckets: Vec, + sub_agg: Option>, + bucket_id_provider: BucketIdProvider, + max_term_id: u64, + terms_req_data: TermsAggReqData, } pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { @@ -707,18 +772,26 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { (agg_name, agg_property) } -impl SegmentAggregationCollector for SegmentTermCollector -where - TermMap: TermAggregationMap, - TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>, +impl SegmentAggregationCollector + for SegmentTermCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + bucket: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + // TODO: avoid prepare_max_bucket here and handle empty buckets. + self.prepare_max_bucket(bucket, agg_data)?; + let bucket = std::mem::replace( + &mut self.parent_buckets[bucket as usize], + TermMap::new(0, &mut self.bucket_id_provider), + ); + let term_req = &self.terms_req_data; + let name = term_req.name.clone(); + + let bucket = + Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) } @@ -726,65 +799,49 @@ where #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req_data = agg_data.take_term_req_data(self.accessor_idx); - let mem_pre = self.get_memory_consumption(); - if let Some(missing) = req_data.missing_value_for_accessor { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } + let req_data = &mut self.terms_req_data; - if std::any::TypeId::of::() == std::any::TypeId::of::() { - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg); - bucket.count += 1; + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &req_data.accessor, + req_data.missing_value_for_accessor, + ); + + if let Some(sub_agg) = &mut self.sub_agg { + let term_buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + let it = agg_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&(_doc, term_id)| allowed_bs.contains(term_id as u32)); + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); + } else { + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); } } else { - let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref() - else { - return Err(TantivyError::InternalError( - "Could not find sub-aggregation blueprint".to_string(), - )); - }; - for (doc, term_id) in req_data - .column_block_accessor - .iter_docid_vals(docs, &req_data.accessor) - { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self - .term_buckets - .term_entry(term_id, sub_aggregation_blueprint); - bucket.count += 1; - bucket.sub_agg.collect(doc, agg_data)?; + let term_buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + let it = agg_data.column_block_accessor.iter_vals(); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&term_id| allowed_bs.contains(term_id as u32)); + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); + } else { + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); } } @@ -795,14 +852,31 @@ where .limits .add_memory_consumed(mem_delta as u64)?; } - agg_data.put_back_term_req_data(self.accessor_idx, req_data); + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - #[inline(always)] + #[inline] fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.flush(agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.flush(agg_data)?; + } + Ok(()) + } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let term_buckets: TermMap = + TermMap::new(self.max_term_id, &mut self.bucket_id_provider); + self.parent_buckets.push(term_buckets); + } Ok(()) } } @@ -831,20 +905,26 @@ fn extract_missing_value( Some((key, bucket)) } -impl SegmentTermCollector -where TermMap: TermAggregationMap +impl SegmentTermCollector +where + TermMap: TermAggregationMap, + C: SubAggCache, { fn get_memory_consumption(&self) -> usize { - self.term_buckets.get_memory_consumption() + self.parent_buckets + .iter() + .map(|b| b.get_memory_consumption()) + .sum() } #[inline] pub(crate) fn into_intermediate_bucket_result( - self, + term_req: &TermsAggReqData, + sub_agg: &mut Option>, + term_buckets: TermMap, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, Bucket)> = self.term_buckets.into_vec(); + let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec(); let order_by_sub_aggregation = matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); @@ -884,23 +964,28 @@ where TermMap: TermAggregationMap dict.reserve(entries.len()); let into_intermediate_bucket_entry = - |bucket: Bucket| -> crate::Result { - let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { + |bucket: Bucket, + sub_agg: &mut Option>| + -> crate::Result { + if let Some(sub_agg) = sub_agg { let mut sub_aggregation_res = IntermediateAggregationResults::default(); - // TODO remove box new - Box::new(bucket.sub_agg) - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - IntermediateTermBucketEntry { + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + bucket.bucket_id, + )?; + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: sub_aggregation_res, - } + }) } else { - IntermediateTermBucketEntry { + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: Default::default(), - } - }; - Ok(intermediate_entry) + }) + } }; if term_req.column_type == ColumnType::Str { @@ -913,21 +998,20 @@ where TermMap: TermAggregationMap if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req) { - let intermediate_entry = into_intermediate_bucket_entry(bucket)?; + let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?; dict.insert(intermediate_key, intermediate_entry); } // Sort by term ord entries.sort_unstable_by_key(|bucket| bucket.0); - let (term_ids, buckets): (Vec, Vec>) = - entries.into_iter().unzip(); + let (term_ids, buckets): (Vec, Vec) = entries.into_iter().unzip(); let mut buckets_it = buckets.into_iter(); term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| { let bucket = buckets_it.next().unwrap(); let intermediate_entry = - into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?; + into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?; dict.insert( IntermediateKey::Str( String::from_utf8(term.to_vec()).expect("could not convert to String"), @@ -969,14 +1053,14 @@ where TermMap: TermAggregationMap } } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } @@ -996,14 +1080,14 @@ where TermMap: TermAggregationMap })?; for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val = Ipv6Addr::from_u128(val); dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); } } else { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); } else if term_req.column_type == ColumnType::I64 { @@ -1037,6 +1121,32 @@ where TermMap: TermAggregationMap } } +impl SegmentTermCollector { + #[inline] + fn collect_terms_with_docs( + iter: impl Iterator, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + sub_agg: &mut CachedSubAggs, + ) { + for (doc, term_id) in iter { + let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider); + sub_agg.push(bucket_id, doc); + } + } + + #[inline] + fn collect_terms( + iter: impl Iterator, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + ) { + for term_id in iter { + term_buckets.term_entry(term_id, bucket_id_provider); + } + } +} + pub(crate) trait GetDocCount { fn doc_count(&self) -> u64; } @@ -1047,7 +1157,7 @@ impl GetDocCount for (String, IntermediateTermBucketEntry) { } } -impl GetDocCount for (u64, Bucket) { +impl GetDocCount for (u64, Bucket) { fn doc_count(&self) -> u64 { self.1.count as u64 } @@ -1079,8 +1189,10 @@ mod tests { use common::DateTime; use time::{Date, Month}; + use super::{PagedTermMap, TermAggregationMap, PAGE_SIZE}; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; + use crate::aggregation::segment_agg_result::BucketIdProvider; use crate::aggregation::tests::{ exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, @@ -1091,6 +1203,43 @@ mod tests { use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING}; use crate::{Index, IndexWriter}; + #[test] + fn paged_term_map_reuses_buckets_and_counts() { + let mut bucket_id_provider = BucketIdProvider::default(); + let mut map = PagedTermMap::new((PAGE_SIZE * 2) as u64, &mut bucket_id_provider); + + let bucket_first = map.term_entry(5, &mut bucket_id_provider); + let bucket_second_page = map.term_entry((PAGE_SIZE + 7) as u64, &mut bucket_id_provider); + + // Reinsertions should increment counts and reuse bucket ids + assert_eq!(map.term_entry(5, &mut bucket_id_provider), bucket_first); + assert_eq!( + map.term_entry((PAGE_SIZE + 7) as u64, &mut bucket_id_provider), + bucket_second_page + ); + + // High offset exercises the TinySet presence word boundaries. + let bucket_high_bit = map.term_entry(63, &mut bucket_id_provider); + + let mut entries = map.into_vec(); + entries.sort_by_key(|(term_id, _)| *term_id); + + let expected = vec![ + (5u64, bucket_first, 2u32), + (63u64, bucket_high_bit, 1u32), + ((PAGE_SIZE + 7) as u64, bucket_second_page, 2u32), + ]; + + assert_eq!(entries.len(), expected.len()); + for ((term_id, bucket), (expected_term, expected_bucket_id, expected_count)) in + entries.into_iter().zip(expected) + { + assert_eq!(term_id, expected_term); + assert_eq!(bucket.bucket_id, expected_bucket_id); + assert_eq!(bucket.count, expected_count); + } + } + #[test] fn terms_aggregation_test_single_segment() -> crate::Result<()> { terms_aggregation_test_merge_segment(true) diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 66f39927a..47c3989c6 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; use crate::aggregation::bucket::term_agg::TermsAggregation; +use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; /// Special aggregation to handle missing values for term aggregations. /// This missing aggregation will check multiple columns for existence. @@ -35,41 +37,55 @@ impl MissingTermAggReqData { } } -/// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] -pub struct TermMissingAgg { +struct MissingCount { missing_count: u32, + bucket_id: BucketId, +} + +/// The specialized missing term aggregation. +#[derive(Default, Debug)] +pub struct TermMissingAgg { accessor_idx: usize, - sub_agg: Option>, + sub_agg: Option, + /// Idx = parent bucket id, Value = missing count for that bucket + missing_count_per_bucket: Vec, + bucket_id_provider: BucketIdProvider, } impl TermMissingAgg { pub(crate) fn new( - req_data: &mut AggregationsSegmentCtx, + agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { let has_sub_aggregations = !node.children.is_empty(); let accessor_idx = node.idx_in_req_data; let sub_agg = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?; Some(sub_aggregation) } else { None }; + let sub_agg = sub_agg.map(CachedSubAggs::new); + let bucket_id_provider = BucketIdProvider::default(); + Ok(Self { accessor_idx, sub_agg, - ..Default::default() + missing_count_per_bucket: Vec::new(), + bucket_id_provider, }) } } impl SegmentAggregationCollector for TermMissingAgg { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let term_agg = &req_data.req; let missing = term_agg @@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg { let mut entries: FxHashMap = Default::default(); + let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize]; let mut missing_entry = IntermediateTermBucketEntry { - doc_count: self.missing_count, + doc_count: missing_count.missing_count, sub_aggregation: Default::default(), }; - if let Some(sub_agg) = self.sub_agg { + if let Some(sub_agg) = &mut self.sub_agg { let mut res = IntermediateAggregationResults::default(); - sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?; missing_entry.sub_aggregation = res; } entries.insert(missing.into(), missing_entry); @@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize]; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); - let has_value = req_data - .accessors - .iter() - .any(|(acc, _)| acc.index.has_value(doc)); - if !has_value { - self.missing_count += 1; - if let Some(sub_agg) = self.sub_agg.as_mut() { - sub_agg.collect(doc, agg_data)?; + + for doc in docs { + let doc = *doc; + let has_value = req_data + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); + if !has_value { + bucket.missing_count += 1; + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket_id, doc); + } } } + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - for doc in docs { - self.collect(*doc, agg_data)?; + while self.missing_count_per_bucket.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.missing_count_per_bucket.push(MissingCount { + missing_count: 0, + bucket_id, + }); + } + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs deleted file mode 100644 index 17bc1ed35..000000000 --- a/src/aggregation/buf_collector.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::agg_data::AggregationsSegmentCtx; -use crate::DocId; - -#[cfg(test)] -pub(crate) const DOC_BLOCK_SIZE: usize = 64; - -#[cfg(not(test))] -pub(crate) const DOC_BLOCK_SIZE: usize = 256; - -pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; - -/// BufAggregationCollector buffers documents before calling collect_block(). -#[derive(Clone)] -pub(crate) struct BufAggregationCollector { - pub(crate) collector: Box, - staged_docs: DocBlock, - num_staged_docs: usize, -} - -impl std::fmt::Debug for BufAggregationCollector { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SegmentAggregationResultsCollector") - .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) - .field("num_staged_docs", &self.num_staged_docs) - .finish() - } -} - -impl BufAggregationCollector { - pub fn new(collector: Box) -> Self { - Self { - collector, - num_staged_docs: 0, - staged_docs: [0; DOC_BLOCK_SIZE], - } - } -} - -impl SegmentAggregationCollector for BufAggregationCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.staged_docs[self.num_staged_docs] = doc; - self.num_staged_docs += 1; - if self.num_staged_docs == self.staged_docs.len() { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - } - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collector.collect_block(docs, agg_data)?; - Ok(()) - } - - #[inline] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - - self.collector.flush(agg_data)?; - - Ok(()) - } -} diff --git a/src/aggregation/cached_sub_aggs.rs b/src/aggregation/cached_sub_aggs.rs new file mode 100644 index 000000000..f97da31ab --- /dev/null +++ b/src/aggregation/cached_sub_aggs.rs @@ -0,0 +1,245 @@ +use std::fmt::Debug; + +use super::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC; +use crate::aggregation::BucketId; +use crate::DocId; + +/// A cache for sub-aggregations, storing doc ids per bucket id. +/// Depending on the cardinality of the parent aggregation, we use different +/// storage strategies. +/// +/// ## Low Cardinality +/// Cardinality here refers to the number of unique flattened buckets that can be created +/// by the parent aggregation. +/// Flattened buckets are the result of combining all buckets per collector +/// into a single list of buckets, where each bucket is identified by its BucketId. +/// +/// ## Usage +/// Since this is caching for sub-aggregations, it is only used by bucket +/// aggregations. +/// +/// TODO: consider using a more advanced data structure for high cardinality +/// aggregations. +/// What this datastructure does in general is to group docs by bucket id. +#[derive(Debug)] +pub(crate) struct CachedSubAggs { + cache: C, + sub_agg_collector: Box, + num_docs: usize, +} + +pub type LowCardCachedSubAggs = CachedSubAggs; +pub type HighCardCachedSubAggs = CachedSubAggs; + +const FLUSH_THRESHOLD: usize = 2048; + +/// A trait for caching sub-aggregation doc ids per bucket id. +/// Different implementations can be used depending on the cardinality +/// of the parent aggregation. +pub trait SubAggCache: Debug { + fn new() -> Self; + fn push(&mut self, bucket_id: BucketId, doc_id: DocId); + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + force: bool, + ) -> crate::Result<()>; +} + +impl CachedSubAggs { + pub fn new(sub_agg: Box) -> Self { + Self { + cache: Backend::new(), + sub_agg_collector: sub_agg, + num_docs: 0, + } + } + + pub fn get_sub_agg_collector(&mut self) -> &mut Box { + &mut self.sub_agg_collector + } + + #[inline] + pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + self.cache.push(bucket_id, doc_id); + self.num_docs += 1; + } + + /// Check if we need to flush based on the number of documents cached. + /// If so, flushes the cache to the provided aggregation collector. + pub fn check_flush_local( + &mut self, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if self.num_docs >= FLUSH_THRESHOLD { + self.cache + .flush_local(&mut self.sub_agg_collector, agg_data, false)?; + self.num_docs = 0; + } + Ok(()) + } + + /// Note: this _does_ flush the sub aggregations. + pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if self.num_docs != 0 { + self.cache + .flush_local(&mut self.sub_agg_collector, agg_data, true)?; + self.num_docs = 0; + } + self.sub_agg_collector.flush(agg_data)?; + Ok(()) + } +} + +/// Number of partitions for high cardinality sub-aggregation cache. +const NUM_PARTITIONS: usize = 16; + +#[derive(Debug)] +pub(crate) struct HighCardSubAggCache { + /// This weird partitioning is used to do some cheap grouping on the bucket ids. + /// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality, + /// but there are just 16 bucket ids, each bucket id will go to its own partition. + /// + /// We want to keep this cheap, because high cardinality aggregations can have a lot of + /// buckets, and there may be nothing to group. + partitions: Box<[PartitionEntry; NUM_PARTITIONS]>, +} + +impl HighCardSubAggCache { + #[inline] + fn clear(&mut self) { + for partition in self.partitions.iter_mut() { + partition.clear(); + } + } +} + +#[derive(Debug, Clone, Default)] +struct PartitionEntry { + bucket_ids: Vec, + docs: Vec, +} + +impl PartitionEntry { + #[inline] + fn clear(&mut self) { + self.bucket_ids.clear(); + self.docs.clear(); + } +} + +impl SubAggCache for HighCardSubAggCache { + fn new() -> Self { + Self { + partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())), + } + } + + fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + let idx = bucket_id % NUM_PARTITIONS as u32; + let slot = &mut self.partitions[idx as usize]; + slot.bucket_ids.push(bucket_id); + slot.docs.push(doc_id); + } + + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + _force: bool, + ) -> crate::Result<()> { + let mut max_bucket = 0u32; + for partition in self.partitions.iter() { + if let Some(&local_max) = partition.bucket_ids.iter().max() { + max_bucket = max_bucket.max(local_max); + } + } + + sub_agg.prepare_max_bucket(max_bucket, agg_data)?; + + for slot in self.partitions.iter() { + if !slot.bucket_ids.is_empty() { + // Reduce dynamic dispatch overhead by collecting a full partition in one call. + sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?; + } + } + + self.clear(); + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct LowCardSubAggCache { + /// Cache doc ids per bucket for sub-aggregations. + /// + /// The outer Vec is indexed by BucketId. + per_bucket_docs: Vec>, +} + +impl LowCardSubAggCache { + #[inline] + fn clear(&mut self) { + for v in &mut self.per_bucket_docs { + v.clear(); + } + } +} + +impl SubAggCache for LowCardSubAggCache { + fn new() -> Self { + Self { + per_bucket_docs: Vec::new(), + } + } + + fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + let idx = bucket_id as usize; + if self.per_bucket_docs.len() <= idx { + self.per_bucket_docs.resize_with(idx + 1, Vec::new); + } + self.per_bucket_docs[idx].push(doc_id); + } + + fn flush_local( + &mut self, + sub_agg: &mut Box, + agg_data: &mut AggregationsSegmentCtx, + force: bool, + ) -> crate::Result<()> { + // Pre-aggregated: call collect per bucket. + let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1); + sub_agg.prepare_max_bucket(max_bucket, agg_data)?; + // The threshold above which we flush buckets individually. + // Note: We need to make sure that we don't lock ourselves into a situation where we hit + // the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush) + let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2); + const _: () = { + // MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations + // Note: There may be other flexible values, for other aggregations, but we can use the + // const value here as a upper bound. (better than nothing) + let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2); + assert!( + bucket_treshold_limit > 0, + "Bucket threshold must be greater than 0" + ); + }; + if force { + bucket_treshold = 0; + } + for (bucket_id, docs) in self + .per_bucket_docs + .iter() + .enumerate() + .filter(|(_, docs)| docs.len() > bucket_treshold) + { + sub_agg.collect(bucket_id as BucketId, docs, agg_data)?; + } + + self.clear(); + Ok(()) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 4c4c2c7f1..59e9c677d 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,9 +1,9 @@ use super::agg_req::Aggregations; use super::agg_result::AggregationResults; -use super::buf_collector::BufAggregationCollector; +use super::cached_sub_aggs::LowCardCachedSubAggs; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; use super::AggContextParams; +// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly. use crate::aggregation::agg_data::{ build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; @@ -136,7 +136,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsSegmentCtx, - agg_collector: BufAggregationCollector, + agg_collector: LowCardCachedSubAggs, error: Option, } @@ -151,8 +151,11 @@ impl AggregationSegmentCollector { ) -> crate::Result { let mut agg_data = build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; - let result = - BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); + let mut result = + LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?); + result + .get_sub_agg_collector() + .prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero Ok(AggregationSegmentCollector { aggs_with_accessor: agg_data, @@ -170,26 +173,31 @@ impl SegmentCollector for AggregationSegmentCollector { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.push(0, doc); + match self .agg_collector - .collect(doc, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } - - /// The query pushes the documents to the collector via this method. - /// - /// Only valid for Collectors that ignore docs fn collect_block(&mut self, docs: &[DocId]) { if self.error.is_some() { return; } - if let Err(err) = self - .agg_collector - .collect_block(docs, &mut self.aggs_with_accessor) - { - self.error = Some(err); + + match self.agg_collector.get_sub_agg_collector().collect( + 0, + docs, + &mut self.aggs_with_accessor, + ) { + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } @@ -200,10 +208,13 @@ impl SegmentCollector for AggregationSegmentCollector { self.agg_collector.flush(&mut self.aggs_with_accessor)?; let mut sub_aggregation_res = IntermediateAggregationResults::default(); - Box::new(self.agg_collector).add_intermediate_aggregation_result( - &self.aggs_with_accessor, - &mut sub_aggregation_res, - )?; + self.agg_collector + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + 0, + )?; Ok(sub_aggregation_res) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 104131461..b20e8a042 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry { /// The number of documents in the bucket. pub doc_count: u64, /// The sub_aggregation in this bucket. - pub sub_aggregation: IntermediateAggregationResults, + pub sub_aggregation_res: IntermediateAggregationResults, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. @@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, sub_aggregation: self - .sub_aggregation + .sub_aggregation_res .into_final_result_internal(req, limits)?, to: self.to, from: self.from, @@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry { impl MergeFruits for IntermediateRangeBucketEntry { fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> { self.doc_count += other.doc_count; - self.sub_aggregation.merge_fruits(other.sub_aggregation)?; + self.sub_aggregation_res + .merge_fruits(other.sub_aggregation_res)?; Ok(()) } } @@ -887,7 +888,7 @@ mod tests { IntermediateRangeBucketEntry { key: IntermediateKey::Str(key.to_string()), doc_count: *doc_count, - sub_aggregation: Default::default(), + sub_aggregation_res: Default::default(), from: None, to: None, }, @@ -920,7 +921,7 @@ mod tests { doc_count: *doc_count, from: None, to: None, - sub_aggregation: get_sub_test_tree(&[( + sub_aggregation_res: get_sub_test_tree(&[( sub_aggregation_key.to_string(), *sub_aggregation_count, )]), diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index e707f2b00..57f694984 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -52,10 +52,8 @@ pub struct IntermediateAverage { impl IntermediateAverage { /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateAverage) { diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 8f3bdd3e5..c184848d8 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{BuildHasher, Hasher}; use columnar::column_values::CompactSpaceU64Accessor; -use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn}; +use columnar::{Column, ColumnType, Dictionary, StrColumn}; use common::f64_to_u64; use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; use rustc_hash::FxHashSet; @@ -106,8 +106,6 @@ pub struct CardinalityAggReqData { pub str_dict_column: Option, /// The missing value normalized to the internal u64 representation of the field type. pub missing_value_for_accessor: Option, - /// The column block accessor to access the fast field values. - pub(crate) column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, /// The aggregation request. @@ -135,45 +133,34 @@ impl CardinalityAggregationReq { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentCardinalityCollector { - cardinality: CardinalityCollector, - entries: FxHashSet, + buckets: Vec, accessor_idx: usize, + /// The column accessor to access the fast field values. + accessor: Column, + /// The column_type of the field. + column_type: ColumnType, + /// The missing value normalized to the internal u64 representation of the field type. + missing_value_for_accessor: Option, } -impl SegmentCardinalityCollector { - pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { +#[derive(Clone, Debug, PartialEq, Default)] +pub(crate) struct SegmentCardinalityCollectorBucket { + cardinality: CardinalityCollector, + entries: FxHashSet, +} +impl SegmentCardinalityCollectorBucket { + pub fn new(column_type: ColumnType) -> Self { Self { cardinality: CardinalityCollector::new(column_type as u8), - entries: Default::default(), - accessor_idx, + entries: FxHashSet::default(), } } - - fn fetch_block_with_field( - &mut self, - docs: &[crate::DocId], - agg_data: &mut CardinalityAggReqData, - ) { - if let Some(missing) = agg_data.missing_value_for_accessor { - agg_data.column_block_accessor.fetch_block_with_missing( - docs, - &agg_data.accessor, - missing, - ); - } else { - agg_data - .column_block_accessor - .fetch_block(docs, &agg_data.accessor); - } - } - fn into_intermediate_metric_result( mut self, - agg_data: &AggregationsSegmentCtx, + req_data: &CardinalityAggReqData, ) -> crate::Result { - let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); if req_data.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); let dict = req_data @@ -194,6 +181,7 @@ impl SegmentCardinalityCollector { term_ids.push(term_ord as u32); } } + term_ids.sort_unstable(); dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| { self.cardinality.sketch.insert_any(&term); @@ -227,16 +215,49 @@ impl SegmentCardinalityCollector { } } +impl SegmentCardinalityCollector { + pub fn from_req( + column_type: ColumnType, + accessor_idx: usize, + accessor: Column, + missing_value_for_accessor: Option, + ) -> Self { + Self { + buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1], + column_type, + accessor_idx, + accessor, + missing_value_for_accessor, + } + } + + fn fetch_block_with_field( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) { + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_value_for_accessor, + ); + } +} + impl SegmentAggregationCollector for SegmentCardinalityCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); let name = req_data.name.to_string(); + // take the bucket in buckets and replace it with a new empty one + let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); - let intermediate_result = self.into_intermediate_metric_result(agg_data)?; + let intermediate_result = bucket.into_intermediate_metric_result(req_data)?; results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); - self.fetch_block_with_field(docs, req_data); + self.fetch_block_with_field(docs, agg_data); + let bucket = &mut self.buckets[parent_bucket_id as usize]; - let col_block_accessor = &req_data.column_block_accessor; - if req_data.column_type == ColumnType::Str { + let col_block_accessor = &agg_data.column_block_accessor; + if self.column_type == ColumnType::Str { for term_ord in col_block_accessor.iter_vals() { - self.entries.insert(term_ord); + bucket.entries.insert(term_ord); } - } else if req_data.column_type == ColumnType::IpAddr { - let compact_space_accessor = req_data + } else if self.column_type == ColumnType::IpAddr { + let compact_space_accessor = self .accessor .values .clone() @@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { })?; for val in col_block_accessor.iter_vals() { let val: u128 = compact_space_accessor.compact_to_u128(val as u32); - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } else { for val in col_block_accessor.iter_vals() { - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + if max_bucket as usize >= self.buckets.len() { + self.buckets.resize_with(max_bucket as usize + 1, || { + SegmentCardinalityCollectorBucket::new(self.column_type) + }); + } + Ok(()) + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs index ac550a38f..b28ced047 100644 --- a/src/aggregation/metric/count.rs +++ b/src/aggregation/metric/count.rs @@ -52,10 +52,8 @@ pub struct IntermediateCount { impl IntermediateCount { /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateCount) { diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs index d7302e5f5..e71426790 100644 --- a/src/aggregation/metric/extended_stats.rs +++ b/src/aggregation/metric/extended_stats.rs @@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of extended statistics /// on numeric values that are extracted @@ -318,51 +317,28 @@ impl IntermediateExtendedStats { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentExtendedStatsCollector { + name: String, missing: Option, field_type: ColumnType, - pub(crate) extended_stats: IntermediateExtendedStats, - pub(crate) accessor_idx: usize, - val_cache: Vec, + accessor: columnar::Column, + buckets: Vec, + sigma: Option, } impl SegmentExtendedStatsCollector { - pub fn from_req( - field_type: ColumnType, - sigma: Option, - accessor_idx: usize, - missing: Option, - ) -> Self { - let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + pub fn from_req(req: &MetricAggReqData, sigma: Option) -> Self { + let missing = req + .missing + .and_then(|val| f64_to_fastfield_u64(val, &req.field_type)); Self { - field_type, - extended_stats: IntermediateExtendedStats::with_sigma(sigma), - accessor_idx, + name: req.name.clone(), + field_type: req.field_type, + accessor: req.accessor.clone(), missing, - val_cache: Default::default(), - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = self.missing.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); + buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16], + sigma, } } } @@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector { impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + let name = self.name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); results.push( name, IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( - self.extended_stats, + extended_stats, )), )?; @@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = self.missing { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - has_val = true; - } - if !has_val { - self.extended_stats - .collect(f64_from_fastfield_u64(missing, &self.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - } + let mut extended_stats = self.buckets[parent_bucket_id as usize].clone(); + + agg_data + .column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, self.missing); + for val in agg_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, self.field_type); + extended_stats.collect(val1); } + // store back + self.buckets[parent_bucket_id as usize] = extended_stats; + Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + if self.buckets.len() <= max_bucket as usize { + self.buckets.resize_with(max_bucket as usize + 1, || { + IntermediateExtendedStats::with_sigma(self.sigma) + }); + } Ok(()) } } diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs index 89c6e4458..59af7e2de 100644 --- a/src/aggregation/metric/max.rs +++ b/src/aggregation/metric/max.rs @@ -52,10 +52,8 @@ pub struct IntermediateMax { impl IntermediateMax { /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMax) { diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs index 61fd2ecd2..ecf2fcafc 100644 --- a/src/aggregation/metric/min.rs +++ b/src/aggregation/metric/min.rs @@ -52,10 +52,8 @@ pub struct IntermediateMin { impl IntermediateMin { /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMin) { diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 3537af8a6..d3a448a38 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -31,7 +31,7 @@ use std::collections::HashMap; pub use average::*; pub use cardinality::*; -use columnar::{Column, ColumnBlockAccessor, ColumnType}; +use columnar::{Column, ColumnType}; pub use count::*; pub use extended_stats::*; pub use max::*; @@ -55,8 +55,6 @@ pub struct MetricAggReqData { pub field_type: ColumnType, /// The missing value normalized to the internal u64 representation of the field type. pub missing_u64: Option, - /// The column block accessor to access the fast field values. - pub column_block_accessor: ColumnBlockAccessor, /// The column accessor to access the fast field values. pub accessor: Column, /// Used when converting to intermediate result diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index c846e2187..ff9de45f1 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// # Percentiles /// @@ -131,10 +130,16 @@ impl PercentilesAggregationReq { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentPercentilesCollector { - pub(crate) percentiles: PercentilesCollector, + pub(crate) buckets: Vec, pub(crate) accessor_idx: usize, + /// The type of the field. + pub field_type: ColumnType, + /// The missing value normalized to the internal u64 representation of the field type. + pub missing_u64: Option, + /// The column accessor to access the fast field values. + pub accessor: Column, } #[derive(Clone, Serialize, Deserialize)] @@ -229,33 +234,18 @@ impl PercentilesCollector { } impl SegmentPercentilesCollector { - pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result { - Ok(Self { - percentiles: PercentilesCollector::new(), + pub fn from_req_and_validate( + field_type: ColumnType, + missing_u64: Option, + accessor: Column, + accessor_idx: usize, + ) -> Self { + Self { + buckets: Vec::with_capacity(64), + field_type, + missing_u64, + accessor, accessor_idx, - }) - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); } } } @@ -263,12 +253,18 @@ impl SegmentPercentilesCollector { impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); - let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + // Swap collector with an empty one to avoid cloning + let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + + let intermediate_metric_result = + IntermediateMetricResult::Percentiles(percentiles_collector); results.push( name, @@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); + let percentiles = &mut self.buckets[parent_bucket_id as usize]; + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_u64, + ); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - has_val = true; - } - if !has_val { - self.percentiles - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } + for val in agg_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, self.field_type); + percentiles.collect(val1); } Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(PercentilesCollector::new()); + } Ok(()) } } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 56715fdea..c43a6a259 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; +use columnar::{Column, ColumnType}; use serde::{Deserialize, Serialize}; use super::*; @@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// are extracted from the aggregated documents. @@ -83,7 +83,7 @@ impl Stats { /// Intermediate result of the stats aggregation that can be combined with other intermediate /// results. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { /// The number of extracted values. pub(crate) count: u64, @@ -187,75 +187,75 @@ pub enum StatsType { Percentiles, } +fn create_collector( + req: &MetricAggReqData, +) -> Box { + Box::new(SegmentStatsCollector:: { + name: req.name.clone(), + collecting_for: req.collecting_for, + is_number_or_date_type: req.is_number_or_date_type, + missing_u64: req.missing_u64, + accessor: req.accessor.clone(), + buckets: vec![IntermediateStats::default()], + }) +} + +/// Build a concrete `SegmentStatsCollector` depending on the column type. +pub(crate) fn build_segment_stats_collector( + req: &MetricAggReqData, +) -> crate::Result> { + match req.field_type { + ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)), + ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)), + ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)), + ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)), + ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)), + ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)), + ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)), + ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)), + } +} + +#[repr(C)] #[derive(Clone, Debug)] -pub(crate) struct SegmentStatsCollector { - pub(crate) stats: IntermediateStats, - pub(crate) accessor_idx: usize, +pub(crate) struct SegmentStatsCollector { + pub(crate) missing_u64: Option, + pub(crate) accessor: Column, + pub(crate) is_number_or_date_type: bool, + pub(crate) buckets: Vec, + pub(crate) name: String, + pub(crate) collecting_for: StatsType, } -impl SegmentStatsCollector { - pub fn from_req(accessor_idx: usize) -> Self { - Self { - stats: IntermediateStats::default(), - accessor_idx, - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - if req_data.is_number_or_date_type { - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } else { - for _val in req_data.column_block_accessor.iter_vals() { - // we ignore the value and simply record that we got something - self.stats.collect(0.0); - } - } - } -} - -impl SegmentAggregationCollector for SegmentStatsCollector { +impl SegmentAggregationCollector + for SegmentStatsCollector +{ #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let req = agg_data.get_metric_req_data(self.accessor_idx); - let name = req.name.clone(); + let name = self.name.clone(); - let intermediate_metric_result = match req.collecting_for { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let stats = self.buckets[parent_bucket_id as usize]; + let intermediate_metric_result = match self.collecting_for { StatsType::Average => { - IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) + IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats)) } StatsType::Count => { - IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) + IntermediateMetricResult::Count(IntermediateCount::from_stats(stats)) } - StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), - StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), - StatsType::Stats => IntermediateMetricResult::Stats(self.stats), - StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), + StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)), + StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)), + StatsType::Stats => IntermediateMetricResult::Stats(stats), + StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)), _ => { return Err(TantivyError::InvalidArgument(format!( "Unsupported stats type for stats aggregation: {:?}", - req.collecting_for + self.collecting_for ))) } }; @@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - has_val = true; - } - if !has_val { - self.stats - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } - - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + // TODO: remove once we fetch all values for all bucket ids in one go + if docs.len() == 1 && self.missing_u64.is_none() { + collect_stats::( + &mut self.buckets[parent_bucket_id as usize], + self.accessor.values_for_doc(docs[0]), + self.is_number_or_date_type, + )?; + + return Ok(()); + } + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &self.accessor, + self.missing_u64, + ); + collect_stats::( + &mut self.buckets[parent_bucket_id as usize], + agg_data.column_block_accessor.iter_vals(), + self.is_number_or_date_type, + )?; + Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + let required_buckets = (max_bucket as usize) + 1; + if self.buckets.len() < required_buckets { + self.buckets + .resize_with(required_buckets, IntermediateStats::default); + } + Ok(()) + } +} + +#[inline] +fn collect_stats( + stats: &mut IntermediateStats, + vals: impl Iterator, + is_number_or_date_type: bool, +) -> crate::Result<()> { + if is_number_or_date_type { + for val in vals { + let val1 = convert_to_f64::(val); + stats.collect(val1); + } + } else { + for _val in vals { + // we ignore the value and simply record that we got something + stats.collect(0.0); + } + } + + Ok(()) } #[cfg(test)] diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs index 86f661679..2487c4e9d 100644 --- a/src/aggregation/metric/sum.rs +++ b/src/aggregation/metric/sum.rs @@ -52,10 +52,8 @@ pub struct IntermediateSum { impl IntermediateSum { /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateSum) { diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 6a8bdf826..54e5a5ced 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -15,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::AggregationError; +use crate::aggregation::{AggregationError, BucketId}; use crate::collector::sort_key::ReverseComparator; use crate::collector::TopNComputer; use crate::schema::OwnedValue; use crate::{DocAddress, DocId, SegmentOrdinal}; -// duplicate import removed; already imported above /// Contains all information required by the TopHitsSegmentCollector to perform the /// top_hits aggregation on a segment. @@ -472,7 +471,10 @@ impl TopHitsTopNComputer { /// Create a new TopHitsCollector pub fn new(req: &TopHitsAggregationReq) -> Self { Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + top_n: TopNComputer::new_with_comparator( + req.size + req.from.unwrap_or(0), + ReverseComparator, + ), req: req.clone(), } } @@ -518,7 +520,8 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, ReverseComparator>, + buckets: Vec, DocAddress, ReverseComparator>>, + num_hits: usize, } impl TopHitsSegmentCollector { @@ -527,19 +530,29 @@ impl TopHitsSegmentCollector { accessor_idx: usize, segment_ordinal: SegmentOrdinal, ) -> Self { + let num_hits = req.size + req.from.unwrap_or(0); Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + num_hits, segment_ordinal, accessor_idx, + buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1], } } - fn into_top_hits_collector( - self, + fn get_top_hits_computer( + &mut self, + parent_bucket_id: BucketId, value_accessors: &HashMap>, req: &TopHitsAggregationReq, ) -> TopHitsTopNComputer { + if parent_bucket_id as usize >= self.buckets.len() { + return TopHitsTopNComputer::new(req); + } + let top_n = std::mem::replace( + &mut self.buckets[parent_bucket_id as usize], + TopNComputer::new(0), + ); let mut top_hits_computer = TopHitsTopNComputer::new(req); - let top_results = self.top_n.into_vec(); + let top_results = top_n.into_vec(); for res in top_results { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); @@ -554,54 +567,24 @@ impl TopHitsSegmentCollector { top_hits_computer } - - /// TODO add a specialized variant for a single sort field - fn collect_with( - &mut self, - doc_id: crate::DocId, - req: &TopHitsAggregationReq, - accessors: &[(Column, ColumnType)], - ) -> crate::Result<()> { - let sorts: Vec = req - .sort - .iter() - .enumerate() - .map(|(idx, KeyOrder { order, .. })| { - let order = *order; - let value = accessors - .get(idx) - .expect("could not find field in accessors") - .0 - .values_for_doc(doc_id) - .next(); - DocValueAndOrder { value, order } - }) - .collect(); - - self.top_n.push( - sorts, - DocAddress { - segment_ord: self.segment_ordinal, - doc_id, - }, - ); - Ok(()) - } } impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let value_accessors = &req_data.value_accessors; - let intermediate_result = IntermediateMetricResult::TopHits( - self.into_top_hits_collector(value_accessors, &req_data.req), - ); + let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer( + parent_bucket_id, + value_accessors, + &req_data.req, + )); results.push( req_data.name.to_string(), IntermediateAggregationResult::Metric(intermediate_result), @@ -611,26 +594,56 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector { /// TODO: Consider a caching layer to reduce the call overhead fn collect( &mut self, - doc_id: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - self.collect_with(doc_id, &req_data.req, &req_data.accessors)?; - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let top_n = &mut self.buckets[parent_bucket_id as usize]; let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - // TODO: Consider getting fields with the column block accessor. - for doc in docs { - self.collect_with(*doc, &req_data.req, &req_data.accessors)?; + let req = &req_data.req; + let accessors = &req_data.accessors; + for &doc_id in docs { + // TODO: this is terrible, a new vec is allocated for every doc + // We can fetch blocks instead + // We don't need to store the order for every value + let sorts: Vec = req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + DocValueAndOrder { value, order } + }) + .collect(); + + top_n.push( + sorts, + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.buckets.resize( + (max_bucket as usize) + 1, + TopNComputer::new_with_comparator(self.num_hits, ReverseComparator), + ); + Ok(()) + } } #[cfg(test)] @@ -746,7 +759,7 @@ mod tests { ], "from": 0, } - } + } })) .unwrap(); @@ -875,7 +888,7 @@ mod tests { "mixed.*", ], } - } + } }))?; let collector = AggregationCollector::from_aggs(d, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ddf60ea4c..b4a080d6a 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -133,7 +133,7 @@ mod agg_limits; pub mod agg_req; pub mod agg_result; pub mod bucket; -mod buf_collector; +pub(crate) mod cached_sub_aggs; mod collector; mod date; mod error; @@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::TokenizerManager; +/// A bucket id is a dense identifier for a bucket within an aggregation. +/// It is used to index into a Vec that hold per-bucket data. +/// +/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId. +/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket. +/// +/// This allows to have a single AggregationCollector instance per aggregation, +/// that can handle multiple buckets efficiently. +/// +/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])]. +/// For that we'll need a buffer. One Vec per bucket aggregation is needed. +pub type BucketId = u32; + /// Context parameters for aggregation execution /// /// This struct holds shared resources needed during aggregation execution: @@ -335,19 +348,37 @@ impl Display for Key { } } +pub(crate) fn convert_to_f64(val: u64) -> f64 { + if COLUMN_TYPE_ID == ColumnType::U64 as u8 { + val as f64 + } else if COLUMN_TYPE_ID == ColumnType::I64 as u8 + || COLUMN_TYPE_ID == ColumnType::DateTime as u8 + { + i64::from_u64(val) as f64 + } else if COLUMN_TYPE_ID == ColumnType::F64 as u8 { + f64::from_u64(val) + } else if COLUMN_TYPE_ID == ColumnType::Bool as u8 { + val as f64 + } else { + panic!( + "ColumnType ID {} cannot be converted to f64 metric", + COLUMN_TYPE_ID + ) + } +} + /// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics. /// /// # Panics /// Only `u64`, `f64`, `date`, and `i64` are supported. -pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 { +pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 { match field_type { - ColumnType::U64 => val as f64, - ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64, - ColumnType::F64 => f64::from_u64(val), - ColumnType::Bool => val as f64, - _ => { - panic!("unexpected type {field_type:?}. This should not happen") - } + ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val), + ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val), + ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val), + ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val), + ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val), + _ => panic!("unexpected type {field_type:?}. This should not happen"), } } diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5cc2650b6..7bd13f1cd 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -8,25 +8,67 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; + +/// Monotonically increasing provider of BucketIds. +#[derive(Debug, Clone, Default)] +pub struct BucketIdProvider(u32); +impl BucketIdProvider { + /// Get the next BucketId. + pub fn next_bucket_id(&mut self) -> BucketId { + let bucket_id = self.0; + self.0 += 1; + bucket_id + } +} /// A SegmentAggregationCollector is used to collect aggregation results. -pub trait SegmentAggregationCollector: CollectorClone + Debug { +pub trait SegmentAggregationCollector: Debug { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()>; + /// Note: The caller needs to call `prepare_max_bucket` before calling `collect`. fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; - fn collect_block( + /// Collect docs for multiple buckets in one call. + /// Minimizes dynamic dispatch overhead when collecting many buckets. + /// + /// Note: The caller needs to call `prepare_max_bucket` before calling `collect`. + fn collect_multiple( &mut self, + bucket_ids: &[BucketId], docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + debug_assert_eq!(bucket_ids.len(), docs.len()); + let mut start = 0; + while start < bucket_ids.len() { + let bucket_id = bucket_ids[start]; + let mut end = start + 1; + while end < bucket_ids.len() && bucket_ids[end] == bucket_id { + end += 1; + } + self.collect(bucket_id, &docs[start..end], agg_data)?; + start = end; + } + Ok(()) + } + + /// Prepare the collector for collecting up to BucketId `max_bucket`. + /// This is useful so we can split allocation ahead of time of collecting. + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()>; /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. @@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug { } } -/// A helper trait to enable cloning of Box -pub trait CollectorClone { - fn clone_box(&self) -> Box; -} - -impl CollectorClone for T -where T: 'static + SegmentAggregationCollector + Clone -{ - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.clone_box() - } -} - -#[derive(Clone, Default)] +#[derive(Default)] /// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one /// and can provide specialized versions instead, that remove some of its overhead. @@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - for agg in self.aggs { - agg.add_intermediate_aggregation_result(agg_data, results)?; + for agg in &mut self.aggs { + agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?; } Ok(()) @@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data)?; - - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for collector in &mut self.aggs { - collector.collect_block(docs, agg_data)?; + collector.collect(parent_bucket_id, docs, agg_data)?; } - Ok(()) } @@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + for collector in &mut self.aggs { + collector.prepare_max_bucket(max_bucket, agg_data)?; + } + Ok(()) + } } diff --git a/src/core/executor.rs b/src/core/executor.rs index 8cc7e0026..f11644599 100644 --- a/src/core/executor.rs +++ b/src/core/executor.rs @@ -48,7 +48,15 @@ impl Executor { F: Sized + Sync + Fn(A) -> crate::Result, { match self { - Executor::SingleThread => args.map(f).collect::>(), + Executor::SingleThread => { + // Avoid `collect`, since the stacktrace is blown up by it, which makes profiling + // harder. + let mut result = Vec::with_capacity(args.size_hint().0); + for arg in args { + result.push(f(arg)?); + } + Ok(result) + } Executor::ThreadPool(pool) => { let args: Vec = args.collect(); let num_fruits = args.len(); From d904630e6a03cb8269f68275392c179304b04d5a Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 8 Jan 2026 15:50:22 +0100 Subject: [PATCH 23/26] Bumped bitpacking version (#2797) Co-authored-by: Paul Masurel --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 40eff7814..a6d59f335 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ uuid = { version = "1.0.0", features = ["v4", "serde"] } crossbeam-channel = "0.5.4" rust-stemmers = { version = "1.2.0", optional = true } downcast-rs = "2.0.1" -bitpacking = { version = "0.9.2", default-features = false, features = [ +bitpacking = { version = "0.9.3", default-features = false, features = [ "bitpacker4x", ] } census = "0.4.2" From 947c0d5f4054f2d3e471f3c339b9ce33d6c011f5 Mon Sep 17 00:00:00 2001 From: Alex Lazar Date: Fri, 9 Jan 2026 23:25:51 -0800 Subject: [PATCH 24/26] Bump lru to 0.16.3 per dependabot --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a6d59f335..a2731789e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ fail = { version = "0.5.0", optional = true } time = { version = "0.3.35", features = ["serde-well-known"] } smallvec = "1.8.0" rayon = "1.5.2" -lru = "0.12.0" +lru = "0.16.3" fastdivide = "0.4.0" itertools = "0.14.0" measure_time = "0.9.0" From c92e831dde738163054729a2276ed74dfa1a8eee Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Mon, 12 Jan 2026 13:53:43 +0100 Subject: [PATCH 25/26] Minor refactoring in PostingsSerializer (#2801) Removes the Write generics argument in PostingsSerializer. This removes useless generic. Prepares the path for codecs. Removes one useless CountingWrite layer. etc. Co-authored-by: Paul Masurel --- src/postings/segment_postings.rs | 7 ++- src/postings/serializer.rs | 86 +++++++++++++++----------------- 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index d9ba33eb2..e9046bd3c 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -70,13 +70,13 @@ impl SegmentPostings { let mut buffer = Vec::new(); { let mut postings_serializer = - PostingsSerializer::new(&mut buffer, 0.0, IndexRecordOption::Basic, None); + PostingsSerializer::new(0.0, IndexRecordOption::Basic, None); postings_serializer.new_term(docs.len() as u32, false); for &doc in docs { postings_serializer.write_doc(doc, 1u32); } postings_serializer - .close_term(docs.len() as u32) + .close_term(docs.len() as u32, &mut buffer) .expect("In memory Serialization should never fail."); } let block_segment_postings = BlockSegmentPostings::open( @@ -115,7 +115,6 @@ impl SegmentPostings { }) .unwrap_or(0.0); let mut postings_serializer = PostingsSerializer::new( - &mut buffer, average_field_norm, IndexRecordOption::WithFreqs, fieldnorm_reader, @@ -125,7 +124,7 @@ impl SegmentPostings { postings_serializer.write_doc(doc, tf); } postings_serializer - .close_term(doc_and_tfs.len() as u32) + .close_term(doc_and_tfs.len() as u32, &mut buffer) .unwrap(); let block_segment_postings = BlockSegmentPostings::open( doc_and_tfs.len() as u32, diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index c0ee8483c..08c3c7542 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -104,10 +104,12 @@ impl InvertedIndexSerializer { /// the serialization of a specific field. pub struct FieldSerializer<'a> { term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter>, - postings_serializer: PostingsSerializer<&'a mut CountingWriter>, + postings_serializer: PostingsSerializer, positions_serializer_opt: Option>>, current_term_info: TermInfo, term_open: bool, + postings_write: &'a mut CountingWriter, + postings_start_offset: u64, } impl<'a> FieldSerializer<'a> { @@ -128,27 +130,30 @@ impl<'a> FieldSerializer<'a> { .as_ref() .map(|ff_reader| total_num_tokens as Score / ff_reader.num_docs() as Score) .unwrap_or(0.0); - let postings_serializer = PostingsSerializer::new( - postings_write, - average_fieldnorm, - index_record_option, - fieldnorm_reader, - ); + let postings_serializer = + PostingsSerializer::new(average_fieldnorm, index_record_option, fieldnorm_reader); let positions_serializer_opt = if index_record_option.has_positions() { Some(PositionSerializer::new(positions_write)) } else { None }; + let postings_start_offset = postings_write.written_bytes(); Ok(FieldSerializer { term_dictionary_builder, postings_serializer, positions_serializer_opt, current_term_info: TermInfo::default(), term_open: false, + postings_write, + postings_start_offset, }) } + fn postings_offset(&self) -> usize { + (self.postings_write.written_bytes() - self.postings_start_offset) as usize + } + fn current_term_info(&self) -> TermInfo { let positions_start = if let Some(positions_serializer) = self.positions_serializer_opt.as_ref() { @@ -156,7 +161,7 @@ impl<'a> FieldSerializer<'a> { } else { 0u64 } as usize; - let addr = self.postings_serializer.written_bytes() as usize; + let addr = self.postings_offset(); TermInfo { doc_freq: 0, postings_range: addr..addr, @@ -213,21 +218,22 @@ impl<'a> FieldSerializer<'a> { crate::fail_point!("FieldSerializer::close_term", |msg: Option| { Err(io::Error::new(io::ErrorKind::Other, format!("{msg:?}"))) }); - if self.term_open { - self.postings_serializer - .close_term(self.current_term_info.doc_freq)?; - self.current_term_info.postings_range.end = - self.postings_serializer.written_bytes() as usize; - if let Some(positions_serializer) = self.positions_serializer_opt.as_mut() { - positions_serializer.close_term()?; - self.current_term_info.positions_range.end = - positions_serializer.written_bytes() as usize; - } - self.term_dictionary_builder - .insert_value(&self.current_term_info)?; - self.term_open = false; + if !self.term_open { + return Ok(()); + }; + + self.postings_serializer + .close_term(self.current_term_info.doc_freq, self.postings_write)?; + self.current_term_info.postings_range.end = self.postings_offset(); + if let Some(positions_serializer) = self.positions_serializer_opt.as_mut() { + positions_serializer.close_term()?; + self.current_term_info.positions_range.end = + positions_serializer.written_bytes() as usize; } + self.term_dictionary_builder + .insert_value(&self.current_term_info)?; + self.term_open = false; Ok(()) } @@ -237,7 +243,7 @@ impl<'a> FieldSerializer<'a> { if let Some(positions_serializer) = self.positions_serializer_opt { positions_serializer.close()?; } - self.postings_serializer.close()?; + self.postings_write.flush()?; self.term_dictionary_builder.finish()?; Ok(()) } @@ -291,8 +297,7 @@ impl Block { } } -pub struct PostingsSerializer { - output_write: CountingWriter, +pub struct PostingsSerializer { last_doc_id_encoded: u32, block_encoder: BlockEncoder, @@ -310,16 +315,13 @@ pub struct PostingsSerializer { term_has_freq: bool, } -impl PostingsSerializer { +impl PostingsSerializer { pub fn new( - write: W, avg_fieldnorm: Score, mode: IndexRecordOption, fieldnorm_reader: Option, - ) -> PostingsSerializer { + ) -> PostingsSerializer { PostingsSerializer { - output_write: CountingWriter::wrap(write), - block_encoder: BlockEncoder::new(), block: Box::new(Block::new()), @@ -422,11 +424,11 @@ impl PostingsSerializer { } } - fn close(mut self) -> io::Result<()> { - self.postings_write.flush() - } - - pub fn close_term(&mut self, doc_freq: u32) -> io::Result<()> { + pub fn close_term( + &mut self, + doc_freq: u32, + output_write: &mut impl std::io::Write, + ) -> io::Result<()> { if !self.block.is_empty() { // we have doc ids waiting to be written // this happens when the number of doc ids is @@ -451,26 +453,16 @@ impl PostingsSerializer { } if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { let skip_data = self.skip_write.data(); - VInt(skip_data.len() as u64).serialize(&mut self.output_write)?; - self.output_write.write_all(skip_data)?; + VInt(skip_data.len() as u64).serialize(output_write)?; + output_write.write_all(skip_data)?; } - self.output_write.write_all(&self.postings_write[..])?; + output_write.write_all(&self.postings_write[..])?; self.skip_write.clear(); self.postings_write.clear(); self.bm25_weight = None; Ok(()) } - /// Returns the number of bytes written in the postings write object - /// at this point. - /// When called before writing the postings of a term, this value is used as - /// start offset. - /// When called after writing the postings of a term, this value is used as a - /// end offset. - fn written_bytes(&self) -> u64 { - self.output_write.written_bytes() - } - fn clear(&mut self) { self.block.clear(); self.last_doc_id_encoded = 0; From 12977bc7c4cfa6e97936e8e0ecd871183bced2a4 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Wed, 14 Jan 2026 10:19:09 +0100 Subject: [PATCH 26/26] upgrade some dependancies (#2802) including rand, which had a few breaking changes --- Cargo.toml | 6 +++--- benches/agg_bench.rs | 12 ++++++------ benches/and_or_queries.rs | 16 ++++++++-------- benches/bool_queries_with_range.rs | 8 ++++---- benches/range_queries.rs | 4 ++-- benches/range_query.rs | 8 ++++---- bitpacker/Cargo.toml | 2 +- bitpacker/benches/bench.rs | 4 ++-- columnar/Cargo.toml | 2 +- columnar/benches/bench_column_values_get.rs | 2 +- columnar/benches/bench_create_column_values.rs | 2 +- columnar/benches/bench_optional_index.rs | 4 ++-- columnar/benches/bench_values_u128.rs | 2 +- columnar/benches/bench_values_u64.rs | 2 +- columnar/src/column_values/u64_based/linear.rs | 2 +- columnar/src/column_values/u64_based/tests.rs | 2 +- common/Cargo.toml | 2 +- common/benches/bench.rs | 4 ++-- common/src/bitset.rs | 2 +- src/collector/facet_collector.rs | 17 +++++++---------- src/collector/sort_key_top_collector.rs | 2 +- src/fastfield/alive_bitset.rs | 4 ++-- src/fastfield/mod.rs | 2 +- src/functional_test.rs | 18 +++++++++--------- src/lib.rs | 4 ++-- src/postings/compression/mod.rs | 5 ++++- src/postings/mod.rs | 6 +++--- src/query/phrase_query/regex_phrase_weight.rs | 2 +- src/query/range_query/range_query.rs | 2 +- src/query/range_query/range_query_fastfield.rs | 2 +- src/query/term_query/term_scorer.rs | 4 ++-- src/termdict/fst_termdict/merger.rs | 6 +++--- sstable/Cargo.toml | 2 +- sstable/benches/stream_bench.rs | 4 ++-- stacker/Cargo.toml | 4 ++-- stacker/benches/bench.rs | 4 ++-- stacker/fuzz_test/Cargo.toml | 4 ++-- stacker/fuzz_test/src/main.rs | 2 +- 38 files changed, 90 insertions(+), 90 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a2731789e..476117656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ regex = { version = "1.5.5", default-features = false, features = [ aho-corasick = "1.0" tantivy-fst = "0.5" memmap2 = { version = "0.9.0", optional = true } -lz4_flex = { version = "0.11", default-features = false, optional = true } +lz4_flex = { version = "0.12", default-features = false, optional = true } zstd = { version = "0.13", optional = true, default-features = false } tempfile = { version = "3.12.0", optional = true } log = "0.4.16" @@ -76,7 +76,7 @@ winapi = "0.3.9" [dev-dependencies] binggan = "0.14.2" -rand = "0.8.5" +rand = "0.9" maplit = "1.0.2" matches = "0.1.9" pretty_assertions = "1.2.1" @@ -85,7 +85,7 @@ test-log = "0.2.10" futures = "0.3.21" paste = "1.0.11" more-asserts = "0.3.1" -rand_distr = "0.4.3" +rand_distr = "0.5" time = { version = "0.3.10", features = ["serde-well-known", "macros"] } postcard = { version = "1.0.4", features = [ "use-std", diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 642532597..9313cca7a 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -1,8 +1,8 @@ use binggan::plugins::PeakMemAllocPlugin; use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM}; -use rand::distributions::WeightedIndex; -use rand::prelude::SliceRandom; +use rand::distr::weighted::WeightedIndex; use rand::rngs::StdRng; +use rand::seq::IndexedRandom; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; use serde_json::json; @@ -532,7 +532,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { // Prepare 1000 unique terms sampled using a Zipf distribution. // Exponent ~1.1 approximates top-20 terms covering around ~20%. let terms_1000: Vec = (1..=1000).map(|i| format!("term_{i}")).collect(); - let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap(); + let zipf_1000 = rand_distr::Zipf::new(1000.0, 1.1f64).unwrap(); { let mut rng = StdRng::from_seed([1u8; 32]); @@ -576,8 +576,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { } let _val_max = 1_000_000.0; for _ in 0..doc_with_value { - let val: f64 = rng.gen_range(0.0..1_000_000.0); - let json = if rng.gen_bool(0.1) { + let val: f64 = rng.random_range(0.0..1_000_000.0); + let json = if rng.random_bool(0.1) { // 10% are numeric values json!({ "mixed_type": val }) } else { @@ -586,7 +586,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { index_writer.add_document(doc!( text_field => "cool", json_field => json, - text_field_all_unique_terms => format!("unique_term_{}", rng.gen::()), + text_field_all_unique_terms => format!("unique_term_{}", rng.random::()), text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(), text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0, text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(), diff --git a/benches/and_or_queries.rs b/benches/and_or_queries.rs index 805061c18..5dd213685 100644 --- a/benches/and_or_queries.rs +++ b/benches/and_or_queries.rs @@ -55,29 +55,29 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench { let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap(); for _ in 0..num_docs { - let has_a = rng.gen_bool(p_a as f64); - let has_b = rng.gen_bool(p_b as f64); - let has_c = rng.gen_bool(p_c as f64); - let score = rng.gen_range(0u64..100u64); - let score2 = rng.gen_range(0u64..100_000u64); + let has_a = rng.random_bool(p_a as f64); + let has_b = rng.random_bool(p_b as f64); + let has_c = rng.random_bool(p_c as f64); + let score = rng.random_range(0u64..100u64); + let score2 = rng.random_range(0u64..100_000u64); let mut title_tokens: Vec<&str> = Vec::new(); let mut body_tokens: Vec<&str> = Vec::new(); if has_a { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("a"); } else { body_tokens.push("a"); } } if has_b { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("b"); } else { body_tokens.push("b"); } } if has_c { - if rng.gen_bool(0.1) { + if rng.random_bool(0.1) { title_tokens.push("c"); } else { body_tokens.push("c"); diff --git a/benches/bool_queries_with_range.rs b/benches/bool_queries_with_range.rs index 9123ccf3a..9b2849300 100644 --- a/benches/bool_queries_with_range.rs +++ b/benches/bool_queries_with_range.rs @@ -36,13 +36,13 @@ fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> "dense" => { for doc_id in 0..num_docs { // Always add title to avoid empty documents - let title_token = if rng.gen_bool(p_title_a as f64) { + let title_token = if rng.random_bool(p_title_a as f64) { "a" } else { "b" }; - let num_rand = rng.gen_range(0u64..1000u64); + let num_rand = rng.random_range(0u64..1000u64); let num_asc = (doc_id / 10000) as u64; @@ -60,13 +60,13 @@ fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> "sparse" => { for doc_id in 0..num_docs { // Always add title to avoid empty documents - let title_token = if rng.gen_bool(p_title_a as f64) { + let title_token = if rng.random_bool(p_title_a as f64) { "a" } else { "b" }; - let num_rand = rng.gen_range(0u64..10000000u64); + let num_rand = rng.random_range(0u64..10000000u64); let num_asc = doc_id as u64; diff --git a/benches/range_queries.rs b/benches/range_queries.rs index 56aaf54b9..c8095a01b 100644 --- a/benches/range_queries.rs +++ b/benches/range_queries.rs @@ -33,7 +33,7 @@ fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex { match distribution { "dense" => { for doc_id in 0..num_docs { - let num_rand = rng.gen_range(0u64..1000u64); + let num_rand = rng.random_range(0u64..1000u64); let num_asc = (doc_id / 10000) as u64; writer @@ -46,7 +46,7 @@ fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex { } "sparse" => { for doc_id in 0..num_docs { - let num_rand = rng.gen_range(0u64..10000000u64); + let num_rand = rng.random_range(0u64..10000000u64); let num_asc = doc_id as u64; writer diff --git a/benches/range_query.rs b/benches/range_query.rs index bf46666f3..e0feddd66 100644 --- a/benches/range_query.rs +++ b/benches/range_query.rs @@ -97,20 +97,20 @@ fn get_index_0_to_100() -> Index { let num_vals = 100_000; let docs: Vec<_> = (0..num_vals) .map(|_i| { - let id_name = if rng.gen_bool(0.01) { + let id_name = if rng.random_bool(0.01) { "veryfew".to_string() // 1% - } else if rng.gen_bool(0.1) { + } else if rng.random_bool(0.1) { "few".to_string() // 9% } else { "most".to_string() // 90% }; Doc { id_name, - id: rng.gen_range(0..100), + id: rng.random_range(0..100), // Multiply by 1000, so that we create most buckets in the compact space // The benches depend on this range to select n-percent of elements with the // methods below. - ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000), + ip: Ipv6Addr::from_u128(rng.random_range(0..100) * 1000), } }) .collect(); diff --git a/bitpacker/Cargo.toml b/bitpacker/Cargo.toml index 3b2a3e15e..945bd0082 100644 --- a/bitpacker/Cargo.toml +++ b/bitpacker/Cargo.toml @@ -18,5 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy" bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] } [dev-dependencies] -rand = "0.8" +rand = "0.9" proptest = "1" diff --git a/bitpacker/benches/bench.rs b/bitpacker/benches/bench.rs index 7544687c2..12bfeb53e 100644 --- a/bitpacker/benches/bench.rs +++ b/bitpacker/benches/bench.rs @@ -4,8 +4,8 @@ extern crate test; #[cfg(test)] mod tests { + use rand::rng; use rand::seq::IteratorRandom; - use rand::thread_rng; use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker}; use test::Bencher; @@ -27,7 +27,7 @@ mod tests { let num_els = 1_000_000u32; let bit_unpacker = BitUnpacker::new(bit_width); let data = create_bitpacked_data(bit_width, num_els); - let idxs: Vec = (0..num_els).choose_multiple(&mut thread_rng(), 100_000); + let idxs: Vec = (0..num_els).choose_multiple(&mut rng(), 100_000); b.iter(|| { let mut out = 0u64; for &idx in &idxs { diff --git a/columnar/Cargo.toml b/columnar/Cargo.toml index 9eeafe2d0..b91ab36ff 100644 --- a/columnar/Cargo.toml +++ b/columnar/Cargo.toml @@ -22,7 +22,7 @@ downcast-rs = "2.0.1" [dev-dependencies] proptest = "1" more-asserts = "0.3.1" -rand = "0.8" +rand = "0.9" binggan = "0.14.0" [[bench]] diff --git a/columnar/benches/bench_column_values_get.rs b/columnar/benches/bench_column_values_get.rs index d486b0dde..f2c1674ef 100644 --- a/columnar/benches/bench_column_values_get.rs +++ b/columnar/benches/bench_column_values_get.rs @@ -9,7 +9,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_and_load_u64_based_co fn get_data() -> Vec { let mut rng = StdRng::seed_from_u64(2u64); let mut data: Vec<_> = (100..55_000_u64) - .map(|num| num + rng.r#gen::() as u64) + .map(|num| num + rng.random::() as u64) .collect(); data.push(99_000); data.insert(1000, 2000); diff --git a/columnar/benches/bench_create_column_values.rs b/columnar/benches/bench_create_column_values.rs index aa04e0661..339dbb199 100644 --- a/columnar/benches/bench_create_column_values.rs +++ b/columnar/benches/bench_create_column_values.rs @@ -6,7 +6,7 @@ use tantivy_columnar::column_values::{CodecType, serialize_u64_based_column_valu fn get_data() -> Vec { let mut rng = StdRng::seed_from_u64(2u64); let mut data: Vec<_> = (100..55_000_u64) - .map(|num| num + rng.r#gen::() as u64) + .map(|num| num + rng.random::() as u64) .collect(); data.push(99_000); data.insert(1000, 2000); diff --git a/columnar/benches/bench_optional_index.rs b/columnar/benches/bench_optional_index.rs index c157f1455..03ff1df97 100644 --- a/columnar/benches/bench_optional_index.rs +++ b/columnar/benches/bench_optional_index.rs @@ -8,7 +8,7 @@ const TOTAL_NUM_VALUES: u32 = 1_000_000; fn gen_optional_index(fill_ratio: f64) -> OptionalIndex { let mut rng: StdRng = StdRng::from_seed([1u8; 32]); let vals: Vec = (0..TOTAL_NUM_VALUES) - .map(|_| rng.gen_bool(fill_ratio)) + .map(|_| rng.random_bool(fill_ratio)) .enumerate() .filter(|(_pos, val)| *val) .map(|(pos, _)| pos as u32) @@ -25,7 +25,7 @@ fn random_range_iterator( let mut rng: StdRng = StdRng::from_seed([1u8; 32]); let mut current = start; std::iter::from_fn(move || { - current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation); + current += rng.random_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation); if current >= end { None } else { Some(current) } }) } diff --git a/columnar/benches/bench_values_u128.rs b/columnar/benches/bench_values_u128.rs index e0b4f0a1f..09173c678 100644 --- a/columnar/benches/bench_values_u128.rs +++ b/columnar/benches/bench_values_u128.rs @@ -39,7 +39,7 @@ fn get_data_50percent_item() -> Vec { let mut data = vec![]; for _ in 0..300_000 { - let val = rng.gen_range(1..=100); + let val = rng.random_range(1..=100); data.push(val); } data.push(SINGLE_ITEM); diff --git a/columnar/benches/bench_values_u64.rs b/columnar/benches/bench_values_u64.rs index 36711c776..f0419d8c6 100644 --- a/columnar/benches/bench_values_u64.rs +++ b/columnar/benches/bench_values_u64.rs @@ -34,7 +34,7 @@ fn get_data_50percent_item() -> Vec { let mut data = vec![]; for _ in 0..300_000 { - let val = rng.gen_range(1..=100); + let val = rng.random_range(1..=100); data.push(val); } data.push(SINGLE_ITEM); diff --git a/columnar/src/column_values/u64_based/linear.rs b/columnar/src/column_values/u64_based/linear.rs index dbfa13a4c..7caf3bdfb 100644 --- a/columnar/src/column_values/u64_based/linear.rs +++ b/columnar/src/column_values/u64_based/linear.rs @@ -268,7 +268,7 @@ mod tests { #[test] fn linear_interpol_fast_field_rand() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..50 { let mut data = (0..10_000).map(|_| rng.next_u64()).collect::>(); create_and_validate::(&data, "random"); diff --git a/columnar/src/column_values/u64_based/tests.rs b/columnar/src/column_values/u64_based/tests.rs index 6b2697263..ff5b7051a 100644 --- a/columnar/src/column_values/u64_based/tests.rs +++ b/columnar/src/column_values/u64_based/tests.rs @@ -122,7 +122,7 @@ pub(crate) fn create_and_validate( assert_eq!(vals, buffer); if !vals.is_empty() { - let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1); + let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1); let expected_positions: Vec = vals .iter() .enumerate() diff --git a/common/Cargo.toml b/common/Cargo.toml index 206329d39..e5e922869 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -21,5 +21,5 @@ serde = { version = "1.0.136", features = ["derive"] } [dev-dependencies] binggan = "0.14.0" proptest = "1.0.0" -rand = "0.8.4" +rand = "0.9" diff --git a/common/benches/bench.rs b/common/benches/bench.rs index 81260e116..a0b1f9451 100644 --- a/common/benches/bench.rs +++ b/common/benches/bench.rs @@ -1,6 +1,6 @@ use binggan::{BenchRunner, black_box}; +use rand::rng; use rand::seq::IteratorRandom; -use rand::thread_rng; use tantivy_common::{BitSet, TinySet, serialize_vint_u32}; fn bench_vint() { @@ -17,7 +17,7 @@ fn bench_vint() { black_box(out); }); - let vals: Vec = (0..20_000).choose_multiple(&mut thread_rng(), 100_000); + let vals: Vec = (0..20_000).choose_multiple(&mut rng(), 100_000); runner.bench_function("bench_vint_rand", move |_| { let mut out = 0u64; for val in vals.iter().cloned() { diff --git a/common/src/bitset.rs b/common/src/bitset.rs index 94e4ca5ae..e005ca40b 100644 --- a/common/src/bitset.rs +++ b/common/src/bitset.rs @@ -416,7 +416,7 @@ mod tests { use std::collections::HashSet; use ownedbytes::OwnedBytes; - use rand::distributions::Bernoulli; + use rand::distr::Bernoulli; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index a94ec03e8..6eb2c3ee7 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -486,9 +486,9 @@ mod tests { use std::collections::BTreeSet; use columnar::Dictionary; - use rand::distributions::Uniform; + use rand::distr::Uniform; use rand::prelude::SliceRandom; - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use super::{FacetCollector, FacetCounts}; use crate::collector::facet_collector::compress_mapping; @@ -731,7 +731,7 @@ mod tests { let schema = schema_builder.build(); let index = Index::create_in_ram(schema); - let uniform = Uniform::new_inclusive(1, 100_000); + let uniform = Uniform::new_inclusive(1, 100_000).unwrap(); let mut docs: Vec = vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)] .into_iter() @@ -741,14 +741,11 @@ mod tests { std::iter::repeat_n(doc, count) }) .map(|mut doc| { - doc.add_facet( - facet_field, - &format!("/facet/{}", thread_rng().sample(uniform)), - ); + doc.add_facet(facet_field, &format!("/facet/{}", rng().sample(uniform))); doc }) .collect(); - docs[..].shuffle(&mut thread_rng()); + docs[..].shuffle(&mut rng()); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for doc in docs { @@ -822,8 +819,8 @@ mod tests { #[cfg(all(test, feature = "unstable"))] mod bench { + use rand::rng; use rand::seq::SliceRandom; - use rand::thread_rng; use test::Bencher; use crate::collector::FacetCollector; @@ -846,7 +843,7 @@ mod bench { } } // 40425 docs - docs[..].shuffle(&mut thread_rng()); + docs[..].shuffle(&mut rng()); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for doc in docs { diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 3ca27fc75..9ca47581b 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -160,7 +160,7 @@ mod tests { expected: &[(crate::Score, usize)], ) { let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect(); - vals.shuffle(&mut rand::thread_rng()); + vals.shuffle(&mut rand::rng()); let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order)); assert_eq!(&vals_merged, expected); } diff --git a/src/fastfield/alive_bitset.rs b/src/fastfield/alive_bitset.rs index 11d7463c7..bbdc82a45 100644 --- a/src/fastfield/alive_bitset.rs +++ b/src/fastfield/alive_bitset.rs @@ -162,7 +162,7 @@ mod tests { mod bench { use rand::prelude::IteratorRandom; - use rand::thread_rng; + use rand::rng; use test::Bencher; use super::AliveBitSet; @@ -176,7 +176,7 @@ mod bench { } fn remove_rand(raw: &mut Vec) { - let i = (0..raw.len()).choose(&mut thread_rng()).unwrap(); + let i = (0..raw.len()).choose(&mut rng()).unwrap(); raw.remove(i); } diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index 726b9b76a..aca53c212 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -879,7 +879,7 @@ mod tests { const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000; let times: Vec = std::iter::repeat_with(|| { // +- One hour. - let t = T0 + rng.gen_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS); + let t = T0 + rng.random_range(-ONE_HOUR_IN_MICROSECS..ONE_HOUR_IN_MICROSECS); DateTime::from_timestamp_micros(t) }) .take(1_000) diff --git a/src/functional_test.rs b/src/functional_test.rs index 1548d8096..9606bb7a7 100644 --- a/src/functional_test.rs +++ b/src/functional_test.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use crate::indexer::index_writer::MEMORY_BUDGET_NUM_BYTES_MIN; use crate::schema::*; @@ -29,7 +29,7 @@ fn test_functional_store() -> crate::Result<()> { let index = Index::create_in_ram(schema); let reader = index.reader()?; - let mut rng = thread_rng(); + let mut rng = rng(); let mut index_writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; @@ -38,9 +38,9 @@ fn test_functional_store() -> crate::Result<()> { let mut doc_id = 0u64; for _iteration in 0..get_num_iterations() { - let num_docs: usize = rng.gen_range(0..4); + let num_docs: usize = rng.random_range(0..4); if !doc_set.is_empty() { - let doc_to_remove_id = rng.gen_range(0..doc_set.len()); + let doc_to_remove_id = rng.random_range(0..doc_set.len()); let removed_doc_id = doc_set.swap_remove(doc_to_remove_id); index_writer.delete_term(Term::from_field_u64(id_field, removed_doc_id)); } @@ -70,10 +70,10 @@ const LOREM: &str = "Doc Lorem ipsum dolor sit amet, consectetur adipiscing elit cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat \ non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; fn get_text() -> String { - use rand::seq::SliceRandom; - let mut rng = thread_rng(); + use rand::seq::IndexedRandom; + let mut rng = rng(); let tokens: Vec<_> = LOREM.split(' ').collect(); - let random_val = rng.gen_range(0..20); + let random_val = rng.random_range(0..20); (0..random_val) .map(|_| tokens.choose(&mut rng).unwrap()) @@ -101,7 +101,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> { let index = Index::create_from_tempdir(schema)?; let reader = index.reader()?; - let mut rng = thread_rng(); + let mut rng = rng(); let mut index_writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; @@ -110,7 +110,7 @@ fn test_functional_indexing_unsorted() -> crate::Result<()> { let mut uncommitted_docs: HashSet = HashSet::new(); for _ in 0..get_num_iterations() { - let random_val = rng.gen_range(0..20); + let random_val = rng.random_range(0..20); if random_val == 0 { index_writer.commit()?; committed_docs.extend(&uncommitted_docs); diff --git a/src/lib.rs b/src/lib.rs index f0b3120a5..2747fe8ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -377,7 +377,7 @@ pub mod tests { use common::{BinarySerializable, FixedSize}; use query_grammar::{UserInputAst, UserInputLeaf, UserInputLiteral}; - use rand::distributions::{Bernoulli, Uniform}; + use rand::distr::{Bernoulli, Uniform}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use time::OffsetDateTime; @@ -428,7 +428,7 @@ pub mod tests { pub fn generate_nonunique_unsorted(max_value: u32, n_elems: usize) -> Vec { let seed: [u8; 32] = [1; 32]; StdRng::from_seed(seed) - .sample_iter(&Uniform::new(0u32, max_value)) + .sample_iter(&Uniform::new(0u32, max_value).unwrap()) .take(n_elems) .collect::>() } diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index 62eeca3d5..0ddf7e3df 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -397,7 +397,10 @@ mod bench { let mut seed: [u8; 32] = [0; 32]; seed[31] = seed_val; let mut rng = StdRng::from_seed(seed); - (0u32..).filter(|_| rng.gen_bool(ratio)).take(n).collect() + (0u32..) + .filter(|_| rng.random_bool(ratio)) + .take(n) + .collect() } pub fn generate_array(n: usize, ratio: f64) -> Vec { diff --git a/src/postings/mod.rs b/src/postings/mod.rs index b9c400859..d60ad597d 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -604,13 +604,13 @@ mod bench { let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); for _ in 0..posting_list_size { let mut doc = TantivyDocument::default(); - if rng.gen_bool(1f64 / 15f64) { + if rng.random_bool(1f64 / 15f64) { doc.add_text(text_field, "a"); } - if rng.gen_bool(1f64 / 10f64) { + if rng.random_bool(1f64 / 10f64) { doc.add_text(text_field, "b"); } - if rng.gen_bool(1f64 / 5f64) { + if rng.random_bool(1f64 / 5f64) { doc.add_text(text_field, "c"); } doc.add_text(text_field, "d"); diff --git a/src/query/phrase_query/regex_phrase_weight.rs b/src/query/phrase_query/regex_phrase_weight.rs index 4e850d2e2..9cefc555a 100644 --- a/src/query/phrase_query/regex_phrase_weight.rs +++ b/src/query/phrase_query/regex_phrase_weight.rs @@ -311,7 +311,7 @@ mod tests { #![proptest_config(ProptestConfig::with_cases(50))] #[test] fn test_phrase_regex_with_random_strings(mut random_strings in proptest::collection::vec("[c-z ]{0,10}", 1..100), num_occurrences in 1..150_usize) { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); // Insert "aaa ccc" the specified number of times into the list for _ in 0..num_occurrences { diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 1893a06a5..a597c8dca 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -429,7 +429,7 @@ mod tests { docs.push(doc); } - docs.shuffle(&mut rand::thread_rng()); + docs.shuffle(&mut rand::rng()); let mut docs_it = docs.into_iter(); for doc in (&mut docs_it).take(50) { index_writer.add_document(doc)?; diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index e379e108e..68da73c92 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -491,7 +491,7 @@ mod tests { use common::DateTime; use proptest::prelude::*; use rand::rngs::StdRng; - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; use rand::SeedableRng; use time::format_description::well_known::Rfc3339; use time::OffsetDateTime; diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 00fb8ca0b..6c7c5b17a 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -304,10 +304,10 @@ mod tests { let mut writer: IndexWriter = index.writer_with_num_threads(3, 3 * MEMORY_BUDGET_NUM_BYTES_MIN)?; use rand::Rng; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); writer.set_merge_policy(Box::new(NoMergePolicy)); for _ in 0..3_000 { - let term_freq = rng.gen_range(1..10000); + let term_freq = rng.random_range(1..10000); let words: Vec<&str> = std::iter::repeat_n("bbbb", term_freq).collect(); let text = words.join(" "); writer.add_document(doc!(text_field=>text))?; diff --git a/src/termdict/fst_termdict/merger.rs b/src/termdict/fst_termdict/merger.rs index e8a064deb..43147a5ae 100644 --- a/src/termdict/fst_termdict/merger.rs +++ b/src/termdict/fst_termdict/merger.rs @@ -95,7 +95,7 @@ impl<'a> TermMerger<'a> { #[cfg(all(test, feature = "unstable"))] mod bench { use rand::distributions::Alphanumeric; - use rand::{thread_rng, Rng}; + use rand::{rng, Rng}; use test::{self, Bencher}; use super::TermMerger; @@ -117,9 +117,9 @@ mod bench { let buffer: Vec = { let mut terms = vec![]; for _i in 0..num_terms { - let rand_string: String = thread_rng() + let rand_string: String = rng() .sample_iter(&Alphanumeric) - .take(thread_rng().gen_range(30..42)) + .take(rng().random_range(30..42)) .map(char::from) .collect(); terms.push(rand_string); diff --git a/sstable/Cargo.toml b/sstable/Cargo.toml index 7b353cece..813692e26 100644 --- a/sstable/Cargo.toml +++ b/sstable/Cargo.toml @@ -25,7 +25,7 @@ zstd-compression = ["zstd"] proptest = "1" criterion = { version = "0.5", default-features = false } names = "0.14" -rand = "0.8" +rand = "0.9" [[bench]] name = "stream_bench" diff --git a/sstable/benches/stream_bench.rs b/sstable/benches/stream_bench.rs index cffe41e26..70dcdd8e3 100644 --- a/sstable/benches/stream_bench.rs +++ b/sstable/benches/stream_bench.rs @@ -10,9 +10,9 @@ use tantivy_sstable::{Dictionary, MonotonicU64SSTable}; const CHARSET: &[u8] = b"abcdefghij"; fn generate_key(rng: &mut impl Rng) -> String { - let len = rng.gen_range(3..12); + let len = rng.random_range(3..12); std::iter::from_fn(|| { - let idx = rng.gen_range(0..CHARSET.len()); + let idx = rng.random_range(0..CHARSET.len()); Some(CHARSET[idx] as char) }) .take(len) diff --git a/stacker/Cargo.toml b/stacker/Cargo.toml index 38b293dff..81388bdfd 100644 --- a/stacker/Cargo.toml +++ b/stacker/Cargo.toml @@ -23,12 +23,12 @@ name = "hashmap" path = "example/hashmap.rs" [dev-dependencies] -rand = "0.8.5" +rand = "0.9" zipf = "7.0.0" rustc-hash = "2.1.0" proptest = "1.2.0" binggan = { version = "0.14.0" } -rand_distr = "0.4.3" +rand_distr = "0.5" [features] compare_hash_only = ["ahash"] # Compare hash only, not the key in the Hashmap diff --git a/stacker/benches/bench.rs b/stacker/benches/bench.rs index ed5ea5eeb..03f801308 100644 --- a/stacker/benches/bench.rs +++ b/stacker/benches/bench.rs @@ -90,10 +90,10 @@ fn bench_vint() { } // benchmark zipfs distribution numbers { - use rand::distributions::Distribution; + use rand::distr::Distribution; use rand::rngs::StdRng; let mut rng = StdRng::from_seed([3u8; 32]); - let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap(); + let zipf = rand_distr::Zipf::new(10_000.0f64, 1.03).unwrap(); let numbers: Vec<[u8; 8]> = (0..num_numbers) .map(|_| zipf.sample(&mut rng).to_le_bytes()) .collect(); diff --git a/stacker/fuzz_test/Cargo.toml b/stacker/fuzz_test/Cargo.toml index 02478c95b..f71b36a37 100644 --- a/stacker/fuzz_test/Cargo.toml +++ b/stacker/fuzz_test/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" [dependencies] ahash = "0.8.7" -rand = "0.8.5" -rand_distr = "0.4.3" +rand = "0.9" +rand_distr = "0.5" tantivy-stacker = { version = "0.2.0", path = ".." } [workspace] diff --git a/stacker/fuzz_test/src/main.rs b/stacker/fuzz_test/src/main.rs index 2367ddc33..efe72d921 100644 --- a/stacker/fuzz_test/src/main.rs +++ b/stacker/fuzz_test/src/main.rs @@ -14,7 +14,7 @@ fn test_with_seed(seed: u64) { let mut hash_map = AHashMap::new(); let mut arena_hashmap = ArenaHashMap::default(); let mut rng = StdRng::seed_from_u64(seed); - let key_count = rng.gen_range(1_000..=1_000_000); + let key_count = rng.random_range(1_000..=1_000_000); let exp = Exp::new(0.05).unwrap(); for _ in 0..key_count {