From 2ce4da8b6667c4553166f8f2c2bd5a84e2d97cd0 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 4 Dec 2025 10:19:37 +0800 Subject: [PATCH] one collector per agg request instead per bucket In this refactoring a collector knows in which bucket of the parent their data is in. This allows to convert the previous approach of one collector per bucket to one collector per request. low card bucket optimization --- src/aggregation/agg_data.rs | 57 +- src/aggregation/agg_tests.rs | 7 +- src/aggregation/bucket/filter.rs | 108 ++-- src/aggregation/bucket/histogram/histogram.rs | 104 ++-- src/aggregation/bucket/range.rs | 306 +++++------ src/aggregation/bucket/term_agg.rs | 487 +++++++----------- src/aggregation/bucket/term_missing_agg.rs | 89 +++- src/aggregation/buf_collector.rs | 87 ---- src/aggregation/cached_sub_aggs.rs | 212 ++++++++ src/aggregation/collector.rs | 44 +- src/aggregation/intermediate_agg_result.rs | 11 +- src/aggregation/metric/average.rs | 6 +- src/aggregation/metric/cardinality.rs | 104 ++-- src/aggregation/metric/count.rs | 6 +- src/aggregation/metric/extended_stats.rs | 113 ++-- src/aggregation/metric/max.rs | 6 +- src/aggregation/metric/min.rs | 6 +- src/aggregation/metric/percentiles.rs | 105 ++-- src/aggregation/metric/stats.rs | 131 +++-- src/aggregation/metric/sum.rs | 6 +- src/aggregation/metric/top_hits.rs | 133 ++--- src/aggregation/mod.rs | 15 +- src/aggregation/segment_agg_result.rs | 59 ++- 23 files changed, 1170 insertions(+), 1032 deletions(-) delete mode 100644 src/aggregation/buf_collector.rs create mode 100644 src/aggregation/cached_sub_aggs.rs diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index deedb5781..4a496af25 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -10,9 +10,9 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, - MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds, + IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, + SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ @@ -107,21 +107,14 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } - #[inline] - pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData { - self.per_request.filter_req_data[idx] - .as_deref() - .expect("filter_req_data slot is empty (taken)") - } // ---------- mutable getters ---------- #[inline] - pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { - self.per_request.term_req_data[idx] - .as_deref_mut() - .expect("term_req_data slot is empty (taken)") + pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { + &mut self.per_request.stats_metric_req_data[idx] } + #[inline] pub(crate) fn get_cardinality_req_data_mut( &mut self, @@ -129,10 +122,7 @@ impl AggregationsSegmentCtx { ) -> &mut CardinalityAggReqData { &mut self.per_request.cardinality_req_data[idx] } - #[inline] - pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { - &mut self.per_request.stats_metric_req_data[idx] - } + #[inline] pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { self.per_request.histogram_req_data[idx] @@ -345,12 +335,19 @@ impl PerRequestAggSegCtx { pub(crate) fn build_segment_agg_collectors_root( req: &mut AggregationsSegmentCtx, ) -> crate::Result> { - build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) + build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone()) } pub(crate) fn build_segment_agg_collectors( req: &mut AggregationsSegmentCtx, nodes: &[AggRefNode], +) -> crate::Result> { + build_segment_agg_collectors_generic(req, nodes) +} + +fn build_segment_agg_collectors_generic( + req: &mut AggregationsSegmentCtx, + nodes: &[AggRefNode], ) -> crate::Result> { let mut collectors = Vec::new(); for node in nodes.iter() { @@ -398,17 +395,10 @@ pub(crate) fn build_segment_agg_collector( | StatsType::Count | StatsType::Max | StatsType::Min - | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( - node.idx_in_req_data, - ))), - StatsType::ExtendedStats(sigma) => { - Ok(Box::new(SegmentExtendedStatsCollector::from_req( - req_data.field_type, - sigma, - node.idx_in_req_data, - req_data.missing, - ))) - } + | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(req_data))), + StatsType::ExtendedStats(sigma) => Ok(Box::new( + SegmentExtendedStatsCollector::from_req(req_data, sigma), + )), StatsType::Percentiles => Ok(Box::new( SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?, )), @@ -428,9 +418,7 @@ pub(crate) fn build_segment_agg_collector( AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( req, node, )?)), - AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Range => Ok(build_segment_range_collector(req, node)?), AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate( req, node, )?)), @@ -524,6 +512,7 @@ fn build_nodes( column_block_accessor: Default::default(), name: agg_name.to_string(), req: range_req.clone(), + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -543,7 +532,6 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req.clone(), is_date_histogram: false, bounds: HistogramBounds { @@ -570,7 +558,6 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req, is_date_histogram: true, bounds: HistogramBounds { @@ -929,8 +916,6 @@ fn build_terms_or_cardinality_nodes( column_block_accessor: Default::default(), name: agg_name.to_string(), req: TermsAggregationInternal::from_req(req), - // Will be filled later when building collectors - sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), allowed_term_ids, is_top_level, diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index cbd529146..49a8afb37 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -2,11 +2,11 @@ use serde_json::Value; use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; -use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, IndexWriter, Term}; @@ -467,8 +467,9 @@ fn test_aggregation_flushing( let reader = index.reader()?; - assert_eq!(DOC_BLOCK_SIZE, 64); - // In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. + assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64); + // In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one + // block. // // Build a request so that on the first level we have one full cache, which is then flushed. // The same cache should have some residue docs at the end, which are flushed (Range 0-70) diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index 18f2a692a..3d9e3ae82 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, }; -use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; +use crate::aggregation::segment_agg_result::{ + BucketIdProvider, CollectorClone, SegmentAggregationCollector, +}; +use crate::aggregation::BucketId; use crate::docset::DocSet; use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; use crate::schema::Schema; @@ -410,9 +414,9 @@ impl FilterAggReqData { pub(crate) fn get_memory_consumption(&self) -> usize { // Estimate: name + segment reader reference + bitset + buffer capacity self.name.len() - + std::mem::size_of::() - + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) - + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() + + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + + self.matching_docs_buffer.capacity() * std::mem::size_of::() } } @@ -489,12 +493,19 @@ impl Debug for DocumentQueryEvaluator { } } +#[derive(Debug, Clone, PartialEq, Copy)] +struct DocCount { + doc_count: u64, + bucket_id: BucketId, +} + /// Segment collector for filter aggregation pub struct SegmentFilterCollector { - /// Document count in this bucket - doc_count: u64, + /// Document counts per bucket + buckets: Vec, /// Sub-aggregation collectors - sub_aggregations: Option>, + sub_aggregations: Option>, + bucket_id_provider: BucketIdProvider, /// Accessor index for this filter aggregation (to access FilterAggReqData) accessor_idx: usize, } @@ -511,11 +522,13 @@ impl SegmentFilterCollector { } else { None }; + let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new); Ok(SegmentFilterCollector { - doc_count: 0, + buckets: Vec::new(), sub_aggregations: sub_agg_collector, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } @@ -523,7 +536,7 @@ impl SegmentFilterCollector { impl Debug for SegmentFilterCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentFilterCollector") - .field("doc_count", &self.doc_count) + .field("buckets", &self.buckets) .field("has_sub_aggs", &self.sub_aggregations.is_some()) .field("accessor_idx", &self.accessor_idx) .finish() @@ -539,19 +552,32 @@ impl CollectorClone for SegmentFilterCollector { impl SegmentAggregationCollector for SegmentFilterCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let mut sub_results = IntermediateAggregationResults::default(); + let bucket_opt = self.buckets.get(parent_bucket_id as usize); - if let Some(sub_aggs) = self.sub_aggregations { - sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_results, + // Here we create a new bucket ID for sub-aggregations if the bucket doesn't + // exist, so that sub-aggregations can still produce results (e.g., zero doc + // count) + bucket_opt + .map(|bucket| bucket.bucket_id) + .unwrap_or(self.bucket_id_provider.next_bucket_id()), + )?; } // Create the filter bucket result let filter_bucket_result = IntermediateBucketResult::Filter { - doc_count: self.doc_count, + doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0), sub_aggregations: sub_results, }; @@ -570,32 +596,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector { Ok(()) } - fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - // Access the evaluator from FilterAggReqData - let req_data = agg_data.get_filter_req_data(self.accessor_idx); - - // O(1) BitSet lookup to check if document matches filter - if req_data.evaluator.matches_document(doc) { - self.doc_count += 1; - - // If we have sub-aggregations, collect on them for this filtered document - if let Some(sub_aggs) = &mut self.sub_aggregations { - sub_aggs.collect(doc, agg_data)?; - } - } - Ok(()) - } - - #[inline] - fn collect_block( + fn collect( &mut self, - docs: &[DocId], + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { if docs.is_empty() { return Ok(()); } + let mut bucket = self.buckets[parent_bucket_id as usize]; // Take the request data to avoid borrow checker issues with sub-aggregations let mut req = agg_data.take_filter_req_data(self.accessor_idx); @@ -604,18 +615,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector { req.evaluator .filter_batch(docs, &mut req.matching_docs_buffer); - self.doc_count += req.matching_docs_buffer.len() as u64; + bucket.doc_count += req.matching_docs_buffer.len() as u64; // Batch process sub-aggregations if we have matches if !req.matching_docs_buffer.is_empty() { if let Some(sub_aggs) = &mut self.sub_aggregations { - // Use collect_block for better sub-aggregation performance - sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; + for &doc_id in &req.matching_docs_buffer { + sub_aggs.push(bucket.bucket_id, doc_id); + } } } // Put the request data back agg_data.put_back_filter_req_data(self.accessor_idx, req); + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs.check_flush_local(agg_data)?; + } + // put back bucket + self.buckets[parent_bucket_id as usize] = bucket; Ok(()) } @@ -626,6 +643,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.buckets.push(DocCount { + doc_count: 0, + bucket_id, + }); + } + Ok(()) + } } /// Intermediate result for filter aggregation @@ -1519,9 +1551,9 @@ mod tests { let searcher = reader.searcher(); let agg = json!({ - "test": { - "filter": deserialized, - "aggs": { "count": { "value_count": { "field": "brand" } } } + "test": { + "filter": deserialized, + "aggs": { "count": { "value_count": { "field": "brand" } } } } }); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 36c0fe57e..378cf7bf4 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; -use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::BucketEntry; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -30,9 +30,6 @@ pub struct HistogramAggReqData { pub column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, - /// The sub aggregation blueprint, used to create sub aggregations for each bucket. - /// Will be filled during initialization of the collector. - pub sub_aggregation_blueprint: Option>, /// The histogram aggregation request. pub req: HistogramAggregation, /// True if this is a date_histogram aggregation. @@ -257,18 +254,24 @@ impl HistogramBounds { pub(crate) struct SegmentHistogramBucketEntry { pub key: f64, pub doc_count: u64, + pub bucket_id: BucketId, } impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - sub_aggregation: Option>, + sub_aggregation: &mut Option, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + self.bucket_id, + )?; } Ok(IntermediateHistogramBucketEntry { key: self.key, @@ -278,27 +281,37 @@ impl SegmentHistogramBucketEntry { } } +#[derive(Clone, Debug, Default)] +struct HistogramBuckets { + pub buckets: FxHashMap, +} + /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug)] pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. - buckets: FxHashMap, - sub_aggregations: FxHashMap>, + /// One Histogram bucket per parent bucket id. + buckets: Vec, + sub_agg: Option, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } impl SegmentAggregationCollector for SegmentHistogramCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data .get_histogram_req_data(self.accessor_idx) .name .clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let histogram = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -307,20 +320,13 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { let mut req = agg_data.take_histogram_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); + let buckets = &mut self.buckets[parent_bucket_id as usize].buckets; let bounds = req.bounds; let interval = req.req.interval; @@ -335,16 +341,17 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { let val = f64_from_fastfield_u64(val, &req.field_type); let bucket_pos = get_bucket_pos(val); if bounds.contains(val) { - let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { + let bucket = buckets.entry(bucket_pos).or_insert_with(|| { let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); - SegmentHistogramBucketEntry { key, doc_count: 0 } + SegmentHistogramBucketEntry { + key, + doc_count: 0, + bucket_id: self.bucket_id_provider.next_bucket_id(), + } }); bucket.doc_count += 1; - if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { - self.sub_aggregations - .entry(bucket_pos) - .or_insert_with(|| sub_aggregation_blueprint.clone()) - .collect(doc, agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.push(bucket.bucket_id, doc); } } } @@ -358,14 +365,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { .add_memory_consumed(mem_delta as u64)?; } + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } + Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregation in self.sub_aggregations.values_mut() { + if let Some(sub_aggregation) = &mut self.sub_agg { sub_aggregation.flush(agg_data)?; } + Ok(()) + } + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(HistogramBuckets { + buckets: FxHashMap::default(), + }); + } Ok(()) } } @@ -373,22 +396,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { impl SegmentHistogramCollector { fn get_memory_consumption(&self) -> usize { let self_mem = std::mem::size_of::(); - let sub_aggs_mem = self.sub_aggregations.memory_consumption(); - let buckets_mem = self.buckets.memory_consumption(); - self_mem + sub_aggs_mem + buckets_mem + let buckets_mem = self.buckets.len() * std::mem::size_of::(); + self_mem + buckets_mem } /// Converts the collector result into a intermediate bucket result. - pub fn into_intermediate_bucket_result( - self, + fn add_intermediate_bucket_result( + &mut self, agg_data: &AggregationsSegmentCtx, + histogram: HistogramBuckets, ) -> crate::Result { let mut buckets = Vec::with_capacity(self.buckets.len()); - for (bucket_pos, bucket) in self.buckets { - let bucket_res = bucket.into_intermediate_bucket_entry( - self.sub_aggregations.get(&bucket_pos).cloned(), - agg_data, - ); + for bucket in histogram.buckets.into_values() { + let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data); buckets.push(bucket_res?); } @@ -408,7 +428,7 @@ impl SegmentHistogramCollector { agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { - let blueprint = if !node.children.is_empty() { + let sub_agg = if !node.children.is_empty() { Some(build_segment_agg_collectors(agg_data, &node.children)?) } else { None @@ -423,13 +443,13 @@ impl SegmentHistogramCollector { max: f64::MAX, }); req_data.offset = req_data.req.offset.unwrap_or(0.0); - - req_data.sub_aggregation_blueprint = blueprint; + let sub_agg = sub_agg.map(CachedSubAggs::new); Ok(Self { buckets: Default::default(), - sub_aggregations: Default::default(), + sub_agg, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index c26872e9b..abf746f5d 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -8,11 +8,13 @@ use serde::{Deserialize, Serialize}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::agg_limits::AggregationLimitsGuard; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -29,6 +31,8 @@ pub struct RangeAggReqData { pub req: RangeAggregation, /// The name of the aggregation. pub name: String, + /// Whether this is a top-level aggregation. + pub is_top_level: bool, } impl RangeAggReqData { @@ -151,19 +155,48 @@ pub(crate) struct SegmentRangeAndBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -pub struct SegmentRangeCollector { +#[derive(Clone)] +pub struct SegmentRangeCollector { /// The buckets containing the aggregation data. - buckets: Vec, + /// One for each ParentBucketId + parent_buckets: Vec>, column_type: ColumnType, pub(crate) accessor_idx: usize, + sub_agg: Option>, + /// Here things get a bit weird. We need to assign unique bucket ids across all + /// parent buckets. So we keep track of the next available bucket id here. + /// This allows a kind of flattening of the bucket ids across all parent buckets. + /// E.g. in nested aggregations: + /// Term Agg -> Range aggregation -> Stats aggregation + /// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range + /// aggregation with 4 buckets. The Range aggregation will create buckets with ids: + /// - INFO: 0,1,2,3 + /// - ERROR: 4,5,6,7 + /// - WARN: 8,9,10,11 + /// + /// This allows the Stats aggregation to have unique bucket ids to refer to. + bucket_id_provider: BucketIdProvider, + limits: AggregationLimitsGuard, } +impl Debug for SegmentRangeCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeCollector") + .field("parent_buckets_len", &self.parent_buckets.len()) + .field("column_type", &self.column_type) + .field("accessor_idx", &self.accessor_idx) + .field("has_sub_agg", &self.sub_agg.is_some()) + .finish() + } +} + +/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry #[derive(Clone)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key, pub doc_count: u64, - pub sub_aggregation: Option>, + // pub sub_aggregation: Option>, + pub bucket_id: BucketId, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not @@ -184,48 +217,50 @@ impl Debug for SegmentRangeBucketEntry { impl SegmentRangeBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let mut sub_aggregation_res = IntermediateAggregationResults::default(); - if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)? - } else { - Default::default() - }; + let sub_aggregation = IntermediateAggregationResults::default(); Ok(IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, - sub_aggregation: sub_aggregation_res, + sub_aggregation_res: sub_aggregation, from: self.from, to: self.to, }) } } -impl SegmentAggregationCollector for SegmentRangeCollector { +impl SegmentAggregationCollector for SegmentRangeCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let field_type = self.column_type; let name = agg_data .get_range_req_data(self.accessor_idx) .name .to_string(); - let buckets: FxHashMap = self - .buckets + let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + + let buckets: FxHashMap = buckets .into_iter() - .map(move |range_bucket| { - Ok(( - range_to_string(&range_bucket.range, &field_type)?, - range_bucket - .bucket - .into_intermediate_bucket_entry(agg_data)?, - )) + .map(|range_bucket| { + let bucket_id = range_bucket.bucket.bucket_id; + let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?; + if let Some(sub_aggregation) = &mut self.sub_agg { + sub_aggregation + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut agg.sub_aggregation_res, + bucket_id, + )?; + } + Ok((range_to_string(&range_bucket.range, &field_type)?, agg)) }) .collect::>()?; @@ -242,73 +277,112 @@ impl SegmentAggregationCollector for SegmentRangeCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - // Take request data to avoid borrow conflicts during sub-aggregation let mut req = agg_data.take_range_req_data(self.accessor_idx); req.column_block_accessor.fetch_block(docs, &req.accessor); + let buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + for (doc, val) in req .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let bucket_pos = self.get_bucket_pos(val); - let bucket = &mut self.buckets[bucket_pos]; + let bucket_pos = get_bucket_pos(val, buckets); + let bucket = &mut buckets[bucket_pos]; bucket.bucket.doc_count += 1; - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.collect(doc, agg_data)?; + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket.bucket_id, doc); } } agg_data.put_back_range_req_data(self.accessor_idx, req); + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.buckets.iter_mut() { - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.flush(agg_data)?; - } + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let new_buckets = self.create_new_buckets(agg_data)?; + self.parent_buckets.push(new_buckets); + } + + Ok(()) + } +} +/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed +/// bucket storage, depending on the column type and aggregation level. +pub(crate) fn build_segment_range_collector( + agg_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let accessor_idx = node.idx_in_req_data; + let req_data = agg_data.get_range_req_data(node.idx_in_req_data); + let field_type = req_data.field_type; + + // TODO: A better metric instead of is_top_level would be the number of buckets expected. + // E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets, + // we can are still in low cardinality territory. + let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64; + + let sub_agg = if !node.children.is_empty() { + Some(build_segment_agg_collectors(agg_data, &node.children)?) + } else { + None + }; + + if is_low_card { + Ok(Box::new(SegmentRangeCollector { + sub_agg: sub_agg.map(CachedSubAggs::::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } else { + Ok(Box::new(SegmentRangeCollector { + sub_agg: sub_agg.map(CachedSubAggs::::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } } -impl SegmentRangeCollector { - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let accessor_idx = node.idx_in_req_data; - let (field_type, ranges) = { - let req_view = req_data.get_range_req_data(node.idx_in_req_data); - (req_view.field_type, req_view.req.ranges.clone()) - }; - +impl SegmentRangeCollector { + pub(crate) fn create_new_buckets( + &mut self, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result> { + let field_type = self.column_type; + let req_data = agg_data.get_range_req_data(self.accessor_idx); // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let sub_agg_prototype = if !node.children.is_empty() { - Some(build_segment_agg_collectors(req_data, &node.children)?) - } else { - None - }; - - let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)? .iter() .map(|range| { + let bucket_id = self.bucket_id_provider.next_bucket_id(); let key = range .key .clone() @@ -324,13 +398,13 @@ impl SegmentRangeCollector { } else { Some(f64_from_fastfield_u64(range.range.start, &field_type)) }; - let sub_aggregation = sub_agg_prototype.clone(); + // let sub_aggregation = sub_agg_prototype.clone(); Ok(SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation, + bucket_id, key, from, to, @@ -339,27 +413,20 @@ impl SegmentRangeCollector { }) .collect::>()?; - req_data.context.limits.add_memory_consumed( + self.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; - - Ok(SegmentRangeCollector { - buckets, - column_type: field_type, - accessor_idx, - }) - } - - #[inline] - fn get_bucket_pos(&self, val: u64) -> usize { - let pos = self - .buckets - .binary_search_by_key(&val, |probe| probe.range.start) - .unwrap_or_else(|pos| pos - 1); - debug_assert!(self.buckets[pos].range.contains(&val)); - pos + Ok(buckets) } } +#[inline] +fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize { + let pos = buckets + .binary_search_by_key(&val, |probe| probe.range.start) + .unwrap_or_else(|pos| pos - 1); + debug_assert!(buckets[pos].range.contains(&val)); + pos +} /// Converts the user provided f64 range value to fast field value space. /// @@ -517,19 +584,22 @@ mod tests { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation: None, key, from, to, + bucket_id: 0, }, } }) .collect(); SegmentRangeCollector { - buckets, + parent_buckets: vec![buckets], column_type: field_type, accessor_idx: 0, + sub_agg: None, + bucket_id_provider: Default::default(), + limits: AggregationLimitsGuard::default(), } } @@ -776,7 +846,7 @@ mod tests { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -799,7 +869,7 @@ mod tests { ]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -814,7 +884,7 @@ mod tests { let buckets = vec![(-10f64..-1f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); } @@ -823,7 +893,7 @@ mod tests { let buckets = vec![(0f64..10f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); } @@ -832,7 +902,7 @@ mod tests { fn range_binary_search_test_u64() { let check_ranges = |ranges: Vec| { let collector = get_collector_from_ranges(ranges, ColumnType::U64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9), 0); @@ -878,7 +948,7 @@ mod tests { let ranges = vec![(10.0..100.0).into()]; let collector = get_collector_from_ranges(ranges, ColumnType::F64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9f64.to_u64()), 0); @@ -890,63 +960,3 @@ mod tests { // the max value } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use itertools::Itertools; - use rand::seq::SliceRandom; - use rand::thread_rng; - - use super::*; - use crate::aggregation::bucket::range::tests::get_collector_from_ranges; - - const TOTAL_DOCS: u64 = 1_000_000u64; - const NUM_DOCS: u64 = 50_000u64; - - fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector { - let bucket_size = num_docs / num_buckets; - let mut buckets: Vec = vec![]; - for i in 0..num_buckets { - let bucket_start = (i * bucket_size) as f64; - buckets.push((bucket_start..bucket_start + bucket_size as f64).into()) - } - - get_collector_from_ranges(buckets, ColumnType::U64) - } - - fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec { - let mut rng = thread_rng(); - - let all_docs = (0..total_docs - 1).collect_vec(); - let mut vals = all_docs - .as_slice() - .choose_multiple(&mut rng, num_docs_returned as usize) - .cloned() - .collect_vec(); - vals.sort(); - vals - } - - fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) { - let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS); - let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS); - b.iter(|| { - let mut bucket_pos = 0; - for val in &vals { - bucket_pos = collector.get_bucket_pos(*val); - } - bucket_pos - }) - } - - #[bench] - fn bench_range_100_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 100) - } - - #[bench] - fn bench_range_10_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 10) - } -} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index d87cd0078..994cae4b8 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -17,13 +17,13 @@ use crate::aggregation::agg_data::{ }; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; -use crate::aggregation::buf_collector::BufAggregationCollector; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{format_date, Key}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::{format_date, BucketId, Key}; use crate::error::DataCorruption; use crate::TantivyError; @@ -40,8 +40,6 @@ pub struct TermsAggReqData { pub missing_value_for_accessor: Option, /// The column block accessor to access the fast field values. pub column_block_accessor: ColumnBlockAccessor, - /// Note: sub_aggregation_blueprint is filled later when building collectors - pub sub_aggregation_blueprint: Option>, /// Used to build the correct nested result when we have an empty result. pub sug_aggregations: Aggregations, /// The name of the aggregation. @@ -257,9 +255,9 @@ pub struct TermsAggregation { /// Internally, `missing` requires some specialized handling in some scenarios. /// /// Simple Case: - /// In the simplest case, we can just put the missing value in the termmap use that. In case of - /// text we put a special u64::MAX and replace it at the end with the actual missing value, - /// when loading the text. + /// In the simplest case, we can just put the missing value in the termmap and use that. In + /// case of text we put a special u64::MAX and replace it at the end with the actual + /// missing value, when loading the text. /// Special Case 1: /// If we have multiple columns on one field, we need to have a union on the indices on both /// columns, to find docids without a value. That requires a special missing aggregation. @@ -334,86 +332,6 @@ impl TermsAggregationInternal { } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector { - #[inline(always)] - fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self { - let sub_agg = sub_agg_blueprint_opt.clone_box(); - BufAggregationCollector::new(sub_agg) - } -} - -#[derive(Debug, Clone)] -struct BoxedAggregation(Box); - -impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation { - #[inline(always)] - fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self { - BoxedAggregation(sub_agg_blueprint.clone_box()) - } -} - -impl SegmentAggregationCollector for BoxedAggregation { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - self.0 - .add_intermediate_aggregation_result(agg_data, results) - } - - #[inline(always)] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect(doc, agg_data) - } - - #[inline(always)] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect_block(docs, agg_data) - } -} - -#[derive(Debug, Clone, Copy)] -struct NoSubAgg; - -impl SegmentAggregationCollector for NoSubAgg { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - _agg_data: &AggregationsSegmentCtx, - _results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect( - &mut self, - _doc: crate::DocId, - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect_block( - &mut self, - _docs: &[crate::DocId], - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } -} - /// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed /// bucket storage, depending on the column type and aggregation level. pub(crate) fn build_segment_term_collector( @@ -450,16 +368,6 @@ pub(crate) fn build_segment_term_collector( // Build sub-aggregation blueprint if there are children. let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - { - let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx); - terms_req_data_mut.sub_aggregation_blueprint = blueprint; - } // Decide whether to use a Vec-backed or HashMap-backed bucket storage. let terms_req_data = req_data.get_term_req_data(accessor_idx); @@ -478,75 +386,56 @@ pub(crate) fn build_segment_term_collector( let max_term: usize = col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize; - // - use a Vec instead of a hashmap for our aggregation. - // - buffer aggregation of our child aggregations (in any) - #[allow(clippy::collapsible_else_if)] - if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { - if has_sub_aggregations { - let sub_agg_blueprint = &req_data - .get_term_req_data_mut(accessor_idx) - .sub_aggregation_blueprint - .as_ref() - .ok_or_else(|| { - // Handle the error case here - // For example, return an error message or a default value - TantivyError::InternalError("Sub-aggregation blueprint not found".to_string()) - })?; - let term_buckets = VecTermBuckets::new(max_term + 1, || { - let collector_clone = sub_agg_blueprint.clone_box(); - BufAggregationCollector::new(collector_clone) - }); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + let sub_agg_collector = if has_sub_aggregations { + Some(build_segment_agg_collectors(req_data, &node.children)?) } else { - if has_sub_aggregations { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + None + }; + + let mut bucket_id_provider = BucketIdProvider::default(); + // - use a Vec instead of a hashmap for our aggregation. + if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { + let term_buckets = VecTermBuckets::new(max_term + 1, &mut bucket_id_provider); + let sub_agg = sub_agg_collector.map(CachedSubAggs::::new); + let collector: SegmentTermCollector<_, true> = SegmentTermCollector { + buckets: vec![term_buckets], + accessor_idx, + sub_agg, + bucket_id_provider, + }; + Ok(Box::new(collector)) + } else { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::::new); + let collector: SegmentTermCollector = SegmentTermCollector { + buckets: vec![term_buckets], + accessor_idx, + sub_agg, + bucket_id_provider, + }; + Ok(Box::new(collector)) } } #[derive(Debug, Clone)] -struct Bucket { +struct Bucket { pub count: u32, - pub sub_agg: SubAgg, + pub bucket_id: BucketId, } -impl Bucket { +impl Bucket { #[inline(always)] - fn new(sub_agg: SubAgg) -> Self { - Self { count: 0, sub_agg } + fn new(bucket_id: BucketId) -> Self { + Self { + count: 0, + bucket_id, + } } } /// Abstraction over the storage used for term buckets (counts only). -trait TermAggregationMap: Clone + Debug + 'static { - type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static; - +trait TermAggregationMap: Clone + Debug + Default + 'static { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize; @@ -554,23 +443,20 @@ trait TermAggregationMap: Clone + Debug + 'static { fn term_entry( &mut self, term_id: u64, - blue_print: &dyn SegmentAggregationCollector, - ) -> &mut Bucket; - - /// If the tree of aggregations contains buffered aggregations, flush them. - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>; + bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket; /// Returns the term aggregation as a vector of (term_id, bucket) pairs, /// in any order. - fn into_vec(self) -> Vec<(u64, Bucket)>; + fn into_vec(self) -> Vec<(u64, Bucket)>; } #[derive(Clone, Debug)] -struct HashMapTermBuckets { - bucket_map: FxHashMap>, +struct HashMapTermBuckets { + bucket_map: FxHashMap, } -impl Default for HashMapTermBuckets { +impl Default for HashMapTermBuckets { #[inline(always)] fn default() -> Self { Self { @@ -579,16 +465,7 @@ impl Default for HashMapTermBuckets { } } -impl< - SubAgg: Debug - + Clone - + SegmentAggregationCollector - + for<'a> From<&'a dyn SegmentAggregationCollector> - + 'static, - > TermAggregationMap for HashMapTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for HashMapTermBuckets { #[inline] fn get_memory_consumption(&self) -> usize { self.bucket_map.memory_consumption() @@ -598,55 +475,42 @@ impl< fn term_entry( &mut self, term_id: u64, - sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket { self.bucket_map .entry(term_id) - .or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint))) + .or_insert_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.bucket_map.values_mut() { - bucket.sub_agg.flush(agg_data)?; - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.bucket_map.into_iter().collect() } } /// An optimized term map implementation for a compact set of term ordinals. -#[derive(Clone, Debug)] -struct VecTermBuckets { - buckets: Vec>, +#[derive(Clone, Debug, Default)] +struct VecTermBuckets { + buckets: Vec, } -impl VecTermBuckets { - fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self { +impl VecTermBuckets { + fn new(num_terms: usize, bucket_id_provider: &mut BucketIdProvider) -> Self { VecTermBuckets { - buckets: std::iter::repeat_with(item_factory_fn) - .map(Bucket::new) + buckets: std::iter::repeat_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) .take(num_terms) .collect(), } } } -impl TermAggregationMap - for VecTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for VecTermBuckets { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize { // We do not include `std::mem::size_of::()` // It is already measure by the parent aggregation. // // The root aggregation mem size is not measure but we do not care. - self.buckets.capacity() * std::mem::size_of::>() + self.buckets.capacity() * std::mem::size_of::() } /// Add an occurrence of the given term id. @@ -654,8 +518,8 @@ impl TermAggregat fn term_entry( &mut self, term_id: u64, - _sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + _bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket { let term_id_usize = term_id as usize; debug_assert!( term_id_usize < self.buckets.len(), @@ -666,17 +530,7 @@ impl TermAggregat unsafe { self.buckets.get_unchecked_mut(term_id_usize) } } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in &mut self.buckets { - if bucket.count > 0 { - bucket.sub_agg.flush(agg_data)?; - } - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.buckets .into_iter() .enumerate() @@ -686,20 +540,15 @@ impl TermAggregat } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg { - #[inline(always)] - fn from(_: &'a dyn SegmentAggregationCollector) -> Self { - Self - } -} - /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug)] -struct SegmentTermCollector { +struct SegmentTermCollector { /// The buckets containing the aggregation data. - term_buckets: TermMap, + buckets: Vec, + sub_agg: Option>, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { @@ -707,18 +556,22 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { (agg_name, agg_property) } -impl SegmentAggregationCollector for SegmentTermCollector -where - TermMap: TermAggregationMap, - TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>, +impl SegmentAggregationCollector + for SegmentTermCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + bucket: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + self.prepare_max_bucket(bucket, agg_data)?; + let bucket = std::mem::take(&mut self.buckets[bucket as usize]); + let term_req = agg_data.get_term_req_data(self.accessor_idx); + let name = term_req.name.clone(); + + let bucket = + Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) } @@ -726,22 +579,14 @@ where #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req_data = agg_data.take_term_req_data(self.accessor_idx); - let mem_pre = self.get_memory_consumption(); + let mut req_data = agg_data.take_term_req_data(self.accessor_idx); + if let Some(missing) = req_data.missing_value_for_accessor { req_data.column_block_accessor.fetch_block_with_missing( docs, @@ -754,37 +599,36 @@ where .fetch_block(docs, &req_data.accessor); } - if std::any::TypeId::of::() == std::any::TypeId::of::() { - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg); - bucket.count += 1; + // Build iterators first to avoid overlapping borrows with self's fields. + if let Some(sub_agg) = &mut self.sub_agg { + let term_buckets = &mut self.buckets[parent_bucket_id as usize]; + let it = req_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&(_doc, term_id)| allowed_bs.contains(term_id as u32)); + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); + } else { + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); } } else { - let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref() - else { - return Err(TantivyError::InternalError( - "Could not find sub-aggregation blueprint".to_string(), - )); - }; - for (doc, term_id) in req_data - .column_block_accessor - .iter_docid_vals(docs, &req_data.accessor) - { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self - .term_buckets - .term_entry(term_id, sub_aggregation_blueprint); - bucket.count += 1; - bucket.sub_agg.collect(doc, agg_data)?; + let term_buckets = &mut self.buckets[parent_bucket_id as usize]; + let it = req_data.column_block_accessor.iter_vals(); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&term_id| allowed_bs.contains(term_id as u32)); + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); + } else { + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); } } @@ -796,13 +640,30 @@ where .add_memory_consumed(mem_delta as u64)?; } agg_data.put_back_term_req_data(self.accessor_idx, req_data); + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } #[inline(always)] fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.flush(agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.flush(agg_data)?; + } + Ok(()) + } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + let term_buckets: TermMap = TermMap::default(); + self.buckets.push(term_buckets); + } Ok(()) } } @@ -831,20 +692,24 @@ fn extract_missing_value( Some((key, bucket)) } -impl SegmentTermCollector +impl SegmentTermCollector where TermMap: TermAggregationMap { fn get_memory_consumption(&self) -> usize { - self.term_buckets.get_memory_consumption() + self.buckets + .iter() + .map(|b| b.get_memory_consumption()) + .sum() } #[inline] pub(crate) fn into_intermediate_bucket_result( - self, + term_req: &TermsAggReqData, + sub_agg: &mut Option>, + term_buckets: TermMap, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, Bucket)> = self.term_buckets.into_vec(); + let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec(); let order_by_sub_aggregation = matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); @@ -884,23 +749,28 @@ where TermMap: TermAggregationMap dict.reserve(entries.len()); let into_intermediate_bucket_entry = - |bucket: Bucket| -> crate::Result { - let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { + |bucket: Bucket, + sub_agg: &mut Option>| + -> crate::Result { + if let Some(sub_agg) = sub_agg { let mut sub_aggregation_res = IntermediateAggregationResults::default(); - // TODO remove box new - Box::new(bucket.sub_agg) - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - IntermediateTermBucketEntry { + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + bucket.bucket_id, + )?; + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: sub_aggregation_res, - } + }) } else { - IntermediateTermBucketEntry { + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: Default::default(), - } - }; - Ok(intermediate_entry) + }) + } }; if term_req.column_type == ColumnType::Str { @@ -913,21 +783,20 @@ where TermMap: TermAggregationMap if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req) { - let intermediate_entry = into_intermediate_bucket_entry(bucket)?; + let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?; dict.insert(intermediate_key, intermediate_entry); } // Sort by term ord entries.sort_unstable_by_key(|bucket| bucket.0); - let (term_ids, buckets): (Vec, Vec>) = - entries.into_iter().unzip(); + let (term_ids, buckets): (Vec, Vec) = entries.into_iter().unzip(); let mut buckets_it = buckets.into_iter(); term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| { let bucket = buckets_it.next().unwrap(); let intermediate_entry = - into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?; + into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?; dict.insert( IntermediateKey::Str( String::from_utf8(term.to_vec()).expect("could not convert to String"), @@ -969,14 +838,14 @@ where TermMap: TermAggregationMap } } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } @@ -996,14 +865,14 @@ where TermMap: TermAggregationMap })?; for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val = Ipv6Addr::from_u128(val); dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); } } else { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); } else if term_req.column_type == ColumnType::I64 { @@ -1037,6 +906,38 @@ where TermMap: TermAggregationMap } } +impl SegmentTermCollector { + #[inline] + fn collect_terms_with_docs( + it: I, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + sub_agg: &mut CachedSubAggs, + ) where + I: Iterator, + { + for (doc, term_id) in it { + let bucket = term_buckets.term_entry(term_id, bucket_id_provider); + bucket.count += 1; + sub_agg.push(bucket.bucket_id, doc); + } + } + + #[inline] + fn collect_terms( + it: I, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + ) where + I: Iterator, + { + for term_id in it { + let bucket = term_buckets.term_entry(term_id, bucket_id_provider); + bucket.count += 1; + } + } +} + pub(crate) trait GetDocCount { fn doc_count(&self) -> u64; } @@ -1047,7 +948,7 @@ impl GetDocCount for (String, IntermediateTermBucketEntry) { } } -impl GetDocCount for (u64, Bucket) { +impl GetDocCount for (u64, Bucket) { fn doc_count(&self) -> u64 { self.1.count as u64 } diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 66f39927a..c8ab51832 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; use crate::aggregation::bucket::term_agg::TermsAggregation; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; /// Special aggregation to handle missing values for term aggregations. /// This missing aggregation will check multiple columns for existence. @@ -35,41 +37,55 @@ impl MissingTermAggReqData { } } +#[derive(Default, Debug, Clone)] +struct MissingCount { + missing_count: u32, + bucket_id: BucketId, +} + /// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] pub struct TermMissingAgg { - missing_count: u32, accessor_idx: usize, - sub_agg: Option>, + sub_agg: Option, + /// Idx = bucket id, Value = missing count for that bucket + missing_count_per_bucket: Vec, + bucket_id_provider: BucketIdProvider, } impl TermMissingAgg { pub(crate) fn new( - req_data: &mut AggregationsSegmentCtx, + agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { let has_sub_aggregations = !node.children.is_empty(); let accessor_idx = node.idx_in_req_data; let sub_agg = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?; Some(sub_aggregation) } else { None }; + let sub_agg = sub_agg.map(CachedSubAggs::new); + let bucket_id_provider = BucketIdProvider::default(); + Ok(Self { accessor_idx, sub_agg, - ..Default::default() + missing_count_per_bucket: Vec::new(), + bucket_id_provider, }) } } impl SegmentAggregationCollector for TermMissingAgg { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let term_agg = &req_data.req; let missing = term_agg @@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg { let mut entries: FxHashMap = Default::default(); + let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize]; let mut missing_entry = IntermediateTermBucketEntry { - doc_count: self.missing_count, + doc_count: missing_count.missing_count, sub_aggregation: Default::default(), }; - if let Some(sub_agg) = self.sub_agg { + if let Some(sub_agg) = &mut self.sub_agg { let mut res = IntermediateAggregationResults::default(); - sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?; missing_entry.sub_aggregation = res; } entries.insert(missing.into(), missing_entry); @@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize]; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); - let has_value = req_data - .accessors - .iter() - .any(|(acc, _)| acc.index.has_value(doc)); - if !has_value { - self.missing_count += 1; - if let Some(sub_agg) = self.sub_agg.as_mut() { - sub_agg.collect(doc, agg_data)?; + + for doc in docs { + let doc = *doc; + let has_value = req_data + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); + if !has_value { + bucket.missing_count += 1; + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket_id, doc); + } } } + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - for doc in docs { - self.collect(*doc, agg_data)?; + while self.missing_count_per_bucket.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.missing_count_per_bucket.push(MissingCount { + missing_count: 0, + bucket_id, + }); + } + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs deleted file mode 100644 index 17bc1ed35..000000000 --- a/src/aggregation/buf_collector.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::agg_data::AggregationsSegmentCtx; -use crate::DocId; - -#[cfg(test)] -pub(crate) const DOC_BLOCK_SIZE: usize = 64; - -#[cfg(not(test))] -pub(crate) const DOC_BLOCK_SIZE: usize = 256; - -pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; - -/// BufAggregationCollector buffers documents before calling collect_block(). -#[derive(Clone)] -pub(crate) struct BufAggregationCollector { - pub(crate) collector: Box, - staged_docs: DocBlock, - num_staged_docs: usize, -} - -impl std::fmt::Debug for BufAggregationCollector { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SegmentAggregationResultsCollector") - .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) - .field("num_staged_docs", &self.num_staged_docs) - .finish() - } -} - -impl BufAggregationCollector { - pub fn new(collector: Box) -> Self { - Self { - collector, - num_staged_docs: 0, - staged_docs: [0; DOC_BLOCK_SIZE], - } - } -} - -impl SegmentAggregationCollector for BufAggregationCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.staged_docs[self.num_staged_docs] = doc; - self.num_staged_docs += 1; - if self.num_staged_docs == self.staged_docs.len() { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - } - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collector.collect_block(docs, agg_data)?; - Ok(()) - } - - #[inline] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - - self.collector.flush(agg_data)?; - - Ok(()) - } -} diff --git a/src/aggregation/cached_sub_aggs.rs b/src/aggregation/cached_sub_aggs.rs new file mode 100644 index 000000000..bf8106ea6 --- /dev/null +++ b/src/aggregation/cached_sub_aggs.rs @@ -0,0 +1,212 @@ +use super::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; +use crate::DocId; + +#[derive(Clone, Debug)] +/// A cache for sub-aggregations, storing doc ids per bucket id. +/// Depending on the cardinality of the parent aggregation, we use different +/// storage strategies. +/// +/// ## Low Cardinality +/// Cardinality here refers to the number of unique flattened buckets that can be created +/// by the parent aggregation. +/// Flattened buckets are the result of combining all buckets per collector +/// into a single list of buckets, where each bucket is identified by its BucketId. +/// +/// ## Usage +/// Since this is caching for sub-aggregations, it is only used by bucket +/// aggregations. +/// +/// TODO: consider using a more advanced data structure for high cardinality +/// aggregations. +/// What this datastructure does in general is to group docs by bucket id. +pub(crate) struct CachedSubAggs { + /// Only used when LOWCARD is true. + /// Cache doc ids per bucket for sub-aggregations. + /// + /// The outer Vec is indexed by BucketId. + per_bucket_docs: Vec>, + /// Only used when LOWCARD is false. + /// For higher cardinalities we use a partitioned approach to store + /// + /// partitioned Vec<(BucketId, DocId)> pairs to improve grouping locality. + partitions: [PartitionEntry; NUM_PARTITIONS], + pub(crate) sub_agg_collector: Box, + num_docs: usize, +} + +const FLUSH_THRESHOLD: usize = 1024; +const NUM_PARTITIONS: usize = 16; + +impl CachedSubAggs { + pub fn get_sub_agg_collector(&mut self) -> &mut Box { + &mut self.sub_agg_collector + } + + pub fn new(sub_agg: Box) -> Self { + Self { + per_bucket_docs: Vec::new(), + num_docs: 0, + sub_agg_collector: sub_agg, + partitions: core::array::from_fn(|_| PartitionEntry::new()), + } + } + + #[inline] + pub fn clear(&mut self) { + for v in &mut self.per_bucket_docs { + v.clear(); + } + for partition in &mut self.partitions { + partition.clear(); + } + self.num_docs = 0; + } + + #[inline] + pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + if LOWCARD { + let idx = bucket_id as usize; + if self.per_bucket_docs.len() <= idx { + self.per_bucket_docs.resize_with(idx + 1, Vec::new); + } + self.per_bucket_docs[idx].push(doc_id); + } else { + let idx = bucket_id % NUM_PARTITIONS as u32; + let slot = &mut self.partitions[idx as usize]; + slot.bucket_ids.push(bucket_id); + slot.docs.push(doc_id); + } + self.num_docs += 1; + } + + #[inline] + pub fn extend_with_bucket_zero(&mut self, docs: &[DocId]) { + debug_assert!( + LOWCARD, + "extend_with_bucket_zero only valid for single bucket" + ); + if self.per_bucket_docs.is_empty() { + self.per_bucket_docs.resize_with(1, Vec::new); + } + self.per_bucket_docs[0].extend_from_slice(docs); + self.num_docs += docs.len(); + } + + /// Check if we need to flush based on the number of documents cached. + /// If so, flushes the cache to the provided aggregation collector. + pub fn check_flush_local( + &mut self, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if self.num_docs >= FLUSH_THRESHOLD { + self.flush_local(agg_data)?; + } + Ok(()) + } + + /// Note: this does _not_ flush the sub aggregations + fn flush_local(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if LOWCARD { + // Pre-aggregated: call collect per bucket. + let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1); + self.sub_agg_collector + .prepare_max_bucket(max_bucket, agg_data)?; + for (bucket_id, docs) in self + .per_bucket_docs + .iter() + .enumerate() + .filter(|(_, docs)| !docs.is_empty()) + { + self.sub_agg_collector + .collect(bucket_id as BucketId, docs, agg_data)?; + } + } else { + let mut max_bucket = 0u32; + for partition in &self.partitions { + if let Some(&local_max) = partition.bucket_ids.iter().max() { + max_bucket = max_bucket.max(local_max); + } + } + + self.sub_agg_collector + .prepare_max_bucket(max_bucket, agg_data)?; + + for slot in &self.partitions { + for (bucket_id, docs) in slot.groups() { + // TODO: this may be a lot of calls to a boxed trait object. + // Consider passing the struct directly to avoid dynamic dispatch. + self.sub_agg_collector.collect(bucket_id, docs, agg_data)?; + } + } + } + self.clear(); + Ok(()) + } + + /// Note: this _does_ flush the sub aggregations + pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if self.num_docs != 0 { + self.flush_local(agg_data)?; + } + self.sub_agg_collector.flush(agg_data)?; + Ok(()) + } +} + +#[derive(Debug, Clone)] +struct PartitionEntry { + bucket_ids: Vec, + docs: Vec, +} + +impl PartitionEntry { + #[inline] + fn new() -> Self { + Self { + bucket_ids: Vec::with_capacity(FLUSH_THRESHOLD / NUM_PARTITIONS), + docs: Vec::with_capacity(FLUSH_THRESHOLD / NUM_PARTITIONS), + } + } + + #[inline] + fn clear(&mut self) { + self.bucket_ids.clear(); + self.docs.clear(); + } + + #[inline] + fn groups(&self) -> PartitionGroups<'_> { + PartitionGroups { + bucket_ids: &self.bucket_ids, + docs: &self.docs, + idx: 0, + } + } +} + +struct PartitionGroups<'a> { + bucket_ids: &'a [BucketId], + docs: &'a [DocId], + idx: usize, +} + +impl<'a> Iterator for PartitionGroups<'a> { + type Item = (BucketId, &'a [DocId]); + + fn next(&mut self) -> Option { + if self.idx >= self.bucket_ids.len() { + return None; + } + + let bucket_id = self.bucket_ids[self.idx]; + let start = self.idx; + let mut end = start + 1; + while end < self.bucket_ids.len() && self.bucket_ids[end] == bucket_id { + end += 1; + } + self.idx = end; + Some((bucket_id, &self.docs[start..end])) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 4c4c2c7f1..e6993eae6 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,9 +1,9 @@ use super::agg_req::Aggregations; use super::agg_result::AggregationResults; -use super::buf_collector::BufAggregationCollector; +use super::cached_sub_aggs::CachedSubAggs; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; use super::AggContextParams; +// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly. use crate::aggregation::agg_data::{ build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; @@ -136,7 +136,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsSegmentCtx, - agg_collector: BufAggregationCollector, + agg_collector: CachedSubAggs, error: Option, } @@ -151,8 +151,7 @@ impl AggregationSegmentCollector { ) -> crate::Result { let mut agg_data = build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; - let result = - BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); + let result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?); Ok(AggregationSegmentCollector { aggs_with_accessor: agg_data, @@ -170,26 +169,30 @@ impl SegmentCollector for AggregationSegmentCollector { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.push(0, doc); + match self .agg_collector - .collect(doc, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } - - /// The query pushes the documents to the collector via this method. - /// - /// Only valid for Collectors that ignore docs fn collect_block(&mut self, docs: &[DocId]) { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.extend_with_bucket_zero(docs); + match self .agg_collector - .collect_block(docs, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } @@ -200,10 +203,13 @@ impl SegmentCollector for AggregationSegmentCollector { self.agg_collector.flush(&mut self.aggs_with_accessor)?; let mut sub_aggregation_res = IntermediateAggregationResults::default(); - Box::new(self.agg_collector).add_intermediate_aggregation_result( - &self.aggs_with_accessor, - &mut sub_aggregation_res, - )?; + self.agg_collector + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + 0, + )?; Ok(sub_aggregation_res) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 104131461..b20e8a042 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry { /// The number of documents in the bucket. pub doc_count: u64, /// The sub_aggregation in this bucket. - pub sub_aggregation: IntermediateAggregationResults, + pub sub_aggregation_res: IntermediateAggregationResults, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. @@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, sub_aggregation: self - .sub_aggregation + .sub_aggregation_res .into_final_result_internal(req, limits)?, to: self.to, from: self.from, @@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry { impl MergeFruits for IntermediateRangeBucketEntry { fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> { self.doc_count += other.doc_count; - self.sub_aggregation.merge_fruits(other.sub_aggregation)?; + self.sub_aggregation_res + .merge_fruits(other.sub_aggregation_res)?; Ok(()) } } @@ -887,7 +888,7 @@ mod tests { IntermediateRangeBucketEntry { key: IntermediateKey::Str(key.to_string()), doc_count: *doc_count, - sub_aggregation: Default::default(), + sub_aggregation_res: Default::default(), from: None, to: None, }, @@ -920,7 +921,7 @@ mod tests { doc_count: *doc_count, from: None, to: None, - sub_aggregation: get_sub_test_tree(&[( + sub_aggregation_res: get_sub_test_tree(&[( sub_aggregation_key.to_string(), *sub_aggregation_count, )]), diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index e707f2b00..57f694984 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -52,10 +52,8 @@ pub struct IntermediateAverage { impl IntermediateAverage { /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateAverage) { diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 8f3bdd3e5..77387fc5e 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -137,43 +137,27 @@ impl CardinalityAggregationReq { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentCardinalityCollector { - cardinality: CardinalityCollector, - entries: FxHashSet, + buckets: Vec, + column_type: ColumnType, accessor_idx: usize, } -impl SegmentCardinalityCollector { - pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { +#[derive(Clone, Debug, PartialEq, Default)] +pub(crate) struct SegmentCardinalityCollectorBucket { + cardinality: CardinalityCollector, + entries: FxHashSet, +} +impl SegmentCardinalityCollectorBucket { + pub fn new(column_type: ColumnType) -> Self { Self { cardinality: CardinalityCollector::new(column_type as u8), - entries: Default::default(), - accessor_idx, + entries: FxHashSet::default(), } } - - fn fetch_block_with_field( - &mut self, - docs: &[crate::DocId], - agg_data: &mut CardinalityAggReqData, - ) { - if let Some(missing) = agg_data.missing_value_for_accessor { - agg_data.column_block_accessor.fetch_block_with_missing( - docs, - &agg_data.accessor, - missing, - ); - } else { - agg_data - .column_block_accessor - .fetch_block(docs, &agg_data.accessor); - } - } - fn into_intermediate_metric_result( mut self, - agg_data: &AggregationsSegmentCtx, + req_data: &CardinalityAggReqData, ) -> crate::Result { - let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); if req_data.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); let dict = req_data @@ -194,6 +178,7 @@ impl SegmentCardinalityCollector { term_ids.push(term_ord as u32); } } + term_ids.sort_unstable(); dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| { self.cardinality.sketch.insert_any(&term); @@ -227,16 +212,48 @@ impl SegmentCardinalityCollector { } } +impl SegmentCardinalityCollector { + pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { + Self { + buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1], + column_type, + accessor_idx, + } + } + + fn fetch_block_with_field( + &mut self, + docs: &[crate::DocId], + agg_data: &mut CardinalityAggReqData, + ) { + if let Some(missing) = agg_data.missing_value_for_accessor { + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &agg_data.accessor, + missing, + ); + } else { + agg_data + .column_block_accessor + .fetch_block(docs, &agg_data.accessor); + } + } +} + impl SegmentAggregationCollector for SegmentCardinalityCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); let name = req_data.name.to_string(); + // take the bucket in buckets and replace it with a new empty one + let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); - let intermediate_result = self.into_intermediate_metric_result(agg_data)?; + let intermediate_result = bucket.into_intermediate_metric_result(req_data)?; results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -247,24 +264,18 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); self.fetch_block_with_field(docs, req_data); + let bucket = &mut self.buckets[parent_bucket_id as usize]; let col_block_accessor = &req_data.column_block_accessor; if req_data.column_type == ColumnType::Str { for term_ord in col_block_accessor.iter_vals() { - self.entries.insert(term_ord); + bucket.entries.insert(term_ord); } } else if req_data.column_type == ColumnType::IpAddr { let compact_space_accessor = req_data @@ -282,16 +293,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { })?; for val in col_block_accessor.iter_vals() { let val: u128 = compact_space_accessor.compact_to_u128(val as u32); - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } else { for val in col_block_accessor.iter_vals() { - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + if max_bucket as usize >= self.buckets.len() { + self.buckets.resize_with(max_bucket as usize + 1, || { + SegmentCardinalityCollectorBucket::new(self.column_type) + }); + } + Ok(()) + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs index ac550a38f..b28ced047 100644 --- a/src/aggregation/metric/count.rs +++ b/src/aggregation/metric/count.rs @@ -52,10 +52,8 @@ pub struct IntermediateCount { impl IntermediateCount { /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateCount) { diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs index d7302e5f5..ad7165b8d 100644 --- a/src/aggregation/metric/extended_stats.rs +++ b/src/aggregation/metric/extended_stats.rs @@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of extended statistics /// on numeric values that are extracted @@ -318,51 +317,30 @@ impl IntermediateExtendedStats { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentExtendedStatsCollector { + name: String, missing: Option, field_type: ColumnType, - pub(crate) extended_stats: IntermediateExtendedStats, - pub(crate) accessor_idx: usize, - val_cache: Vec, + accessor: columnar::Column, + column_block_accessor: columnar::ColumnBlockAccessor, + buckets: Vec, + sigma: Option, } impl SegmentExtendedStatsCollector { - pub fn from_req( - field_type: ColumnType, - sigma: Option, - accessor_idx: usize, - missing: Option, - ) -> Self { - let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + pub fn from_req(req: &MetricAggReqData, sigma: Option) -> Self { + let missing = req + .missing + .and_then(|val| f64_to_fastfield_u64(val, &req.field_type)); Self { - field_type, - extended_stats: IntermediateExtendedStats::with_sigma(sigma), - accessor_idx, + name: req.name.clone(), + field_type: req.field_type, + accessor: req.accessor.clone(), + column_block_accessor: req.column_block_accessor.clone(), missing, - val_cache: Default::default(), - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = self.missing.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); + buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16], + sigma, } } } @@ -370,15 +348,18 @@ impl SegmentExtendedStatsCollector { impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + let name = self.name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); results.push( name, IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( - self.extended_stats, + extended_stats, )), )?; @@ -388,39 +369,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + _agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = self.missing { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - has_val = true; - } - if !has_val { - self.extended_stats - .collect(f64_from_fastfield_u64(missing, &self.field_type)); - } + let mut extended_stats = self.buckets[parent_bucket_id as usize].clone(); + + if let Some(missing) = self.missing.as_ref() { + self.column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, *missing); } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - } + self.column_block_accessor.fetch_block(docs, &self.accessor); } + for val in self.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + extended_stats.collect(val1); + } + + // store back + self.buckets[parent_bucket_id as usize] = extended_stats; Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + if self.buckets.len() <= max_bucket as usize { + self.buckets.resize_with(max_bucket as usize + 1, || { + IntermediateExtendedStats::with_sigma(self.sigma) + }); + } Ok(()) } } diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs index 89c6e4458..59af7e2de 100644 --- a/src/aggregation/metric/max.rs +++ b/src/aggregation/metric/max.rs @@ -52,10 +52,8 @@ pub struct IntermediateMax { impl IntermediateMax { /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMax) { diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs index 61fd2ecd2..ecf2fcafc 100644 --- a/src/aggregation/metric/min.rs +++ b/src/aggregation/metric/min.rs @@ -52,10 +52,8 @@ pub struct IntermediateMin { impl IntermediateMin { /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMin) { diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index c846e2187..f155b498e 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// # Percentiles /// @@ -133,7 +132,7 @@ impl PercentilesAggregationReq { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentPercentilesCollector { - pub(crate) percentiles: PercentilesCollector, + pub(crate) buckets: Vec, pub(crate) accessor_idx: usize, } @@ -231,16 +230,46 @@ impl PercentilesCollector { impl SegmentPercentilesCollector { pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result { Ok(Self { - percentiles: PercentilesCollector::new(), + buckets: Vec::with_capacity(64), accessor_idx, }) } +} + +impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] - pub(crate) fn collect_block_with_field( + fn add_intermediate_aggregation_result( &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, + ) -> crate::Result<()> { + let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + // Swap collector with an empty one to avoid cloning + let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + + let intermediate_metric_result = + IntermediateMetricResult::Percentiles(percentiles_collector); + + results.push( + name, + IntermediateAggregationResult::Metric(intermediate_metric_result), + )?; + + Ok(()) + } + + #[inline] + fn collect( + &mut self, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + let percentiles = &mut self.buckets[parent_bucket_id as usize]; + let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); + if let Some(missing) = req_data.missing_u64.as_ref() { req_data.column_block_accessor.fetch_block_with_missing( docs, @@ -255,66 +284,20 @@ impl SegmentPercentilesCollector { for val in req_data.column_block_accessor.iter_vals() { let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } - } -} - -impl SegmentAggregationCollector for SegmentPercentilesCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); - let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); - - results.push( - name, - IntermediateAggregationResult::Metric(intermediate_metric_result), - )?; - - Ok(()) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - has_val = true; - } - if !has_val { - self.percentiles - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } + percentiles.collect(val1); } Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(PercentilesCollector::new()); + } Ok(()) } } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 56715fdea..4f0c5e90c 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; +use columnar::{Column, ColumnBlockAccessor, ColumnType}; use serde::{Deserialize, Serialize}; use super::*; @@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// are extracted from the aggregated documents. @@ -83,7 +83,7 @@ impl Stats { /// Intermediate result of the stats aggregation that can be combined with other intermediate /// results. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { /// The number of extracted values. pub(crate) count: u64, @@ -189,44 +189,28 @@ pub enum StatsType { #[derive(Clone, Debug)] pub(crate) struct SegmentStatsCollector { - pub(crate) stats: IntermediateStats, - pub(crate) accessor_idx: usize, + pub(crate) name: String, + pub(crate) collecting_for: StatsType, + pub(crate) is_number_or_date_type: bool, + pub(crate) field_type: ColumnType, + pub(crate) missing_u64: Option, + pub(crate) column_block_accessor: ColumnBlockAccessor, + pub(crate) accessor: Column, + pub(crate) buckets: Vec, } impl SegmentStatsCollector { - pub fn from_req(accessor_idx: usize) -> Self { + pub fn from_req(req: &MetricAggReqData) -> Self { + let buckets = vec![IntermediateStats::default()]; Self { - stats: IntermediateStats::default(), - accessor_idx, - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - if req_data.is_number_or_date_type { - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } else { - for _val in req_data.column_block_accessor.iter_vals() { - // we ignore the value and simply record that we got something - self.stats.collect(0.0); - } + name: req.name.clone(), + collecting_for: req.collecting_for, + is_number_or_date_type: req.is_number_or_date_type, + field_type: req.field_type, + missing_u64: req.missing_u64, + column_block_accessor: req.column_block_accessor.clone(), + accessor: req.accessor.clone(), + buckets, } } } @@ -234,28 +218,30 @@ impl SegmentStatsCollector { impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let req = agg_data.get_metric_req_data(self.accessor_idx); - let name = req.name.clone(); + let name = self.name.clone(); - let intermediate_metric_result = match req.collecting_for { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let stats = self.buckets[parent_bucket_id as usize]; + let intermediate_metric_result = match self.collecting_for { StatsType::Average => { - IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) + IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats)) } StatsType::Count => { - IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) + IntermediateMetricResult::Count(IntermediateCount::from_stats(stats)) } - StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), - StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), - StatsType::Stats => IntermediateMetricResult::Stats(self.stats), - StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), + StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)), + StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)), + StatsType::Stats => IntermediateMetricResult::Stats(stats), + StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)), _ => { return Err(TantivyError::InvalidArgument(format!( "Unsupported stats type for stats aggregation: {:?}", - req.collecting_for + self.collecting_for ))) } }; @@ -271,39 +257,44 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + _agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - has_val = true; - } - if !has_val { - self.stats - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); + // Copying the stats to avoid aliasing optimization issues + let mut stats = self.buckets[parent_bucket_id as usize]; + if let Some(missing) = self.missing_u64.as_ref() { + self.column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, *missing); + } else { + self.column_block_accessor.fetch_block(docs, &self.accessor); + } + if self.is_number_or_date_type { + for val in self.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + stats.collect(val1); } } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); + for _val in self.column_block_accessor.iter_vals() { + // we ignore the value and simply record that we got something + stats.collect(0.0); } } + self.buckets[parent_bucket_id as usize] = stats; Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + let required_buckets = (max_bucket as usize) + 1; + if self.buckets.len() < required_buckets { + self.buckets + .resize_with(required_buckets, IntermediateStats::default); + } Ok(()) } } diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs index 86f661679..2487c4e9d 100644 --- a/src/aggregation/metric/sum.rs +++ b/src/aggregation/metric/sum.rs @@ -52,10 +52,8 @@ pub struct IntermediateSum { impl IntermediateSum { /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateSum) { diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 6a8bdf826..02b6be41b 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -15,7 +15,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::AggregationError; +use crate::aggregation::{AggregationError, BucketId}; use crate::collector::sort_key::ReverseComparator; use crate::collector::TopNComputer; use crate::schema::OwnedValue; @@ -472,7 +472,10 @@ impl TopHitsTopNComputer { /// Create a new TopHitsCollector pub fn new(req: &TopHitsAggregationReq) -> Self { Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + top_n: TopNComputer::new_with_comparator( + req.size + req.from.unwrap_or(0), + ReverseComparator, + ), req: req.clone(), } } @@ -518,7 +521,8 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, ReverseComparator>, + buckets: Vec, DocAddress, ReverseComparator>>, + num_hits: usize, } impl TopHitsSegmentCollector { @@ -527,19 +531,29 @@ impl TopHitsSegmentCollector { accessor_idx: usize, segment_ordinal: SegmentOrdinal, ) -> Self { + let num_hits = req.size + req.from.unwrap_or(0); Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + num_hits, segment_ordinal, accessor_idx, + buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1], } } - fn into_top_hits_collector( - self, + fn get_top_hits_computer( + &mut self, + parent_bucket_id: BucketId, value_accessors: &HashMap>, req: &TopHitsAggregationReq, ) -> TopHitsTopNComputer { + if parent_bucket_id as usize >= self.buckets.len() { + return TopHitsTopNComputer::new(req); + } + let top_n = std::mem::replace( + &mut self.buckets[parent_bucket_id as usize], + TopNComputer::new(0), + ); let mut top_hits_computer = TopHitsTopNComputer::new(req); - let top_results = self.top_n.into_vec(); + let top_results = top_n.into_vec(); for res in top_results { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); @@ -554,54 +568,24 @@ impl TopHitsSegmentCollector { top_hits_computer } - - /// TODO add a specialized variant for a single sort field - fn collect_with( - &mut self, - doc_id: crate::DocId, - req: &TopHitsAggregationReq, - accessors: &[(Column, ColumnType)], - ) -> crate::Result<()> { - let sorts: Vec = req - .sort - .iter() - .enumerate() - .map(|(idx, KeyOrder { order, .. })| { - let order = *order; - let value = accessors - .get(idx) - .expect("could not find field in accessors") - .0 - .values_for_doc(doc_id) - .next(); - DocValueAndOrder { value, order } - }) - .collect(); - - self.top_n.push( - sorts, - DocAddress { - segment_ord: self.segment_ordinal, - doc_id, - }, - ); - Ok(()) - } } impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let value_accessors = &req_data.value_accessors; - let intermediate_result = IntermediateMetricResult::TopHits( - self.into_top_hits_collector(value_accessors, &req_data.req), - ); + let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer( + parent_bucket_id, + value_accessors, + &req_data.req, + )); results.push( req_data.name.to_string(), IntermediateAggregationResult::Metric(intermediate_result), @@ -611,26 +595,57 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector { /// TODO: Consider a caching layer to reduce the call overhead fn collect( &mut self, - doc_id: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - self.collect_with(doc_id, &req_data.req, &req_data.accessors)?; - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let top_n = &mut self.buckets[parent_bucket_id as usize]; let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - // TODO: Consider getting fields with the column block accessor. - for doc in docs { - self.collect_with(*doc, &req_data.req, &req_data.accessors)?; + let req = &req_data.req; + let accessors = &req_data.accessors; + for doc_id in docs { + let doc_id = *doc_id; + // TODO: this is terrible, a new vec is allocated for every doc + // We can fetch blocks instead + // We don't need to store the order for every value + let sorts: Vec = req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + DocValueAndOrder { value, order } + }) + .collect(); + + top_n.push( + sorts, + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.buckets.resize( + (max_bucket as usize) + 1, + TopNComputer::new_with_comparator(self.num_hits, ReverseComparator), + ); + Ok(()) + } } #[cfg(test)] @@ -746,7 +761,7 @@ mod tests { ], "from": 0, } - } + } })) .unwrap(); @@ -875,7 +890,7 @@ mod tests { "mixed.*", ], } - } + } }))?; let collector = AggregationCollector::from_aggs(d, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ddf60ea4c..eb8fd4773 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -133,7 +133,7 @@ mod agg_limits; pub mod agg_req; pub mod agg_result; pub mod bucket; -mod buf_collector; +pub(crate) mod cached_sub_aggs; mod collector; mod date; mod error; @@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::TokenizerManager; +/// A bucket id is a dense identifier for a bucket within an aggregation. +/// It is used to index into a Vec that hold per-bucket data. +/// +/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId. +/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket. +/// +/// This allows to have a single AggregationCollector instance per aggregation, +/// that can handle multiple buckets efficiently. +/// +/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])]. +/// For that we'll need a buffer. One Vec per bucket aggregation is needed. +pub type BucketId = u32; + /// Context parameters for aggregation execution /// /// This struct holds shared resources needed during aggregation execution: diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5cc2650b6..fc2457b81 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -8,25 +8,42 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; + +/// Monotonically increasing provider of BucketIds. +#[derive(Debug, Clone, Default)] +pub struct BucketIdProvider(u32); +impl BucketIdProvider { + /// Get the next BucketId. + pub fn next_bucket_id(&mut self) -> BucketId { + let bucket_id = self.0; + self.0 += 1; + bucket_id + } +} /// A SegmentAggregationCollector is used to collect aggregation results. pub trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()>; fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; - fn collect_block( + /// Prepare the collector for collecting up to ParentBucketId `max_bucket`. + /// This is useful so we can split allocation ahead of time of collecting. + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()>; /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. @@ -73,12 +90,13 @@ impl Debug for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - for agg in self.aggs { - agg.add_intermediate_aggregation_result(agg_data, results)?; + for agg in &mut self.aggs { + agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?; } Ok(()) @@ -86,23 +104,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data)?; - - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for collector in &mut self.aggs { - collector.collect_block(docs, agg_data)?; + collector.collect(parent_bucket_id, docs, agg_data)?; } - Ok(()) } @@ -112,4 +120,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + for collector in &mut self.aggs { + collector.prepare_max_bucket(max_bucket, agg_data)?; + } + Ok(()) + } }