diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index deedb5781..4a496af25 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -10,9 +10,9 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, - MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds, + IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, + SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ @@ -107,21 +107,14 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } - #[inline] - pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData { - self.per_request.filter_req_data[idx] - .as_deref() - .expect("filter_req_data slot is empty (taken)") - } // ---------- mutable getters ---------- #[inline] - pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { - self.per_request.term_req_data[idx] - .as_deref_mut() - .expect("term_req_data slot is empty (taken)") + pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { + &mut self.per_request.stats_metric_req_data[idx] } + #[inline] pub(crate) fn get_cardinality_req_data_mut( &mut self, @@ -129,10 +122,7 @@ impl AggregationsSegmentCtx { ) -> &mut CardinalityAggReqData { &mut self.per_request.cardinality_req_data[idx] } - #[inline] - pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { - &mut self.per_request.stats_metric_req_data[idx] - } + #[inline] pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { self.per_request.histogram_req_data[idx] @@ -345,12 +335,19 @@ impl PerRequestAggSegCtx { pub(crate) fn build_segment_agg_collectors_root( req: &mut AggregationsSegmentCtx, ) -> crate::Result> { - build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) + build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone()) } pub(crate) fn build_segment_agg_collectors( req: &mut AggregationsSegmentCtx, nodes: &[AggRefNode], +) -> crate::Result> { + build_segment_agg_collectors_generic(req, nodes) +} + +fn build_segment_agg_collectors_generic( + req: &mut AggregationsSegmentCtx, + nodes: &[AggRefNode], ) -> crate::Result> { let mut collectors = Vec::new(); for node in nodes.iter() { @@ -398,17 +395,10 @@ pub(crate) fn build_segment_agg_collector( | StatsType::Count | StatsType::Max | StatsType::Min - | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( - node.idx_in_req_data, - ))), - StatsType::ExtendedStats(sigma) => { - Ok(Box::new(SegmentExtendedStatsCollector::from_req( - req_data.field_type, - sigma, - node.idx_in_req_data, - req_data.missing, - ))) - } + | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(req_data))), + StatsType::ExtendedStats(sigma) => Ok(Box::new( + SegmentExtendedStatsCollector::from_req(req_data, sigma), + )), StatsType::Percentiles => Ok(Box::new( SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?, )), @@ -428,9 +418,7 @@ pub(crate) fn build_segment_agg_collector( AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( req, node, )?)), - AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Range => Ok(build_segment_range_collector(req, node)?), AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate( req, node, )?)), @@ -524,6 +512,7 @@ fn build_nodes( column_block_accessor: Default::default(), name: agg_name.to_string(), req: range_req.clone(), + is_top_level, }); let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; Ok(vec![AggRefNode { @@ -543,7 +532,6 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req.clone(), is_date_histogram: false, bounds: HistogramBounds { @@ -570,7 +558,6 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), - sub_aggregation_blueprint: None, req: histo_req, is_date_histogram: true, bounds: HistogramBounds { @@ -929,8 +916,6 @@ fn build_terms_or_cardinality_nodes( column_block_accessor: Default::default(), name: agg_name.to_string(), req: TermsAggregationInternal::from_req(req), - // Will be filled later when building collectors - sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), allowed_term_ids, is_top_level, diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index cbd529146..49a8afb37 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -2,11 +2,11 @@ use serde_json::Value; use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; -use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, IndexWriter, Term}; @@ -467,8 +467,9 @@ fn test_aggregation_flushing( let reader = index.reader()?; - assert_eq!(DOC_BLOCK_SIZE, 64); - // In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. + assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64); + // In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one + // block. // // Build a request so that on the first level we have one full cache, which is then flushed. // The same cache should have some residue docs at the end, which are flushed (Range 0-70) diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index 18f2a692a..3d9e3ae82 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, }; -use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; +use crate::aggregation::segment_agg_result::{ + BucketIdProvider, CollectorClone, SegmentAggregationCollector, +}; +use crate::aggregation::BucketId; use crate::docset::DocSet; use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; use crate::schema::Schema; @@ -410,9 +414,9 @@ impl FilterAggReqData { pub(crate) fn get_memory_consumption(&self) -> usize { // Estimate: name + segment reader reference + bitset + buffer capacity self.name.len() - + std::mem::size_of::() - + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) - + self.matching_docs_buffer.capacity() * std::mem::size_of::() + + std::mem::size_of::() + + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + + self.matching_docs_buffer.capacity() * std::mem::size_of::() } } @@ -489,12 +493,19 @@ impl Debug for DocumentQueryEvaluator { } } +#[derive(Debug, Clone, PartialEq, Copy)] +struct DocCount { + doc_count: u64, + bucket_id: BucketId, +} + /// Segment collector for filter aggregation pub struct SegmentFilterCollector { - /// Document count in this bucket - doc_count: u64, + /// Document counts per bucket + buckets: Vec, /// Sub-aggregation collectors - sub_aggregations: Option>, + sub_aggregations: Option>, + bucket_id_provider: BucketIdProvider, /// Accessor index for this filter aggregation (to access FilterAggReqData) accessor_idx: usize, } @@ -511,11 +522,13 @@ impl SegmentFilterCollector { } else { None }; + let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new); Ok(SegmentFilterCollector { - doc_count: 0, + buckets: Vec::new(), sub_aggregations: sub_agg_collector, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } @@ -523,7 +536,7 @@ impl SegmentFilterCollector { impl Debug for SegmentFilterCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentFilterCollector") - .field("doc_count", &self.doc_count) + .field("buckets", &self.buckets) .field("has_sub_aggs", &self.sub_aggregations.is_some()) .field("accessor_idx", &self.accessor_idx) .finish() @@ -539,19 +552,32 @@ impl CollectorClone for SegmentFilterCollector { impl SegmentAggregationCollector for SegmentFilterCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let mut sub_results = IntermediateAggregationResults::default(); + let bucket_opt = self.buckets.get(parent_bucket_id as usize); - if let Some(sub_aggs) = self.sub_aggregations { - sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_results, + // Here we create a new bucket ID for sub-aggregations if the bucket doesn't + // exist, so that sub-aggregations can still produce results (e.g., zero doc + // count) + bucket_opt + .map(|bucket| bucket.bucket_id) + .unwrap_or(self.bucket_id_provider.next_bucket_id()), + )?; } // Create the filter bucket result let filter_bucket_result = IntermediateBucketResult::Filter { - doc_count: self.doc_count, + doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0), sub_aggregations: sub_results, }; @@ -570,32 +596,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector { Ok(()) } - fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - // Access the evaluator from FilterAggReqData - let req_data = agg_data.get_filter_req_data(self.accessor_idx); - - // O(1) BitSet lookup to check if document matches filter - if req_data.evaluator.matches_document(doc) { - self.doc_count += 1; - - // If we have sub-aggregations, collect on them for this filtered document - if let Some(sub_aggs) = &mut self.sub_aggregations { - sub_aggs.collect(doc, agg_data)?; - } - } - Ok(()) - } - - #[inline] - fn collect_block( + fn collect( &mut self, - docs: &[DocId], + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { if docs.is_empty() { return Ok(()); } + let mut bucket = self.buckets[parent_bucket_id as usize]; // Take the request data to avoid borrow checker issues with sub-aggregations let mut req = agg_data.take_filter_req_data(self.accessor_idx); @@ -604,18 +615,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector { req.evaluator .filter_batch(docs, &mut req.matching_docs_buffer); - self.doc_count += req.matching_docs_buffer.len() as u64; + bucket.doc_count += req.matching_docs_buffer.len() as u64; // Batch process sub-aggregations if we have matches if !req.matching_docs_buffer.is_empty() { if let Some(sub_aggs) = &mut self.sub_aggregations { - // Use collect_block for better sub-aggregation performance - sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; + for &doc_id in &req.matching_docs_buffer { + sub_aggs.push(bucket.bucket_id, doc_id); + } } } // Put the request data back agg_data.put_back_filter_req_data(self.accessor_idx, req); + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs.check_flush_local(agg_data)?; + } + // put back bucket + self.buckets[parent_bucket_id as usize] = bucket; Ok(()) } @@ -626,6 +643,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.buckets.push(DocCount { + doc_count: 0, + bucket_id, + }); + } + Ok(()) + } } /// Intermediate result for filter aggregation @@ -1519,9 +1551,9 @@ mod tests { let searcher = reader.searcher(); let agg = json!({ - "test": { - "filter": deserialized, - "aggs": { "count": { "value_count": { "field": "brand" } } } + "test": { + "filter": deserialized, + "aggs": { "count": { "value_count": { "field": "brand" } } } } }); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 36c0fe57e..378cf7bf4 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; -use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::BucketEntry; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -30,9 +30,6 @@ pub struct HistogramAggReqData { pub column_block_accessor: ColumnBlockAccessor, /// The name of the aggregation. pub name: String, - /// The sub aggregation blueprint, used to create sub aggregations for each bucket. - /// Will be filled during initialization of the collector. - pub sub_aggregation_blueprint: Option>, /// The histogram aggregation request. pub req: HistogramAggregation, /// True if this is a date_histogram aggregation. @@ -257,18 +254,24 @@ impl HistogramBounds { pub(crate) struct SegmentHistogramBucketEntry { pub key: f64, pub doc_count: u64, + pub bucket_id: BucketId, } impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - sub_aggregation: Option>, + sub_aggregation: &mut Option, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + self.bucket_id, + )?; } Ok(IntermediateHistogramBucketEntry { key: self.key, @@ -278,27 +281,37 @@ impl SegmentHistogramBucketEntry { } } +#[derive(Clone, Debug, Default)] +struct HistogramBuckets { + pub buckets: FxHashMap, +} + /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug)] pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. - buckets: FxHashMap, - sub_aggregations: FxHashMap>, + /// One Histogram bucket per parent bucket id. + buckets: Vec, + sub_agg: Option, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } impl SegmentAggregationCollector for SegmentHistogramCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let name = agg_data .get_histogram_req_data(self.accessor_idx) .name .clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let histogram = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -307,20 +320,13 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { let mut req = agg_data.take_histogram_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); + let buckets = &mut self.buckets[parent_bucket_id as usize].buckets; let bounds = req.bounds; let interval = req.req.interval; @@ -335,16 +341,17 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { let val = f64_from_fastfield_u64(val, &req.field_type); let bucket_pos = get_bucket_pos(val); if bounds.contains(val) { - let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { + let bucket = buckets.entry(bucket_pos).or_insert_with(|| { let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); - SegmentHistogramBucketEntry { key, doc_count: 0 } + SegmentHistogramBucketEntry { + key, + doc_count: 0, + bucket_id: self.bucket_id_provider.next_bucket_id(), + } }); bucket.doc_count += 1; - if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { - self.sub_aggregations - .entry(bucket_pos) - .or_insert_with(|| sub_aggregation_blueprint.clone()) - .collect(doc, agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.push(bucket.bucket_id, doc); } } } @@ -358,14 +365,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { .add_memory_consumed(mem_delta as u64)?; } + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } + Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregation in self.sub_aggregations.values_mut() { + if let Some(sub_aggregation) = &mut self.sub_agg { sub_aggregation.flush(agg_data)?; } + Ok(()) + } + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(HistogramBuckets { + buckets: FxHashMap::default(), + }); + } Ok(()) } } @@ -373,22 +396,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { impl SegmentHistogramCollector { fn get_memory_consumption(&self) -> usize { let self_mem = std::mem::size_of::(); - let sub_aggs_mem = self.sub_aggregations.memory_consumption(); - let buckets_mem = self.buckets.memory_consumption(); - self_mem + sub_aggs_mem + buckets_mem + let buckets_mem = self.buckets.len() * std::mem::size_of::(); + self_mem + buckets_mem } /// Converts the collector result into a intermediate bucket result. - pub fn into_intermediate_bucket_result( - self, + fn add_intermediate_bucket_result( + &mut self, agg_data: &AggregationsSegmentCtx, + histogram: HistogramBuckets, ) -> crate::Result { let mut buckets = Vec::with_capacity(self.buckets.len()); - for (bucket_pos, bucket) in self.buckets { - let bucket_res = bucket.into_intermediate_bucket_entry( - self.sub_aggregations.get(&bucket_pos).cloned(), - agg_data, - ); + for bucket in histogram.buckets.into_values() { + let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data); buckets.push(bucket_res?); } @@ -408,7 +428,7 @@ impl SegmentHistogramCollector { agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { - let blueprint = if !node.children.is_empty() { + let sub_agg = if !node.children.is_empty() { Some(build_segment_agg_collectors(agg_data, &node.children)?) } else { None @@ -423,13 +443,13 @@ impl SegmentHistogramCollector { max: f64::MAX, }); req_data.offset = req_data.req.offset.unwrap_or(0.0); - - req_data.sub_aggregation_blueprint = blueprint; + let sub_agg = sub_agg.map(CachedSubAggs::new); Ok(Self { buckets: Default::default(), - sub_aggregations: Default::default(), + sub_agg, accessor_idx: node.idx_in_req_data, + bucket_id_provider: BucketIdProvider::default(), }) } } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index c26872e9b..abf746f5d 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -8,11 +8,13 @@ use serde::{Deserialize, Serialize}; use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::agg_limits::AggregationLimitsGuard; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; use crate::aggregation::*; use crate::TantivyError; @@ -29,6 +31,8 @@ pub struct RangeAggReqData { pub req: RangeAggregation, /// The name of the aggregation. pub name: String, + /// Whether this is a top-level aggregation. + pub is_top_level: bool, } impl RangeAggReqData { @@ -151,19 +155,48 @@ pub(crate) struct SegmentRangeAndBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug)] -pub struct SegmentRangeCollector { +#[derive(Clone)] +pub struct SegmentRangeCollector { /// The buckets containing the aggregation data. - buckets: Vec, + /// One for each ParentBucketId + parent_buckets: Vec>, column_type: ColumnType, pub(crate) accessor_idx: usize, + sub_agg: Option>, + /// Here things get a bit weird. We need to assign unique bucket ids across all + /// parent buckets. So we keep track of the next available bucket id here. + /// This allows a kind of flattening of the bucket ids across all parent buckets. + /// E.g. in nested aggregations: + /// Term Agg -> Range aggregation -> Stats aggregation + /// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range + /// aggregation with 4 buckets. The Range aggregation will create buckets with ids: + /// - INFO: 0,1,2,3 + /// - ERROR: 4,5,6,7 + /// - WARN: 8,9,10,11 + /// + /// This allows the Stats aggregation to have unique bucket ids to refer to. + bucket_id_provider: BucketIdProvider, + limits: AggregationLimitsGuard, } +impl Debug for SegmentRangeCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeCollector") + .field("parent_buckets_len", &self.parent_buckets.len()) + .field("column_type", &self.column_type) + .field("accessor_idx", &self.accessor_idx) + .field("has_sub_agg", &self.sub_agg.is_some()) + .finish() + } +} + +/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry #[derive(Clone)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key, pub doc_count: u64, - pub sub_aggregation: Option>, + // pub sub_aggregation: Option>, + pub bucket_id: BucketId, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not @@ -184,48 +217,50 @@ impl Debug for SegmentRangeBucketEntry { impl SegmentRangeBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let mut sub_aggregation_res = IntermediateAggregationResults::default(); - if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)? - } else { - Default::default() - }; + let sub_aggregation = IntermediateAggregationResults::default(); Ok(IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, - sub_aggregation: sub_aggregation_res, + sub_aggregation_res: sub_aggregation, from: self.from, to: self.to, }) } } -impl SegmentAggregationCollector for SegmentRangeCollector { +impl SegmentAggregationCollector for SegmentRangeCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let field_type = self.column_type; let name = agg_data .get_range_req_data(self.accessor_idx) .name .to_string(); - let buckets: FxHashMap = self - .buckets + let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]); + + let buckets: FxHashMap = buckets .into_iter() - .map(move |range_bucket| { - Ok(( - range_to_string(&range_bucket.range, &field_type)?, - range_bucket - .bucket - .into_intermediate_bucket_entry(agg_data)?, - )) + .map(|range_bucket| { + let bucket_id = range_bucket.bucket.bucket_id; + let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?; + if let Some(sub_aggregation) = &mut self.sub_agg { + sub_aggregation + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut agg.sub_aggregation_res, + bucket_id, + )?; + } + Ok((range_to_string(&range_bucket.range, &field_type)?, agg)) }) .collect::>()?; @@ -242,73 +277,112 @@ impl SegmentAggregationCollector for SegmentRangeCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - // Take request data to avoid borrow conflicts during sub-aggregation let mut req = agg_data.take_range_req_data(self.accessor_idx); req.column_block_accessor.fetch_block(docs, &req.accessor); + let buckets = &mut self.parent_buckets[parent_bucket_id as usize]; + for (doc, val) in req .column_block_accessor .iter_docid_vals(docs, &req.accessor) { - let bucket_pos = self.get_bucket_pos(val); - let bucket = &mut self.buckets[bucket_pos]; + let bucket_pos = get_bucket_pos(val, buckets); + let bucket = &mut buckets[bucket_pos]; bucket.bucket.doc_count += 1; - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.collect(doc, agg_data)?; + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket.bucket_id, doc); } } agg_data.put_back_range_req_data(self.accessor_idx, req); + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.buckets.iter_mut() { - if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.flush(agg_data)?; - } + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.parent_buckets.len() <= max_bucket as usize { + let new_buckets = self.create_new_buckets(agg_data)?; + self.parent_buckets.push(new_buckets); + } + + Ok(()) + } +} +/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed +/// bucket storage, depending on the column type and aggregation level. +pub(crate) fn build_segment_range_collector( + agg_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let accessor_idx = node.idx_in_req_data; + let req_data = agg_data.get_range_req_data(node.idx_in_req_data); + let field_type = req_data.field_type; + + // TODO: A better metric instead of is_top_level would be the number of buckets expected. + // E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets, + // we can are still in low cardinality territory. + let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64; + + let sub_agg = if !node.children.is_empty() { + Some(build_segment_agg_collectors(agg_data, &node.children)?) + } else { + None + }; + + if is_low_card { + Ok(Box::new(SegmentRangeCollector { + sub_agg: sub_agg.map(CachedSubAggs::::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } else { + Ok(Box::new(SegmentRangeCollector { + sub_agg: sub_agg.map(CachedSubAggs::::new), + column_type: field_type, + accessor_idx, + parent_buckets: Vec::new(), + bucket_id_provider: BucketIdProvider::default(), + limits: agg_data.context.limits.clone(), + })) + } } -impl SegmentRangeCollector { - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let accessor_idx = node.idx_in_req_data; - let (field_type, ranges) = { - let req_view = req_data.get_range_req_data(node.idx_in_req_data); - (req_view.field_type, req_view.req.ranges.clone()) - }; - +impl SegmentRangeCollector { + pub(crate) fn create_new_buckets( + &mut self, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result> { + let field_type = self.column_type; + let req_data = agg_data.get_range_req_data(self.accessor_idx); // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let sub_agg_prototype = if !node.children.is_empty() { - Some(build_segment_agg_collectors(req_data, &node.children)?) - } else { - None - }; - - let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)? .iter() .map(|range| { + let bucket_id = self.bucket_id_provider.next_bucket_id(); let key = range .key .clone() @@ -324,13 +398,13 @@ impl SegmentRangeCollector { } else { Some(f64_from_fastfield_u64(range.range.start, &field_type)) }; - let sub_aggregation = sub_agg_prototype.clone(); + // let sub_aggregation = sub_agg_prototype.clone(); Ok(SegmentRangeAndBucketEntry { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation, + bucket_id, key, from, to, @@ -339,27 +413,20 @@ impl SegmentRangeCollector { }) .collect::>()?; - req_data.context.limits.add_memory_consumed( + self.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; - - Ok(SegmentRangeCollector { - buckets, - column_type: field_type, - accessor_idx, - }) - } - - #[inline] - fn get_bucket_pos(&self, val: u64) -> usize { - let pos = self - .buckets - .binary_search_by_key(&val, |probe| probe.range.start) - .unwrap_or_else(|pos| pos - 1); - debug_assert!(self.buckets[pos].range.contains(&val)); - pos + Ok(buckets) } } +#[inline] +fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize { + let pos = buckets + .binary_search_by_key(&val, |probe| probe.range.start) + .unwrap_or_else(|pos| pos - 1); + debug_assert!(buckets[pos].range.contains(&val)); + pos +} /// Converts the user provided f64 range value to fast field value space. /// @@ -517,19 +584,22 @@ mod tests { range: range.range.clone(), bucket: SegmentRangeBucketEntry { doc_count: 0, - sub_aggregation: None, key, from, to, + bucket_id: 0, }, } }) .collect(); SegmentRangeCollector { - buckets, + parent_buckets: vec![buckets], column_type: field_type, accessor_idx: 0, + sub_agg: None, + bucket_id_provider: Default::default(), + limits: AggregationLimitsGuard::default(), } } @@ -776,7 +846,7 @@ mod tests { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -799,7 +869,7 @@ mod tests { ]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64()); @@ -814,7 +884,7 @@ mod tests { let buckets = vec![(-10f64..-1f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); } @@ -823,7 +893,7 @@ mod tests { let buckets = vec![(0f64..10f64).into()]; let collector = get_collector_from_ranges(buckets, ColumnType::F64); - let buckets = collector.buckets; + let buckets = collector.parent_buckets[0].clone(); assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); } @@ -832,7 +902,7 @@ mod tests { fn range_binary_search_test_u64() { let check_ranges = |ranges: Vec| { let collector = get_collector_from_ranges(ranges, ColumnType::U64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9), 0); @@ -878,7 +948,7 @@ mod tests { let ranges = vec![(10.0..100.0).into()]; let collector = get_collector_from_ranges(ranges, ColumnType::F64); - let search = |val: u64| collector.get_bucket_pos(val); + let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]); assert_eq!(search(u64::MIN), 0); assert_eq!(search(9f64.to_u64()), 0); @@ -890,63 +960,3 @@ mod tests { // the max value } } - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use itertools::Itertools; - use rand::seq::SliceRandom; - use rand::thread_rng; - - use super::*; - use crate::aggregation::bucket::range::tests::get_collector_from_ranges; - - const TOTAL_DOCS: u64 = 1_000_000u64; - const NUM_DOCS: u64 = 50_000u64; - - fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector { - let bucket_size = num_docs / num_buckets; - let mut buckets: Vec = vec![]; - for i in 0..num_buckets { - let bucket_start = (i * bucket_size) as f64; - buckets.push((bucket_start..bucket_start + bucket_size as f64).into()) - } - - get_collector_from_ranges(buckets, ColumnType::U64) - } - - fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec { - let mut rng = thread_rng(); - - let all_docs = (0..total_docs - 1).collect_vec(); - let mut vals = all_docs - .as_slice() - .choose_multiple(&mut rng, num_docs_returned as usize) - .cloned() - .collect_vec(); - vals.sort(); - vals - } - - fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) { - let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS); - let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS); - b.iter(|| { - let mut bucket_pos = 0; - for val in &vals { - bucket_pos = collector.get_bucket_pos(*val); - } - bucket_pos - }) - } - - #[bench] - fn bench_range_100_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 100) - } - - #[bench] - fn bench_range_10_buckets(b: &mut test::Bencher) { - bench_range_binary_search(b, 10) - } -} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index d87cd0078..994cae4b8 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -17,13 +17,13 @@ use crate::aggregation::agg_data::{ }; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; -use crate::aggregation::buf_collector::BufAggregationCollector; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{format_date, Key}; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::{format_date, BucketId, Key}; use crate::error::DataCorruption; use crate::TantivyError; @@ -40,8 +40,6 @@ pub struct TermsAggReqData { pub missing_value_for_accessor: Option, /// The column block accessor to access the fast field values. pub column_block_accessor: ColumnBlockAccessor, - /// Note: sub_aggregation_blueprint is filled later when building collectors - pub sub_aggregation_blueprint: Option>, /// Used to build the correct nested result when we have an empty result. pub sug_aggregations: Aggregations, /// The name of the aggregation. @@ -257,9 +255,9 @@ pub struct TermsAggregation { /// Internally, `missing` requires some specialized handling in some scenarios. /// /// Simple Case: - /// In the simplest case, we can just put the missing value in the termmap use that. In case of - /// text we put a special u64::MAX and replace it at the end with the actual missing value, - /// when loading the text. + /// In the simplest case, we can just put the missing value in the termmap and use that. In + /// case of text we put a special u64::MAX and replace it at the end with the actual + /// missing value, when loading the text. /// Special Case 1: /// If we have multiple columns on one field, we need to have a union on the indices on both /// columns, to find docids without a value. That requires a special missing aggregation. @@ -334,86 +332,6 @@ impl TermsAggregationInternal { } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector { - #[inline(always)] - fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self { - let sub_agg = sub_agg_blueprint_opt.clone_box(); - BufAggregationCollector::new(sub_agg) - } -} - -#[derive(Debug, Clone)] -struct BoxedAggregation(Box); - -impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation { - #[inline(always)] - fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self { - BoxedAggregation(sub_agg_blueprint.clone_box()) - } -} - -impl SegmentAggregationCollector for BoxedAggregation { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - self.0 - .add_intermediate_aggregation_result(agg_data, results) - } - - #[inline(always)] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect(doc, agg_data) - } - - #[inline(always)] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.0.collect_block(docs, agg_data) - } -} - -#[derive(Debug, Clone, Copy)] -struct NoSubAgg; - -impl SegmentAggregationCollector for NoSubAgg { - #[inline(always)] - fn add_intermediate_aggregation_result( - self: Box, - _agg_data: &AggregationsSegmentCtx, - _results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect( - &mut self, - _doc: crate::DocId, - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } - - #[inline(always)] - fn collect_block( - &mut self, - _docs: &[crate::DocId], - _agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - Ok(()) - } -} - /// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed /// bucket storage, depending on the column type and aggregation level. pub(crate) fn build_segment_term_collector( @@ -450,16 +368,6 @@ pub(crate) fn build_segment_term_collector( // Build sub-aggregation blueprint if there are children. let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - { - let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx); - terms_req_data_mut.sub_aggregation_blueprint = blueprint; - } // Decide whether to use a Vec-backed or HashMap-backed bucket storage. let terms_req_data = req_data.get_term_req_data(accessor_idx); @@ -478,75 +386,56 @@ pub(crate) fn build_segment_term_collector( let max_term: usize = col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize; - // - use a Vec instead of a hashmap for our aggregation. - // - buffer aggregation of our child aggregations (in any) - #[allow(clippy::collapsible_else_if)] - if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { - if has_sub_aggregations { - let sub_agg_blueprint = &req_data - .get_term_req_data_mut(accessor_idx) - .sub_aggregation_blueprint - .as_ref() - .ok_or_else(|| { - // Handle the error case here - // For example, return an error message or a default value - TantivyError::InternalError("Sub-aggregation blueprint not found".to_string()) - })?; - let term_buckets = VecTermBuckets::new(max_term + 1, || { - let collector_clone = sub_agg_blueprint.clone_box(); - BufAggregationCollector::new(collector_clone) - }); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg); - let collector = SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + let sub_agg_collector = if has_sub_aggregations { + Some(build_segment_agg_collectors(req_data, &node.children)?) } else { - if has_sub_aggregations { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } else { - let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); - let collector: SegmentTermCollector> = - SegmentTermCollector { - term_buckets, - accessor_idx, - }; - Ok(Box::new(collector)) - } + None + }; + + let mut bucket_id_provider = BucketIdProvider::default(); + // - use a Vec instead of a hashmap for our aggregation. + if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { + let term_buckets = VecTermBuckets::new(max_term + 1, &mut bucket_id_provider); + let sub_agg = sub_agg_collector.map(CachedSubAggs::::new); + let collector: SegmentTermCollector<_, true> = SegmentTermCollector { + buckets: vec![term_buckets], + accessor_idx, + sub_agg, + bucket_id_provider, + }; + Ok(Box::new(collector)) + } else { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + // Build sub-aggregation blueprint (flat pairs) + let sub_agg = sub_agg_collector.map(CachedSubAggs::::new); + let collector: SegmentTermCollector = SegmentTermCollector { + buckets: vec![term_buckets], + accessor_idx, + sub_agg, + bucket_id_provider, + }; + Ok(Box::new(collector)) } } #[derive(Debug, Clone)] -struct Bucket { +struct Bucket { pub count: u32, - pub sub_agg: SubAgg, + pub bucket_id: BucketId, } -impl Bucket { +impl Bucket { #[inline(always)] - fn new(sub_agg: SubAgg) -> Self { - Self { count: 0, sub_agg } + fn new(bucket_id: BucketId) -> Self { + Self { + count: 0, + bucket_id, + } } } /// Abstraction over the storage used for term buckets (counts only). -trait TermAggregationMap: Clone + Debug + 'static { - type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static; - +trait TermAggregationMap: Clone + Debug + Default + 'static { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize; @@ -554,23 +443,20 @@ trait TermAggregationMap: Clone + Debug + 'static { fn term_entry( &mut self, term_id: u64, - blue_print: &dyn SegmentAggregationCollector, - ) -> &mut Bucket; - - /// If the tree of aggregations contains buffered aggregations, flush them. - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>; + bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket; /// Returns the term aggregation as a vector of (term_id, bucket) pairs, /// in any order. - fn into_vec(self) -> Vec<(u64, Bucket)>; + fn into_vec(self) -> Vec<(u64, Bucket)>; } #[derive(Clone, Debug)] -struct HashMapTermBuckets { - bucket_map: FxHashMap>, +struct HashMapTermBuckets { + bucket_map: FxHashMap, } -impl Default for HashMapTermBuckets { +impl Default for HashMapTermBuckets { #[inline(always)] fn default() -> Self { Self { @@ -579,16 +465,7 @@ impl Default for HashMapTermBuckets { } } -impl< - SubAgg: Debug - + Clone - + SegmentAggregationCollector - + for<'a> From<&'a dyn SegmentAggregationCollector> - + 'static, - > TermAggregationMap for HashMapTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for HashMapTermBuckets { #[inline] fn get_memory_consumption(&self) -> usize { self.bucket_map.memory_consumption() @@ -598,55 +475,42 @@ impl< fn term_entry( &mut self, term_id: u64, - sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket { self.bucket_map .entry(term_id) - .or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint))) + .or_insert_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in self.bucket_map.values_mut() { - bucket.sub_agg.flush(agg_data)?; - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.bucket_map.into_iter().collect() } } /// An optimized term map implementation for a compact set of term ordinals. -#[derive(Clone, Debug)] -struct VecTermBuckets { - buckets: Vec>, +#[derive(Clone, Debug, Default)] +struct VecTermBuckets { + buckets: Vec, } -impl VecTermBuckets { - fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self { +impl VecTermBuckets { + fn new(num_terms: usize, bucket_id_provider: &mut BucketIdProvider) -> Self { VecTermBuckets { - buckets: std::iter::repeat_with(item_factory_fn) - .map(Bucket::new) + buckets: std::iter::repeat_with(|| Bucket::new(bucket_id_provider.next_bucket_id())) .take(num_terms) .collect(), } } } -impl TermAggregationMap - for VecTermBuckets -{ - type SubAggregation = SubAgg; - +impl TermAggregationMap for VecTermBuckets { /// Estimate the memory consumption of this struct in bytes. fn get_memory_consumption(&self) -> usize { // We do not include `std::mem::size_of::()` // It is already measure by the parent aggregation. // // The root aggregation mem size is not measure but we do not care. - self.buckets.capacity() * std::mem::size_of::>() + self.buckets.capacity() * std::mem::size_of::() } /// Add an occurrence of the given term id. @@ -654,8 +518,8 @@ impl TermAggregat fn term_entry( &mut self, term_id: u64, - _sub_agg_blueprint: &dyn SegmentAggregationCollector, - ) -> &mut Bucket { + _bucket_id_provider: &mut BucketIdProvider, + ) -> &mut Bucket { let term_id_usize = term_id as usize; debug_assert!( term_id_usize < self.buckets.len(), @@ -666,17 +530,7 @@ impl TermAggregat unsafe { self.buckets.get_unchecked_mut(term_id_usize) } } - #[inline(always)] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for bucket in &mut self.buckets { - if bucket.count > 0 { - bucket.sub_agg.flush(agg_data)?; - } - } - Ok(()) - } - - fn into_vec(self) -> Vec<(u64, Bucket)> { + fn into_vec(self) -> Vec<(u64, Bucket)> { self.buckets .into_iter() .enumerate() @@ -686,20 +540,15 @@ impl TermAggregat } } -impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg { - #[inline(always)] - fn from(_: &'a dyn SegmentAggregationCollector) -> Self { - Self - } -} - /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug)] -struct SegmentTermCollector { +struct SegmentTermCollector { /// The buckets containing the aggregation data. - term_buckets: TermMap, + buckets: Vec, + sub_agg: Option>, accessor_idx: usize, + bucket_id_provider: BucketIdProvider, } pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { @@ -707,18 +556,22 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { (agg_name, agg_property) } -impl SegmentAggregationCollector for SegmentTermCollector -where - TermMap: TermAggregationMap, - TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>, +impl SegmentAggregationCollector + for SegmentTermCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + bucket: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; + self.prepare_max_bucket(bucket, agg_data)?; + let bucket = std::mem::take(&mut self.buckets[bucket as usize]); + let term_req = agg_data.get_term_req_data(self.accessor_idx); + let name = term_req.name.clone(); + + let bucket = + Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) } @@ -726,22 +579,14 @@ where #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let mut req_data = agg_data.take_term_req_data(self.accessor_idx); - let mem_pre = self.get_memory_consumption(); + let mut req_data = agg_data.take_term_req_data(self.accessor_idx); + if let Some(missing) = req_data.missing_value_for_accessor { req_data.column_block_accessor.fetch_block_with_missing( docs, @@ -754,37 +599,36 @@ where .fetch_block(docs, &req_data.accessor); } - if std::any::TypeId::of::() == std::any::TypeId::of::() { - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg); - bucket.count += 1; + // Build iterators first to avoid overlapping borrows with self's fields. + if let Some(sub_agg) = &mut self.sub_agg { + let term_buckets = &mut self.buckets[parent_bucket_id as usize]; + let it = req_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&(_doc, term_id)| allowed_bs.contains(term_id as u32)); + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); + } else { + Self::collect_terms_with_docs( + it, + term_buckets, + &mut self.bucket_id_provider, + sub_agg, + ); } } else { - let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref() - else { - return Err(TantivyError::InternalError( - "Could not find sub-aggregation blueprint".to_string(), - )); - }; - for (doc, term_id) in req_data - .column_block_accessor - .iter_docid_vals(docs, &req_data.accessor) - { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let bucket = self - .term_buckets - .term_entry(term_id, sub_aggregation_blueprint); - bucket.count += 1; - bucket.sub_agg.collect(doc, agg_data)?; + let term_buckets = &mut self.buckets[parent_bucket_id as usize]; + let it = req_data.column_block_accessor.iter_vals(); + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + let it = it.filter(move |&term_id| allowed_bs.contains(term_id as u32)); + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); + } else { + Self::collect_terms(it, term_buckets, &mut self.bucket_id_provider); } } @@ -796,13 +640,30 @@ where .add_memory_consumed(mem_delta as u64)?; } agg_data.put_back_term_req_data(self.accessor_idx, req_data); + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } #[inline(always)] fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.flush(agg_data)?; + if let Some(sub_agg) = &mut self.sub_agg { + sub_agg.flush(agg_data)?; + } + Ok(()) + } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + while self.buckets.len() <= max_bucket as usize { + let term_buckets: TermMap = TermMap::default(); + self.buckets.push(term_buckets); + } Ok(()) } } @@ -831,20 +692,24 @@ fn extract_missing_value( Some((key, bucket)) } -impl SegmentTermCollector +impl SegmentTermCollector where TermMap: TermAggregationMap { fn get_memory_consumption(&self) -> usize { - self.term_buckets.get_memory_consumption() + self.buckets + .iter() + .map(|b| b.get_memory_consumption()) + .sum() } #[inline] pub(crate) fn into_intermediate_bucket_result( - self, + term_req: &TermsAggReqData, + sub_agg: &mut Option>, + term_buckets: TermMap, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, Bucket)> = self.term_buckets.into_vec(); + let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec(); let order_by_sub_aggregation = matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); @@ -884,23 +749,28 @@ where TermMap: TermAggregationMap dict.reserve(entries.len()); let into_intermediate_bucket_entry = - |bucket: Bucket| -> crate::Result { - let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { + |bucket: Bucket, + sub_agg: &mut Option>| + -> crate::Result { + if let Some(sub_agg) = sub_agg { let mut sub_aggregation_res = IntermediateAggregationResults::default(); - // TODO remove box new - Box::new(bucket.sub_agg) - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - IntermediateTermBucketEntry { + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + agg_data, + &mut sub_aggregation_res, + bucket.bucket_id, + )?; + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: sub_aggregation_res, - } + }) } else { - IntermediateTermBucketEntry { + Ok(IntermediateTermBucketEntry { doc_count: bucket.count, sub_aggregation: Default::default(), - } - }; - Ok(intermediate_entry) + }) + } }; if term_req.column_type == ColumnType::Str { @@ -913,21 +783,20 @@ where TermMap: TermAggregationMap if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req) { - let intermediate_entry = into_intermediate_bucket_entry(bucket)?; + let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?; dict.insert(intermediate_key, intermediate_entry); } // Sort by term ord entries.sort_unstable_by_key(|bucket| bucket.0); - let (term_ids, buckets): (Vec, Vec>) = - entries.into_iter().unzip(); + let (term_ids, buckets): (Vec, Vec) = entries.into_iter().unzip(); let mut buckets_it = buckets.into_iter(); term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| { let bucket = buckets_it.next().unwrap(); let intermediate_entry = - into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?; + into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?; dict.insert( IntermediateKey::Str( String::from_utf8(term.to_vec()).expect("could not convert to String"), @@ -969,14 +838,14 @@ where TermMap: TermAggregationMap } } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } @@ -996,14 +865,14 @@ where TermMap: TermAggregationMap })?; for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val = Ipv6Addr::from_u128(val); dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); } } else { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?; if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); } else if term_req.column_type == ColumnType::I64 { @@ -1037,6 +906,38 @@ where TermMap: TermAggregationMap } } +impl SegmentTermCollector { + #[inline] + fn collect_terms_with_docs( + it: I, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + sub_agg: &mut CachedSubAggs, + ) where + I: Iterator, + { + for (doc, term_id) in it { + let bucket = term_buckets.term_entry(term_id, bucket_id_provider); + bucket.count += 1; + sub_agg.push(bucket.bucket_id, doc); + } + } + + #[inline] + fn collect_terms( + it: I, + term_buckets: &mut TermMap, + bucket_id_provider: &mut BucketIdProvider, + ) where + I: Iterator, + { + for term_id in it { + let bucket = term_buckets.term_entry(term_id, bucket_id_provider); + bucket.count += 1; + } + } +} + pub(crate) trait GetDocCount { fn doc_count(&self) -> u64; } @@ -1047,7 +948,7 @@ impl GetDocCount for (String, IntermediateTermBucketEntry) { } } -impl GetDocCount for (u64, Bucket) { +impl GetDocCount for (u64, Bucket) { fn doc_count(&self) -> u64 { self.1.count as u64 } diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 66f39927a..c8ab51832 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{ build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; use crate::aggregation::bucket::term_agg::TermsAggregation; +use crate::aggregation::cached_sub_aggs::CachedSubAggs; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector}; +use crate::aggregation::BucketId; /// Special aggregation to handle missing values for term aggregations. /// This missing aggregation will check multiple columns for existence. @@ -35,41 +37,55 @@ impl MissingTermAggReqData { } } +#[derive(Default, Debug, Clone)] +struct MissingCount { + missing_count: u32, + bucket_id: BucketId, +} + /// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] pub struct TermMissingAgg { - missing_count: u32, accessor_idx: usize, - sub_agg: Option>, + sub_agg: Option, + /// Idx = bucket id, Value = missing count for that bucket + missing_count_per_bucket: Vec, + bucket_id_provider: BucketIdProvider, } impl TermMissingAgg { pub(crate) fn new( - req_data: &mut AggregationsSegmentCtx, + agg_data: &mut AggregationsSegmentCtx, node: &AggRefNode, ) -> crate::Result { let has_sub_aggregations = !node.children.is_empty(); let accessor_idx = node.idx_in_req_data; let sub_agg = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?; Some(sub_aggregation) } else { None }; + let sub_agg = sub_agg.map(CachedSubAggs::new); + let bucket_id_provider = BucketIdProvider::default(); + Ok(Self { accessor_idx, sub_agg, - ..Default::default() + missing_count_per_bucket: Vec::new(), + bucket_id_provider, }) } } impl SegmentAggregationCollector for TermMissingAgg { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let term_agg = &req_data.req; let missing = term_agg @@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg { let mut entries: FxHashMap = Default::default(); + let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize]; let mut missing_entry = IntermediateTermBucketEntry { - doc_count: self.missing_count, + doc_count: missing_count.missing_count, sub_aggregation: Default::default(), }; - if let Some(sub_agg) = self.sub_agg { + if let Some(sub_agg) = &mut self.sub_agg { let mut res = IntermediateAggregationResults::default(); - sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; + sub_agg + .get_sub_agg_collector() + .add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?; missing_entry.sub_aggregation = res; } entries.insert(missing.into(), missing_entry); @@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize]; let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); - let has_value = req_data - .accessors - .iter() - .any(|(acc, _)| acc.index.has_value(doc)); - if !has_value { - self.missing_count += 1; - if let Some(sub_agg) = self.sub_agg.as_mut() { - sub_agg.collect(doc, agg_data)?; + + for doc in docs { + let doc = *doc; + let has_value = req_data + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); + if !has_value { + bucket.missing_count += 1; + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.push(bucket.bucket_id, doc); + } } } + + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.check_flush_local(agg_data)?; + } Ok(()) } - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - for doc in docs { - self.collect(*doc, agg_data)?; + while self.missing_count_per_bucket.len() <= max_bucket as usize { + let bucket_id = self.bucket_id_provider.next_bucket_id(); + self.missing_count_per_bucket.push(MissingCount { + missing_count: 0, + bucket_id, + }); + } + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if let Some(sub_agg) = self.sub_agg.as_mut() { + sub_agg.flush(agg_data)?; } Ok(()) } diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs deleted file mode 100644 index 17bc1ed35..000000000 --- a/src/aggregation/buf_collector.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::agg_data::AggregationsSegmentCtx; -use crate::DocId; - -#[cfg(test)] -pub(crate) const DOC_BLOCK_SIZE: usize = 64; - -#[cfg(not(test))] -pub(crate) const DOC_BLOCK_SIZE: usize = 256; - -pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; - -/// BufAggregationCollector buffers documents before calling collect_block(). -#[derive(Clone)] -pub(crate) struct BufAggregationCollector { - pub(crate) collector: Box, - staged_docs: DocBlock, - num_staged_docs: usize, -} - -impl std::fmt::Debug for BufAggregationCollector { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SegmentAggregationResultsCollector") - .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) - .field("num_staged_docs", &self.num_staged_docs) - .finish() - } -} - -impl BufAggregationCollector { - pub fn new(collector: Box) -> Self { - Self { - collector, - num_staged_docs: 0, - staged_docs: [0; DOC_BLOCK_SIZE], - } - } -} - -impl SegmentAggregationCollector for BufAggregationCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.staged_docs[self.num_staged_docs] = doc; - self.num_staged_docs += 1; - if self.num_staged_docs == self.staged_docs.len() { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - } - Ok(()) - } - - #[inline] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collector.collect_block(docs, agg_data)?; - Ok(()) - } - - #[inline] - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; - self.num_staged_docs = 0; - - self.collector.flush(agg_data)?; - - Ok(()) - } -} diff --git a/src/aggregation/cached_sub_aggs.rs b/src/aggregation/cached_sub_aggs.rs new file mode 100644 index 000000000..bf8106ea6 --- /dev/null +++ b/src/aggregation/cached_sub_aggs.rs @@ -0,0 +1,212 @@ +use super::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; +use crate::DocId; + +#[derive(Clone, Debug)] +/// A cache for sub-aggregations, storing doc ids per bucket id. +/// Depending on the cardinality of the parent aggregation, we use different +/// storage strategies. +/// +/// ## Low Cardinality +/// Cardinality here refers to the number of unique flattened buckets that can be created +/// by the parent aggregation. +/// Flattened buckets are the result of combining all buckets per collector +/// into a single list of buckets, where each bucket is identified by its BucketId. +/// +/// ## Usage +/// Since this is caching for sub-aggregations, it is only used by bucket +/// aggregations. +/// +/// TODO: consider using a more advanced data structure for high cardinality +/// aggregations. +/// What this datastructure does in general is to group docs by bucket id. +pub(crate) struct CachedSubAggs { + /// Only used when LOWCARD is true. + /// Cache doc ids per bucket for sub-aggregations. + /// + /// The outer Vec is indexed by BucketId. + per_bucket_docs: Vec>, + /// Only used when LOWCARD is false. + /// For higher cardinalities we use a partitioned approach to store + /// + /// partitioned Vec<(BucketId, DocId)> pairs to improve grouping locality. + partitions: [PartitionEntry; NUM_PARTITIONS], + pub(crate) sub_agg_collector: Box, + num_docs: usize, +} + +const FLUSH_THRESHOLD: usize = 1024; +const NUM_PARTITIONS: usize = 16; + +impl CachedSubAggs { + pub fn get_sub_agg_collector(&mut self) -> &mut Box { + &mut self.sub_agg_collector + } + + pub fn new(sub_agg: Box) -> Self { + Self { + per_bucket_docs: Vec::new(), + num_docs: 0, + sub_agg_collector: sub_agg, + partitions: core::array::from_fn(|_| PartitionEntry::new()), + } + } + + #[inline] + pub fn clear(&mut self) { + for v in &mut self.per_bucket_docs { + v.clear(); + } + for partition in &mut self.partitions { + partition.clear(); + } + self.num_docs = 0; + } + + #[inline] + pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) { + if LOWCARD { + let idx = bucket_id as usize; + if self.per_bucket_docs.len() <= idx { + self.per_bucket_docs.resize_with(idx + 1, Vec::new); + } + self.per_bucket_docs[idx].push(doc_id); + } else { + let idx = bucket_id % NUM_PARTITIONS as u32; + let slot = &mut self.partitions[idx as usize]; + slot.bucket_ids.push(bucket_id); + slot.docs.push(doc_id); + } + self.num_docs += 1; + } + + #[inline] + pub fn extend_with_bucket_zero(&mut self, docs: &[DocId]) { + debug_assert!( + LOWCARD, + "extend_with_bucket_zero only valid for single bucket" + ); + if self.per_bucket_docs.is_empty() { + self.per_bucket_docs.resize_with(1, Vec::new); + } + self.per_bucket_docs[0].extend_from_slice(docs); + self.num_docs += docs.len(); + } + + /// Check if we need to flush based on the number of documents cached. + /// If so, flushes the cache to the provided aggregation collector. + pub fn check_flush_local( + &mut self, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if self.num_docs >= FLUSH_THRESHOLD { + self.flush_local(agg_data)?; + } + Ok(()) + } + + /// Note: this does _not_ flush the sub aggregations + fn flush_local(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if LOWCARD { + // Pre-aggregated: call collect per bucket. + let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1); + self.sub_agg_collector + .prepare_max_bucket(max_bucket, agg_data)?; + for (bucket_id, docs) in self + .per_bucket_docs + .iter() + .enumerate() + .filter(|(_, docs)| !docs.is_empty()) + { + self.sub_agg_collector + .collect(bucket_id as BucketId, docs, agg_data)?; + } + } else { + let mut max_bucket = 0u32; + for partition in &self.partitions { + if let Some(&local_max) = partition.bucket_ids.iter().max() { + max_bucket = max_bucket.max(local_max); + } + } + + self.sub_agg_collector + .prepare_max_bucket(max_bucket, agg_data)?; + + for slot in &self.partitions { + for (bucket_id, docs) in slot.groups() { + // TODO: this may be a lot of calls to a boxed trait object. + // Consider passing the struct directly to avoid dynamic dispatch. + self.sub_agg_collector.collect(bucket_id, docs, agg_data)?; + } + } + } + self.clear(); + Ok(()) + } + + /// Note: this _does_ flush the sub aggregations + pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if self.num_docs != 0 { + self.flush_local(agg_data)?; + } + self.sub_agg_collector.flush(agg_data)?; + Ok(()) + } +} + +#[derive(Debug, Clone)] +struct PartitionEntry { + bucket_ids: Vec, + docs: Vec, +} + +impl PartitionEntry { + #[inline] + fn new() -> Self { + Self { + bucket_ids: Vec::with_capacity(FLUSH_THRESHOLD / NUM_PARTITIONS), + docs: Vec::with_capacity(FLUSH_THRESHOLD / NUM_PARTITIONS), + } + } + + #[inline] + fn clear(&mut self) { + self.bucket_ids.clear(); + self.docs.clear(); + } + + #[inline] + fn groups(&self) -> PartitionGroups<'_> { + PartitionGroups { + bucket_ids: &self.bucket_ids, + docs: &self.docs, + idx: 0, + } + } +} + +struct PartitionGroups<'a> { + bucket_ids: &'a [BucketId], + docs: &'a [DocId], + idx: usize, +} + +impl<'a> Iterator for PartitionGroups<'a> { + type Item = (BucketId, &'a [DocId]); + + fn next(&mut self) -> Option { + if self.idx >= self.bucket_ids.len() { + return None; + } + + let bucket_id = self.bucket_ids[self.idx]; + let start = self.idx; + let mut end = start + 1; + while end < self.bucket_ids.len() && self.bucket_ids[end] == bucket_id { + end += 1; + } + self.idx = end; + Some((bucket_id, &self.docs[start..end])) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 4c4c2c7f1..e6993eae6 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,9 +1,9 @@ use super::agg_req::Aggregations; use super::agg_result::AggregationResults; -use super::buf_collector::BufAggregationCollector; +use super::cached_sub_aggs::CachedSubAggs; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationCollector; use super::AggContextParams; +// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly. use crate::aggregation::agg_data::{ build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; @@ -136,7 +136,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsSegmentCtx, - agg_collector: BufAggregationCollector, + agg_collector: CachedSubAggs, error: Option, } @@ -151,8 +151,7 @@ impl AggregationSegmentCollector { ) -> crate::Result { let mut agg_data = build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; - let result = - BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); + let result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?); Ok(AggregationSegmentCollector { aggs_with_accessor: agg_data, @@ -170,26 +169,30 @@ impl SegmentCollector for AggregationSegmentCollector { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.push(0, doc); + match self .agg_collector - .collect(doc, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } - - /// The query pushes the documents to the collector via this method. - /// - /// Only valid for Collectors that ignore docs fn collect_block(&mut self, docs: &[DocId]) { if self.error.is_some() { return; } - if let Err(err) = self + self.agg_collector.extend_with_bucket_zero(docs); + match self .agg_collector - .collect_block(docs, &mut self.aggs_with_accessor) + .check_flush_local(&mut self.aggs_with_accessor) { - self.error = Some(err); + Ok(_) => {} + Err(e) => { + self.error = Some(e); + } } } @@ -200,10 +203,13 @@ impl SegmentCollector for AggregationSegmentCollector { self.agg_collector.flush(&mut self.aggs_with_accessor)?; let mut sub_aggregation_res = IntermediateAggregationResults::default(); - Box::new(self.agg_collector).add_intermediate_aggregation_result( - &self.aggs_with_accessor, - &mut sub_aggregation_res, - )?; + self.agg_collector + .get_sub_agg_collector() + .add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + 0, + )?; Ok(sub_aggregation_res) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 104131461..b20e8a042 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry { /// The number of documents in the bucket. pub doc_count: u64, /// The sub_aggregation in this bucket. - pub sub_aggregation: IntermediateAggregationResults, + pub sub_aggregation_res: IntermediateAggregationResults, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. @@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry { key: self.key.into(), doc_count: self.doc_count, sub_aggregation: self - .sub_aggregation + .sub_aggregation_res .into_final_result_internal(req, limits)?, to: self.to, from: self.from, @@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry { impl MergeFruits for IntermediateRangeBucketEntry { fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> { self.doc_count += other.doc_count; - self.sub_aggregation.merge_fruits(other.sub_aggregation)?; + self.sub_aggregation_res + .merge_fruits(other.sub_aggregation_res)?; Ok(()) } } @@ -887,7 +888,7 @@ mod tests { IntermediateRangeBucketEntry { key: IntermediateKey::Str(key.to_string()), doc_count: *doc_count, - sub_aggregation: Default::default(), + sub_aggregation_res: Default::default(), from: None, to: None, }, @@ -920,7 +921,7 @@ mod tests { doc_count: *doc_count, from: None, to: None, - sub_aggregation: get_sub_test_tree(&[( + sub_aggregation_res: get_sub_test_tree(&[( sub_aggregation_key.to_string(), *sub_aggregation_count, )]), diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index e707f2b00..57f694984 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -52,10 +52,8 @@ pub struct IntermediateAverage { impl IntermediateAverage { /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateAverage) { diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 8f3bdd3e5..77387fc5e 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -137,43 +137,27 @@ impl CardinalityAggregationReq { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentCardinalityCollector { - cardinality: CardinalityCollector, - entries: FxHashSet, + buckets: Vec, + column_type: ColumnType, accessor_idx: usize, } -impl SegmentCardinalityCollector { - pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { +#[derive(Clone, Debug, PartialEq, Default)] +pub(crate) struct SegmentCardinalityCollectorBucket { + cardinality: CardinalityCollector, + entries: FxHashSet, +} +impl SegmentCardinalityCollectorBucket { + pub fn new(column_type: ColumnType) -> Self { Self { cardinality: CardinalityCollector::new(column_type as u8), - entries: Default::default(), - accessor_idx, + entries: FxHashSet::default(), } } - - fn fetch_block_with_field( - &mut self, - docs: &[crate::DocId], - agg_data: &mut CardinalityAggReqData, - ) { - if let Some(missing) = agg_data.missing_value_for_accessor { - agg_data.column_block_accessor.fetch_block_with_missing( - docs, - &agg_data.accessor, - missing, - ); - } else { - agg_data - .column_block_accessor - .fetch_block(docs, &agg_data.accessor); - } - } - fn into_intermediate_metric_result( mut self, - agg_data: &AggregationsSegmentCtx, + req_data: &CardinalityAggReqData, ) -> crate::Result { - let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); if req_data.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); let dict = req_data @@ -194,6 +178,7 @@ impl SegmentCardinalityCollector { term_ids.push(term_ord as u32); } } + term_ids.sort_unstable(); dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| { self.cardinality.sketch.insert_any(&term); @@ -227,16 +212,48 @@ impl SegmentCardinalityCollector { } } +impl SegmentCardinalityCollector { + pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { + Self { + buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1], + column_type, + accessor_idx, + } + } + + fn fetch_block_with_field( + &mut self, + docs: &[crate::DocId], + agg_data: &mut CardinalityAggReqData, + ) { + if let Some(missing) = agg_data.missing_value_for_accessor { + agg_data.column_block_accessor.fetch_block_with_missing( + docs, + &agg_data.accessor, + missing, + ); + } else { + agg_data + .column_block_accessor + .fetch_block(docs, &agg_data.accessor); + } + } +} + impl SegmentAggregationCollector for SegmentCardinalityCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); let name = req_data.name.to_string(); + // take the bucket in buckets and replace it with a new empty one + let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); - let intermediate_result = self.into_intermediate_metric_result(agg_data)?; + let intermediate_result = bucket.into_intermediate_metric_result(req_data)?; results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -247,24 +264,18 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); self.fetch_block_with_field(docs, req_data); + let bucket = &mut self.buckets[parent_bucket_id as usize]; let col_block_accessor = &req_data.column_block_accessor; if req_data.column_type == ColumnType::Str { for term_ord in col_block_accessor.iter_vals() { - self.entries.insert(term_ord); + bucket.entries.insert(term_ord); } } else if req_data.column_type == ColumnType::IpAddr { let compact_space_accessor = req_data @@ -282,16 +293,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { })?; for val in col_block_accessor.iter_vals() { let val: u128 = compact_space_accessor.compact_to_u128(val as u32); - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } else { for val in col_block_accessor.iter_vals() { - self.cardinality.sketch.insert_any(&val); + bucket.cardinality.sketch.insert_any(&val); } } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + if max_bucket as usize >= self.buckets.len() { + self.buckets.resize_with(max_bucket as usize + 1, || { + SegmentCardinalityCollectorBucket::new(self.column_type) + }); + } + Ok(()) + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs index ac550a38f..b28ced047 100644 --- a/src/aggregation/metric/count.rs +++ b/src/aggregation/metric/count.rs @@ -52,10 +52,8 @@ pub struct IntermediateCount { impl IntermediateCount { /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateCount) { diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs index d7302e5f5..ad7165b8d 100644 --- a/src/aggregation/metric/extended_stats.rs +++ b/src/aggregation/metric/extended_stats.rs @@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of extended statistics /// on numeric values that are extracted @@ -318,51 +317,30 @@ impl IntermediateExtendedStats { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentExtendedStatsCollector { + name: String, missing: Option, field_type: ColumnType, - pub(crate) extended_stats: IntermediateExtendedStats, - pub(crate) accessor_idx: usize, - val_cache: Vec, + accessor: columnar::Column, + column_block_accessor: columnar::ColumnBlockAccessor, + buckets: Vec, + sigma: Option, } impl SegmentExtendedStatsCollector { - pub fn from_req( - field_type: ColumnType, - sigma: Option, - accessor_idx: usize, - missing: Option, - ) -> Self { - let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + pub fn from_req(req: &MetricAggReqData, sigma: Option) -> Self { + let missing = req + .missing + .and_then(|val| f64_to_fastfield_u64(val, &req.field_type)); Self { - field_type, - extended_stats: IntermediateExtendedStats::with_sigma(sigma), - accessor_idx, + name: req.name.clone(), + field_type: req.field_type, + accessor: req.accessor.clone(), + column_block_accessor: req.column_block_accessor.clone(), missing, - val_cache: Default::default(), - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = self.missing.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); + buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16], + sigma, } } } @@ -370,15 +348,18 @@ impl SegmentExtendedStatsCollector { impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + let name = self.name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); results.push( name, IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( - self.extended_stats, + extended_stats, )), )?; @@ -388,39 +369,39 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + _agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = self.missing { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - has_val = true; - } - if !has_val { - self.extended_stats - .collect(f64_from_fastfield_u64(missing, &self.field_type)); - } + let mut extended_stats = self.buckets[parent_bucket_id as usize].clone(); + + if let Some(missing) = self.missing.as_ref() { + self.column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, *missing); } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); - self.extended_stats.collect(val1); - } + self.column_block_accessor.fetch_block(docs, &self.accessor); } + for val in self.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + extended_stats.collect(val1); + } + + // store back + self.buckets[parent_bucket_id as usize] = extended_stats; Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + if self.buckets.len() <= max_bucket as usize { + self.buckets.resize_with(max_bucket as usize + 1, || { + IntermediateExtendedStats::with_sigma(self.sigma) + }); + } Ok(()) } } diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs index 89c6e4458..59af7e2de 100644 --- a/src/aggregation/metric/max.rs +++ b/src/aggregation/metric/max.rs @@ -52,10 +52,8 @@ pub struct IntermediateMax { impl IntermediateMax { /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMax) { diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs index 61fd2ecd2..ecf2fcafc 100644 --- a/src/aggregation/metric/min.rs +++ b/src/aggregation/metric/min.rs @@ -52,10 +52,8 @@ pub struct IntermediateMin { impl IntermediateMin { /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateMin) { diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index c846e2187..f155b498e 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// # Percentiles /// @@ -133,7 +132,7 @@ impl PercentilesAggregationReq { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentPercentilesCollector { - pub(crate) percentiles: PercentilesCollector, + pub(crate) buckets: Vec, pub(crate) accessor_idx: usize, } @@ -231,16 +230,46 @@ impl PercentilesCollector { impl SegmentPercentilesCollector { pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result { Ok(Self { - percentiles: PercentilesCollector::new(), + buckets: Vec::with_capacity(64), accessor_idx, }) } +} + +impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] - pub(crate) fn collect_block_with_field( + fn add_intermediate_aggregation_result( &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, + ) -> crate::Result<()> { + let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + // Swap collector with an empty one to avoid cloning + let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]); + + let intermediate_metric_result = + IntermediateMetricResult::Percentiles(percentiles_collector); + + results.push( + name, + IntermediateAggregationResult::Metric(intermediate_metric_result), + )?; + + Ok(()) + } + + #[inline] + fn collect( + &mut self, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + let percentiles = &mut self.buckets[parent_bucket_id as usize]; + let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); + if let Some(missing) = req_data.missing_u64.as_ref() { req_data.column_block_accessor.fetch_block_with_missing( docs, @@ -255,66 +284,20 @@ impl SegmentPercentilesCollector { for val in req_data.column_block_accessor.iter_vals() { let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } - } -} - -impl SegmentAggregationCollector for SegmentPercentilesCollector { - #[inline] - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); - let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); - - results.push( - name, - IntermediateAggregationResult::Metric(intermediate_metric_result), - )?; - - Ok(()) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - has_val = true; - } - if !has_val { - self.percentiles - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); - } - } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.percentiles.collect(val1); - } + percentiles.collect(val1); } Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + while self.buckets.len() <= max_bucket as usize { + self.buckets.push(PercentilesCollector::new()); + } Ok(()) } } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 56715fdea..4f0c5e90c 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; +use columnar::{Column, ColumnBlockAccessor, ColumnType}; use serde::{Deserialize, Serialize}; use super::*; @@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; -use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; -use crate::{DocId, TantivyError}; +use crate::TantivyError; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// are extracted from the aggregated documents. @@ -83,7 +83,7 @@ impl Stats { /// Intermediate result of the stats aggregation that can be combined with other intermediate /// results. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { /// The number of extracted values. pub(crate) count: u64, @@ -189,44 +189,28 @@ pub enum StatsType { #[derive(Clone, Debug)] pub(crate) struct SegmentStatsCollector { - pub(crate) stats: IntermediateStats, - pub(crate) accessor_idx: usize, + pub(crate) name: String, + pub(crate) collecting_for: StatsType, + pub(crate) is_number_or_date_type: bool, + pub(crate) field_type: ColumnType, + pub(crate) missing_u64: Option, + pub(crate) column_block_accessor: ColumnBlockAccessor, + pub(crate) accessor: Column, + pub(crate) buckets: Vec, } impl SegmentStatsCollector { - pub fn from_req(accessor_idx: usize) -> Self { + pub fn from_req(req: &MetricAggReqData) -> Self { + let buckets = vec![IntermediateStats::default()]; Self { - stats: IntermediateStats::default(), - accessor_idx, - } - } - #[inline] - pub(crate) fn collect_block_with_field( - &mut self, - docs: &[DocId], - req_data: &mut MetricAggReqData, - ) { - if let Some(missing) = req_data.missing_u64.as_ref() { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - *missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - if req_data.is_number_or_date_type { - for val in req_data.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - } - } else { - for _val in req_data.column_block_accessor.iter_vals() { - // we ignore the value and simply record that we got something - self.stats.collect(0.0); - } + name: req.name.clone(), + collecting_for: req.collecting_for, + is_number_or_date_type: req.is_number_or_date_type, + field_type: req.field_type, + missing_u64: req.missing_u64, + column_block_accessor: req.column_block_accessor.clone(), + accessor: req.accessor.clone(), + buckets, } } } @@ -234,28 +218,30 @@ impl SegmentStatsCollector { impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - let req = agg_data.get_metric_req_data(self.accessor_idx); - let name = req.name.clone(); + let name = self.name.clone(); - let intermediate_metric_result = match req.collecting_for { + self.prepare_max_bucket(parent_bucket_id, agg_data)?; + let stats = self.buckets[parent_bucket_id as usize]; + let intermediate_metric_result = match self.collecting_for { StatsType::Average => { - IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) + IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats)) } StatsType::Count => { - IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) + IntermediateMetricResult::Count(IntermediateCount::from_stats(stats)) } - StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), - StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), - StatsType::Stats => IntermediateMetricResult::Stats(self.stats), - StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), + StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)), + StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)), + StatsType::Stats => IntermediateMetricResult::Stats(stats), + StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)), _ => { return Err(TantivyError::InvalidArgument(format!( "Unsupported stats type for stats aggregation: {:?}", - req.collecting_for + self.collecting_for ))) } }; @@ -271,39 +257,44 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, + parent_bucket_id: BucketId, + docs: &[crate::DocId], + _agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = req_data.missing_u64 { - let mut has_val = false; - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); - has_val = true; - } - if !has_val { - self.stats - .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); + // Copying the stats to avoid aliasing optimization issues + let mut stats = self.buckets[parent_bucket_id as usize]; + if let Some(missing) = self.missing_u64.as_ref() { + self.column_block_accessor + .fetch_block_with_missing(docs, &self.accessor, *missing); + } else { + self.column_block_accessor.fetch_block(docs, &self.accessor); + } + if self.is_number_or_date_type { + for val in self.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + stats.collect(val1); } } else { - for val in req_data.accessor.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &req_data.field_type); - self.stats.collect(val1); + for _val in self.column_block_accessor.iter_vals() { + // we ignore the value and simply record that we got something + stats.collect(0.0); } } + self.buckets[parent_bucket_id as usize] = stats; Ok(()) } - #[inline] - fn collect_block( + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()> { - let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); - self.collect_block_with_field(docs, req_data); + let required_buckets = (max_bucket as usize) + 1; + if self.buckets.len() < required_buckets { + self.buckets + .resize_with(required_buckets, IntermediateStats::default); + } Ok(()) } } diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs index 86f661679..2487c4e9d 100644 --- a/src/aggregation/metric/sum.rs +++ b/src/aggregation/metric/sum.rs @@ -52,10 +52,8 @@ pub struct IntermediateSum { impl IntermediateSum { /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. - pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { - Self { - stats: collector.stats, - } + pub(crate) fn from_stats(stats: IntermediateStats) -> Self { + Self { stats } } /// Merges the other intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateSum) { diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 6a8bdf826..02b6be41b 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -15,7 +15,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::AggregationError; +use crate::aggregation::{AggregationError, BucketId}; use crate::collector::sort_key::ReverseComparator; use crate::collector::TopNComputer; use crate::schema::OwnedValue; @@ -472,7 +472,10 @@ impl TopHitsTopNComputer { /// Create a new TopHitsCollector pub fn new(req: &TopHitsAggregationReq) -> Self { Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + top_n: TopNComputer::new_with_comparator( + req.size + req.from.unwrap_or(0), + ReverseComparator, + ), req: req.clone(), } } @@ -518,7 +521,8 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, ReverseComparator>, + buckets: Vec, DocAddress, ReverseComparator>>, + num_hits: usize, } impl TopHitsSegmentCollector { @@ -527,19 +531,29 @@ impl TopHitsSegmentCollector { accessor_idx: usize, segment_ordinal: SegmentOrdinal, ) -> Self { + let num_hits = req.size + req.from.unwrap_or(0); Self { - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + num_hits, segment_ordinal, accessor_idx, + buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1], } } - fn into_top_hits_collector( - self, + fn get_top_hits_computer( + &mut self, + parent_bucket_id: BucketId, value_accessors: &HashMap>, req: &TopHitsAggregationReq, ) -> TopHitsTopNComputer { + if parent_bucket_id as usize >= self.buckets.len() { + return TopHitsTopNComputer::new(req); + } + let top_n = std::mem::replace( + &mut self.buckets[parent_bucket_id as usize], + TopNComputer::new(0), + ); let mut top_hits_computer = TopHitsTopNComputer::new(req); - let top_results = self.top_n.into_vec(); + let top_results = top_n.into_vec(); for res in top_results { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); @@ -554,54 +568,24 @@ impl TopHitsSegmentCollector { top_hits_computer } - - /// TODO add a specialized variant for a single sort field - fn collect_with( - &mut self, - doc_id: crate::DocId, - req: &TopHitsAggregationReq, - accessors: &[(Column, ColumnType)], - ) -> crate::Result<()> { - let sorts: Vec = req - .sort - .iter() - .enumerate() - .map(|(idx, KeyOrder { order, .. })| { - let order = *order; - let value = accessors - .get(idx) - .expect("could not find field in accessors") - .0 - .values_for_doc(doc_id) - .next(); - DocValueAndOrder { value, order } - }) - .collect(); - - self.top_n.push( - sorts, - DocAddress { - segment_ord: self.segment_ordinal, - doc_id, - }, - ); - Ok(()) - } } impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let value_accessors = &req_data.value_accessors; - let intermediate_result = IntermediateMetricResult::TopHits( - self.into_top_hits_collector(value_accessors, &req_data.req), - ); + let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer( + parent_bucket_id, + value_accessors, + &req_data.req, + )); results.push( req_data.name.to_string(), IntermediateAggregationResult::Metric(intermediate_result), @@ -611,26 +595,57 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector { /// TODO: Consider a caching layer to reduce the call overhead fn collect( &mut self, - doc_id: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - self.collect_with(doc_id, &req_data.req, &req_data.accessors)?; - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { + let top_n = &mut self.buckets[parent_bucket_id as usize]; let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - // TODO: Consider getting fields with the column block accessor. - for doc in docs { - self.collect_with(*doc, &req_data.req, &req_data.accessors)?; + let req = &req_data.req; + let accessors = &req_data.accessors; + for doc_id in docs { + let doc_id = *doc_id; + // TODO: this is terrible, a new vec is allocated for every doc + // We can fetch blocks instead + // We don't need to store the order for every value + let sorts: Vec = req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + DocValueAndOrder { value, order } + }) + .collect(); + + top_n.push( + sorts, + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + _agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.buckets.resize( + (max_bucket as usize) + 1, + TopNComputer::new_with_comparator(self.num_hits, ReverseComparator), + ); + Ok(()) + } } #[cfg(test)] @@ -746,7 +761,7 @@ mod tests { ], "from": 0, } - } + } })) .unwrap(); @@ -875,7 +890,7 @@ mod tests { "mixed.*", ], } - } + } }))?; let collector = AggregationCollector::from_aggs(d, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ddf60ea4c..eb8fd4773 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -133,7 +133,7 @@ mod agg_limits; pub mod agg_req; pub mod agg_result; pub mod bucket; -mod buf_collector; +pub(crate) mod cached_sub_aggs; mod collector; mod date; mod error; @@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::TokenizerManager; +/// A bucket id is a dense identifier for a bucket within an aggregation. +/// It is used to index into a Vec that hold per-bucket data. +/// +/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId. +/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket. +/// +/// This allows to have a single AggregationCollector instance per aggregation, +/// that can handle multiple buckets efficiently. +/// +/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])]. +/// For that we'll need a buffer. One Vec per bucket aggregation is needed. +pub type BucketId = u32; + /// Context parameters for aggregation execution /// /// This struct holds shared resources needed during aggregation execution: diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5cc2650b6..fc2457b81 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -8,25 +8,42 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::agg_data::AggregationsSegmentCtx; +use crate::aggregation::BucketId; + +/// Monotonically increasing provider of BucketIds. +#[derive(Debug, Clone, Default)] +pub struct BucketIdProvider(u32); +impl BucketIdProvider { + /// Get the next BucketId. + pub fn next_bucket_id(&mut self) -> BucketId { + let bucket_id = self.0; + self.0 += 1; + bucket_id + } +} /// A SegmentAggregationCollector is used to collect aggregation results. pub trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()>; fn collect( &mut self, - doc: crate::DocId, + parent_bucket_id: BucketId, + docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; - fn collect_block( + /// Prepare the collector for collecting up to ParentBucketId `max_bucket`. + /// This is useful so we can split allocation ahead of time of collecting. + fn prepare_max_bucket( &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result<()>; /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. @@ -73,12 +90,13 @@ impl Debug for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn add_intermediate_aggregation_result( - self: Box, + &mut self, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, + parent_bucket_id: BucketId, ) -> crate::Result<()> { - for agg in self.aggs { - agg.add_intermediate_aggregation_result(agg_data, results)?; + for agg in &mut self.aggs { + agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?; } Ok(()) @@ -86,23 +104,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect( &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data)?; - - Ok(()) - } - - fn collect_block( - &mut self, + parent_bucket_id: BucketId, docs: &[crate::DocId], agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for collector in &mut self.aggs { - collector.collect_block(docs, agg_data)?; + collector.collect(parent_bucket_id, docs, agg_data)?; } - Ok(()) } @@ -112,4 +120,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { } Ok(()) } + + fn prepare_max_bucket( + &mut self, + max_bucket: BucketId, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result<()> { + for collector in &mut self.aggs { + collector.prepare_max_bucket(max_bucket, agg_data)?; + } + Ok(()) + } }