From 55b0b52457c87df45250243c3310b4ffad85b088 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Tue, 17 Sep 2024 14:25:47 +0800 Subject: [PATCH] Fix AggregationLimits (#2495) * change AggregationLimits behavior This fixes an issue encountered with the current behaviour of AggregationLimits. Previously we had AggregationLimits and RessourceLimitGuard, which both track the memory, but only RessourceLimitGuard released memory when dropped, while AggregationLimits did not. This PR changes AggregationLimits to be a guard itself and removes the RessourceLimitGuard. * rename AggregationLimits to AggregationLimitsGuard --- src/aggregation/agg_limits.rs | 65 +++++++------------ src/aggregation/agg_req_with_accessor.rs | 17 ++--- src/aggregation/agg_tests.rs | 8 +-- src/aggregation/bucket/histogram/histogram.rs | 6 +- src/aggregation/bucket/range.rs | 5 +- src/aggregation/bucket/term_agg.rs | 4 +- src/aggregation/collector.rs | 14 ++-- src/aggregation/intermediate_agg_result.rs | 18 ++--- src/aggregation/mod.rs | 4 +- src/aggregation/segment_agg_result.rs | 4 +- 10 files changed, 62 insertions(+), 83 deletions(-) diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs index db7928f5f..76dfbca9d 100644 --- a/src/aggregation/agg_limits.rs +++ b/src/aggregation/agg_limits.rs @@ -21,7 +21,10 @@ impl MemoryConsumption for HashMap { /// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT /// (500MB). The limit is shared by all SegmentCollectors -pub struct AggregationLimits { +/// +/// The memory limit is also a guard, which tracks how much it allocated and releases it's memory +/// on the shared counter. Cloning will create a new guard. +pub struct AggregationLimitsGuard { /// The counter which is shared between the aggregations for one request. memory_consumption: Arc, /// The memory_limit in bytes @@ -29,28 +32,41 @@ pub struct AggregationLimits { /// The maximum number of buckets _returned_ /// This is not counting intermediate buckets. bucket_limit: u32, + /// Allocated memory with this guard. + allocated_with_the_guard: u64, } -impl Clone for AggregationLimits { +impl Clone for AggregationLimitsGuard { fn clone(&self) -> Self { Self { memory_consumption: Arc::clone(&self.memory_consumption), memory_limit: self.memory_limit, bucket_limit: self.bucket_limit, + allocated_with_the_guard: 0, } } } -impl Default for AggregationLimits { +impl Drop for AggregationLimitsGuard { + /// Removes the memory consumed tracked by this _instance_ of AggregationLimits. + /// This is used to clear the segment specific memory consumption all at once. + fn drop(&mut self) { + self.memory_consumption + .fetch_sub(self.allocated_with_the_guard, Ordering::Relaxed); + } +} + +impl Default for AggregationLimitsGuard { fn default() -> Self { Self { memory_consumption: Default::default(), memory_limit: DEFAULT_MEMORY_LIMIT.into(), bucket_limit: DEFAULT_BUCKET_LIMIT, + allocated_with_the_guard: 0, } } } -impl AggregationLimits { +impl AggregationLimitsGuard { /// *memory_limit* /// memory_limit is defined in bytes. /// Aggregation fails when the estimated memory consumption of the aggregation is higher than @@ -67,24 +83,15 @@ impl AggregationLimits { memory_consumption: Default::default(), memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(), bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT), - } - } - - /// Create a new ResourceLimitGuard, that will release the memory when dropped. - pub fn new_guard(&self) -> ResourceLimitGuard { - ResourceLimitGuard { - // The counter which is shared between the aggregations for one request. - memory_consumption: Arc::clone(&self.memory_consumption), - // The memory_limit in bytes - memory_limit: self.memory_limit, allocated_with_the_guard: 0, } } - pub(crate) fn add_memory_consumed(&self, add_num_bytes: u64) -> crate::Result<()> { + pub(crate) fn add_memory_consumed(&mut self, add_num_bytes: u64) -> crate::Result<()> { let prev_value = self .memory_consumption .fetch_add(add_num_bytes, Ordering::Relaxed); + self.allocated_with_the_guard += add_num_bytes; validate_memory_consumption(prev_value + add_num_bytes, self.memory_limit)?; Ok(()) } @@ -109,34 +116,6 @@ fn validate_memory_consumption( Ok(()) } -pub struct ResourceLimitGuard { - /// The counter which is shared between the aggregations for one request. - memory_consumption: Arc, - /// The memory_limit in bytes - memory_limit: ByteCount, - /// Allocated memory with this guard. - allocated_with_the_guard: u64, -} - -impl ResourceLimitGuard { - pub(crate) fn add_memory_consumed(&self, add_num_bytes: u64) -> crate::Result<()> { - let prev_value = self - .memory_consumption - .fetch_add(add_num_bytes, Ordering::Relaxed); - validate_memory_consumption(prev_value + add_num_bytes, self.memory_limit)?; - Ok(()) - } -} - -impl Drop for ResourceLimitGuard { - /// Removes the memory consumed tracked by this _instance_ of AggregationLimits. - /// This is used to clear the segment specific memory consumption all at once. - fn drop(&mut self) { - self.memory_consumption - .fetch_sub(self.allocated_with_the_guard, Ordering::Relaxed); - } -} - #[cfg(test)] mod tests { use crate::aggregation::tests::exec_request_with_query; diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index bd7528d02..986d2e1d0 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -5,7 +5,6 @@ use std::io; use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn}; -use super::agg_limits::ResourceLimitGuard; use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, @@ -14,7 +13,7 @@ use super::metric::{ AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, }; -use super::segment_agg_result::AggregationLimits; +use super::segment_agg_result::AggregationLimitsGuard; use super::VecWithNames; use crate::aggregation::{f64_to_fastfield_u64, Key}; use crate::index::SegmentReader; @@ -46,7 +45,7 @@ pub struct AggregationWithAccessor { pub(crate) str_dict_column: Option, pub(crate) field_type: ColumnType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) limits: ResourceLimitGuard, + pub(crate) limits: AggregationLimitsGuard, pub(crate) column_block_accessor: ColumnBlockAccessor, /// Used for missing term aggregation, which checks all columns for existence. /// And also for `top_hits` aggregation, which may sort on multiple fields. @@ -69,7 +68,7 @@ impl AggregationWithAccessor { sub_aggregation: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: AggregationLimits, + limits: AggregationLimitsGuard, ) -> crate::Result> { let mut agg = agg.clone(); @@ -91,7 +90,7 @@ impl AggregationWithAccessor { &limits, )?, agg: agg.clone(), - limits: limits.new_guard(), + limits: limits.clone(), missing_value_for_accessor: None, str_dict_column: None, column_block_accessor: Default::default(), @@ -106,6 +105,7 @@ impl AggregationWithAccessor { value_accessors: HashMap>| -> crate::Result<()> { let (accessor, field_type) = accessors.first().expect("at least one accessor"); + let limits = limits.clone(); let res = AggregationWithAccessor { segment_ordinal, // TODO: We should do away with the `accessor` field altogether @@ -120,7 +120,7 @@ impl AggregationWithAccessor { &limits, )?, agg: agg.clone(), - limits: limits.new_guard(), + limits, missing_value_for_accessor: None, str_dict_column: None, column_block_accessor: Default::default(), @@ -245,6 +245,7 @@ impl AggregationWithAccessor { None }; + let limits = limits.clone(); let agg = AggregationWithAccessor { segment_ordinal, missing_value_for_accessor, @@ -260,7 +261,7 @@ impl AggregationWithAccessor { )?, agg: agg.clone(), str_dict_column: str_dict_column.clone(), - limits: limits.new_guard(), + limits, column_block_accessor: Default::default(), }; res.push(agg); @@ -386,7 +387,7 @@ pub(crate) fn get_aggs_with_segment_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: &AggregationLimits, + limits: &AggregationLimitsGuard, ) -> crate::Result { let mut aggss = Vec::new(); for (key, agg) in aggs.iter() { diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index a4ac827a0..3ee9726ca 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -5,7 +5,7 @@ use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; -use crate::aggregation::segment_agg_result::AggregationLimits; +use crate::aggregation::segment_agg_result::AggregationLimitsGuard; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -130,7 +130,7 @@ fn test_aggregation_flushing( let agg_res: AggregationResults = if use_distributed_collector { let collector = DistributedAggregationCollector::from_aggs( agg_req.clone(), - AggregationLimits::default(), + AggregationLimitsGuard::default(), ); let searcher = reader.searcher(); @@ -146,7 +146,7 @@ fn test_aggregation_flushing( .expect("Post deserialization failed"); intermediate_agg_result - .into_final_result(agg_req, &Default::default()) + .into_final_result(agg_req, Default::default()) .unwrap() } else { let collector = get_collector(agg_req); @@ -460,7 +460,7 @@ fn test_aggregation_level2( let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); - res.into_final_result(agg_req.clone(), &Default::default()) + res.into_final_result(agg_req.clone(), Default::default()) .unwrap() } else { let collector = get_collector(agg_req.clone()); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 833fd9cd5..b690b0064 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -438,7 +438,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result> { // Generate the full list of buckets without gaps. // @@ -496,7 +496,7 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( is_date_agg: bool, histogram_req: &HistogramAggregation, sub_aggregation: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result> { // Normalization is column type dependent. // The request used in the the call to final is not yet be normalized. @@ -750,7 +750,7 @@ mod tests { agg_req, &index, None, - AggregationLimits::new(Some(5_000), None), + AggregationLimitsGuard::new(Some(5_000), None), ) .unwrap_err(); assert!(res.to_string().starts_with( diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 2e29d97ae..9761f610d 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -4,7 +4,6 @@ use std::ops::Range; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use crate::aggregation::agg_limits::ResourceLimitGuard; use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, @@ -270,7 +269,7 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &mut AggregationsWithAccessor, - limits: &ResourceLimitGuard, + limits: &mut AggregationLimitsGuard, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { @@ -471,7 +470,7 @@ mod tests { SegmentRangeCollector::from_req_and_validate( &req, &mut Default::default(), - &AggregationLimits::default().new_guard(), + &mut AggregationLimitsGuard::default(), field_type, 0, ) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 75a23761b..951576ac9 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -669,7 +669,7 @@ mod tests { exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, }; - use crate::aggregation::AggregationLimits; + use crate::aggregation::AggregationLimitsGuard; use crate::indexer::NoMergePolicy; use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING}; use crate::{Index, IndexWriter}; @@ -1424,7 +1424,7 @@ mod tests { agg_req, &index, None, - AggregationLimits::new(Some(50_000), None), + AggregationLimitsGuard::new(Some(50_000), None), ) .unwrap_err(); assert!(res diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 2b9ee6f61..3c5ad4eae 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -4,7 +4,7 @@ use super::agg_result::AggregationResults; use super::buf_collector::BufAggregationCollector; use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::{ - build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, + build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector, }; use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; @@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000; /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, - limits: AggregationLimits, + limits: AggregationLimitsGuard, } impl AggregationCollector { @@ -30,7 +30,7 @@ impl AggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { Self { agg, limits } } } @@ -45,7 +45,7 @@ impl AggregationCollector { /// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, - limits: AggregationLimits, + limits: AggregationLimitsGuard, } impl DistributedAggregationCollector { @@ -53,7 +53,7 @@ impl DistributedAggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { Self { agg, limits } } } @@ -115,7 +115,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - res.into_final_result(self.agg.clone(), &self.limits) + res.into_final_result(self.agg.clone(), self.limits.clone()) } } @@ -147,7 +147,7 @@ impl AggregationSegmentCollector { agg: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: &AggregationLimits, + limits: &AggregationLimitsGuard, ) -> crate::Result { let mut aggs_with_accessor = get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?; diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index f4cef1a51..0d94af933 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -22,7 +22,7 @@ use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax, IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, TopHitsTopNComputer, }; -use super::segment_agg_result::AggregationLimits; +use super::segment_agg_result::AggregationLimitsGuard; use super::{format_date, AggregationError, Key, SerializedKey}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; @@ -122,9 +122,9 @@ impl IntermediateAggregationResults { pub fn into_final_result( self, req: Aggregations, - limits: &AggregationLimits, + mut limits: AggregationLimitsGuard, ) -> crate::Result { - let res = self.into_final_result_internal(&req, limits)?; + let res = self.into_final_result_internal(&req, &mut limits)?; let bucket_count = res.get_bucket_count() as u32; if bucket_count > limits.get_bucket_limit() { return Err(TantivyError::AggregationError( @@ -141,7 +141,7 @@ impl IntermediateAggregationResults { pub(crate) fn into_final_result_internal( self, req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let mut results: FxHashMap = FxHashMap::default(); for (key, agg_res) in self.aggs_res.into_iter() { @@ -257,7 +257,7 @@ impl IntermediateAggregationResult { pub(crate) fn into_final_result( self, req: &Aggregation, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let res = match self { IntermediateAggregationResult::Bucket(bucket) => { @@ -432,7 +432,7 @@ impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, req: &Aggregation, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { match self { IntermediateBucketResult::Range(range_res) => { @@ -596,7 +596,7 @@ impl IntermediateTermBucketResult { self, req: &TermsAggregation, sub_aggregation_req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); let mut buckets: Vec = self @@ -723,7 +723,7 @@ impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { Ok(BucketEntry { key_as_string: None, @@ -758,7 +758,7 @@ impl IntermediateRangeBucketEntry { req: &Aggregations, _range_req: &RangeAggregation, column_type: Option, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let mut range_bucket_entry = RangeBucketEntry { key: self.key.into(), diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index f5f7d3142..b87865d88 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -148,7 +148,7 @@ mod agg_tests; use core::fmt; -pub use agg_limits::AggregationLimits; +pub use agg_limits::AggregationLimitsGuard; pub use collector::{ AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector, DEFAULT_BUCKET_LIMIT, @@ -458,7 +458,7 @@ mod tests { agg_req: Aggregations, index: &Index, query: Option<(&str, &str)>, - limits: AggregationLimits, + limits: AggregationLimitsGuard, ) -> crate::Result { let collector = AggregationCollector::from_aggs(agg_req, limits); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5023b943d..747543e23 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -5,7 +5,7 @@ use std::fmt::Debug; -pub(crate) use super::agg_limits::AggregationLimits; +pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::agg_req::AggregationVariants; use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; @@ -103,7 +103,7 @@ pub(crate) fn build_single_agg_segment_collector( Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( range_req, &mut req.sub_aggregation, - &req.limits, + &mut req.limits, req.field_type, accessor_idx, )?)),