diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs index db7928f5f..76dfbca9d 100644 --- a/src/aggregation/agg_limits.rs +++ b/src/aggregation/agg_limits.rs @@ -21,7 +21,10 @@ impl MemoryConsumption for HashMap { /// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT /// (500MB). The limit is shared by all SegmentCollectors -pub struct AggregationLimits { +/// +/// The memory limit is also a guard, which tracks how much it allocated and releases it's memory +/// on the shared counter. Cloning will create a new guard. +pub struct AggregationLimitsGuard { /// The counter which is shared between the aggregations for one request. memory_consumption: Arc, /// The memory_limit in bytes @@ -29,28 +32,41 @@ pub struct AggregationLimits { /// The maximum number of buckets _returned_ /// This is not counting intermediate buckets. bucket_limit: u32, + /// Allocated memory with this guard. + allocated_with_the_guard: u64, } -impl Clone for AggregationLimits { +impl Clone for AggregationLimitsGuard { fn clone(&self) -> Self { Self { memory_consumption: Arc::clone(&self.memory_consumption), memory_limit: self.memory_limit, bucket_limit: self.bucket_limit, + allocated_with_the_guard: 0, } } } -impl Default for AggregationLimits { +impl Drop for AggregationLimitsGuard { + /// Removes the memory consumed tracked by this _instance_ of AggregationLimits. + /// This is used to clear the segment specific memory consumption all at once. + fn drop(&mut self) { + self.memory_consumption + .fetch_sub(self.allocated_with_the_guard, Ordering::Relaxed); + } +} + +impl Default for AggregationLimitsGuard { fn default() -> Self { Self { memory_consumption: Default::default(), memory_limit: DEFAULT_MEMORY_LIMIT.into(), bucket_limit: DEFAULT_BUCKET_LIMIT, + allocated_with_the_guard: 0, } } } -impl AggregationLimits { +impl AggregationLimitsGuard { /// *memory_limit* /// memory_limit is defined in bytes. /// Aggregation fails when the estimated memory consumption of the aggregation is higher than @@ -67,24 +83,15 @@ impl AggregationLimits { memory_consumption: Default::default(), memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(), bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT), - } - } - - /// Create a new ResourceLimitGuard, that will release the memory when dropped. - pub fn new_guard(&self) -> ResourceLimitGuard { - ResourceLimitGuard { - // The counter which is shared between the aggregations for one request. - memory_consumption: Arc::clone(&self.memory_consumption), - // The memory_limit in bytes - memory_limit: self.memory_limit, allocated_with_the_guard: 0, } } - pub(crate) fn add_memory_consumed(&self, add_num_bytes: u64) -> crate::Result<()> { + pub(crate) fn add_memory_consumed(&mut self, add_num_bytes: u64) -> crate::Result<()> { let prev_value = self .memory_consumption .fetch_add(add_num_bytes, Ordering::Relaxed); + self.allocated_with_the_guard += add_num_bytes; validate_memory_consumption(prev_value + add_num_bytes, self.memory_limit)?; Ok(()) } @@ -109,34 +116,6 @@ fn validate_memory_consumption( Ok(()) } -pub struct ResourceLimitGuard { - /// The counter which is shared between the aggregations for one request. - memory_consumption: Arc, - /// The memory_limit in bytes - memory_limit: ByteCount, - /// Allocated memory with this guard. - allocated_with_the_guard: u64, -} - -impl ResourceLimitGuard { - pub(crate) fn add_memory_consumed(&self, add_num_bytes: u64) -> crate::Result<()> { - let prev_value = self - .memory_consumption - .fetch_add(add_num_bytes, Ordering::Relaxed); - validate_memory_consumption(prev_value + add_num_bytes, self.memory_limit)?; - Ok(()) - } -} - -impl Drop for ResourceLimitGuard { - /// Removes the memory consumed tracked by this _instance_ of AggregationLimits. - /// This is used to clear the segment specific memory consumption all at once. - fn drop(&mut self) { - self.memory_consumption - .fetch_sub(self.allocated_with_the_guard, Ordering::Relaxed); - } -} - #[cfg(test)] mod tests { use crate::aggregation::tests::exec_request_with_query; diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index bd7528d02..986d2e1d0 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -5,7 +5,6 @@ use std::io; use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn}; -use super::agg_limits::ResourceLimitGuard; use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, @@ -14,7 +13,7 @@ use super::metric::{ AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, }; -use super::segment_agg_result::AggregationLimits; +use super::segment_agg_result::AggregationLimitsGuard; use super::VecWithNames; use crate::aggregation::{f64_to_fastfield_u64, Key}; use crate::index::SegmentReader; @@ -46,7 +45,7 @@ pub struct AggregationWithAccessor { pub(crate) str_dict_column: Option, pub(crate) field_type: ColumnType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) limits: ResourceLimitGuard, + pub(crate) limits: AggregationLimitsGuard, pub(crate) column_block_accessor: ColumnBlockAccessor, /// Used for missing term aggregation, which checks all columns for existence. /// And also for `top_hits` aggregation, which may sort on multiple fields. @@ -69,7 +68,7 @@ impl AggregationWithAccessor { sub_aggregation: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: AggregationLimits, + limits: AggregationLimitsGuard, ) -> crate::Result> { let mut agg = agg.clone(); @@ -91,7 +90,7 @@ impl AggregationWithAccessor { &limits, )?, agg: agg.clone(), - limits: limits.new_guard(), + limits: limits.clone(), missing_value_for_accessor: None, str_dict_column: None, column_block_accessor: Default::default(), @@ -106,6 +105,7 @@ impl AggregationWithAccessor { value_accessors: HashMap>| -> crate::Result<()> { let (accessor, field_type) = accessors.first().expect("at least one accessor"); + let limits = limits.clone(); let res = AggregationWithAccessor { segment_ordinal, // TODO: We should do away with the `accessor` field altogether @@ -120,7 +120,7 @@ impl AggregationWithAccessor { &limits, )?, agg: agg.clone(), - limits: limits.new_guard(), + limits, missing_value_for_accessor: None, str_dict_column: None, column_block_accessor: Default::default(), @@ -245,6 +245,7 @@ impl AggregationWithAccessor { None }; + let limits = limits.clone(); let agg = AggregationWithAccessor { segment_ordinal, missing_value_for_accessor, @@ -260,7 +261,7 @@ impl AggregationWithAccessor { )?, agg: agg.clone(), str_dict_column: str_dict_column.clone(), - limits: limits.new_guard(), + limits, column_block_accessor: Default::default(), }; res.push(agg); @@ -386,7 +387,7 @@ pub(crate) fn get_aggs_with_segment_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: &AggregationLimits, + limits: &AggregationLimitsGuard, ) -> crate::Result { let mut aggss = Vec::new(); for (key, agg) in aggs.iter() { diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index a4ac827a0..3ee9726ca 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -5,7 +5,7 @@ use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; -use crate::aggregation::segment_agg_result::AggregationLimits; +use crate::aggregation::segment_agg_result::AggregationLimitsGuard; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -130,7 +130,7 @@ fn test_aggregation_flushing( let agg_res: AggregationResults = if use_distributed_collector { let collector = DistributedAggregationCollector::from_aggs( agg_req.clone(), - AggregationLimits::default(), + AggregationLimitsGuard::default(), ); let searcher = reader.searcher(); @@ -146,7 +146,7 @@ fn test_aggregation_flushing( .expect("Post deserialization failed"); intermediate_agg_result - .into_final_result(agg_req, &Default::default()) + .into_final_result(agg_req, Default::default()) .unwrap() } else { let collector = get_collector(agg_req); @@ -460,7 +460,7 @@ fn test_aggregation_level2( let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); - res.into_final_result(agg_req.clone(), &Default::default()) + res.into_final_result(agg_req.clone(), Default::default()) .unwrap() } else { let collector = get_collector(agg_req.clone()); diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 833fd9cd5..b690b0064 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -438,7 +438,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result> { // Generate the full list of buckets without gaps. // @@ -496,7 +496,7 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( is_date_agg: bool, histogram_req: &HistogramAggregation, sub_aggregation: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result> { // Normalization is column type dependent. // The request used in the the call to final is not yet be normalized. @@ -750,7 +750,7 @@ mod tests { agg_req, &index, None, - AggregationLimits::new(Some(5_000), None), + AggregationLimitsGuard::new(Some(5_000), None), ) .unwrap_err(); assert!(res.to_string().starts_with( diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 2e29d97ae..9761f610d 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -4,7 +4,6 @@ use std::ops::Range; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use crate::aggregation::agg_limits::ResourceLimitGuard; use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, @@ -270,7 +269,7 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &mut AggregationsWithAccessor, - limits: &ResourceLimitGuard, + limits: &mut AggregationLimitsGuard, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { @@ -471,7 +470,7 @@ mod tests { SegmentRangeCollector::from_req_and_validate( &req, &mut Default::default(), - &AggregationLimits::default().new_guard(), + &mut AggregationLimitsGuard::default(), field_type, 0, ) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 75a23761b..951576ac9 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -669,7 +669,7 @@ mod tests { exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, }; - use crate::aggregation::AggregationLimits; + use crate::aggregation::AggregationLimitsGuard; use crate::indexer::NoMergePolicy; use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING}; use crate::{Index, IndexWriter}; @@ -1424,7 +1424,7 @@ mod tests { agg_req, &index, None, - AggregationLimits::new(Some(50_000), None), + AggregationLimitsGuard::new(Some(50_000), None), ) .unwrap_err(); assert!(res diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 2b9ee6f61..3c5ad4eae 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -4,7 +4,7 @@ use super::agg_result::AggregationResults; use super::buf_collector::BufAggregationCollector; use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::{ - build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, + build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector, }; use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; @@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000; /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, - limits: AggregationLimits, + limits: AggregationLimitsGuard, } impl AggregationCollector { @@ -30,7 +30,7 @@ impl AggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { Self { agg, limits } } } @@ -45,7 +45,7 @@ impl AggregationCollector { /// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, - limits: AggregationLimits, + limits: AggregationLimitsGuard, } impl DistributedAggregationCollector { @@ -53,7 +53,7 @@ impl DistributedAggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { Self { agg, limits } } } @@ -115,7 +115,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - res.into_final_result(self.agg.clone(), &self.limits) + res.into_final_result(self.agg.clone(), self.limits.clone()) } } @@ -147,7 +147,7 @@ impl AggregationSegmentCollector { agg: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: &AggregationLimits, + limits: &AggregationLimitsGuard, ) -> crate::Result { let mut aggs_with_accessor = get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?; diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index f4cef1a51..0d94af933 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -22,7 +22,7 @@ use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax, IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, TopHitsTopNComputer, }; -use super::segment_agg_result::AggregationLimits; +use super::segment_agg_result::AggregationLimitsGuard; use super::{format_date, AggregationError, Key, SerializedKey}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; @@ -122,9 +122,9 @@ impl IntermediateAggregationResults { pub fn into_final_result( self, req: Aggregations, - limits: &AggregationLimits, + mut limits: AggregationLimitsGuard, ) -> crate::Result { - let res = self.into_final_result_internal(&req, limits)?; + let res = self.into_final_result_internal(&req, &mut limits)?; let bucket_count = res.get_bucket_count() as u32; if bucket_count > limits.get_bucket_limit() { return Err(TantivyError::AggregationError( @@ -141,7 +141,7 @@ impl IntermediateAggregationResults { pub(crate) fn into_final_result_internal( self, req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let mut results: FxHashMap = FxHashMap::default(); for (key, agg_res) in self.aggs_res.into_iter() { @@ -257,7 +257,7 @@ impl IntermediateAggregationResult { pub(crate) fn into_final_result( self, req: &Aggregation, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let res = match self { IntermediateAggregationResult::Bucket(bucket) => { @@ -432,7 +432,7 @@ impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, req: &Aggregation, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { match self { IntermediateBucketResult::Range(range_res) => { @@ -596,7 +596,7 @@ impl IntermediateTermBucketResult { self, req: &TermsAggregation, sub_aggregation_req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); let mut buckets: Vec = self @@ -723,7 +723,7 @@ impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, req: &Aggregations, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { Ok(BucketEntry { key_as_string: None, @@ -758,7 +758,7 @@ impl IntermediateRangeBucketEntry { req: &Aggregations, _range_req: &RangeAggregation, column_type: Option, - limits: &AggregationLimits, + limits: &mut AggregationLimitsGuard, ) -> crate::Result { let mut range_bucket_entry = RangeBucketEntry { key: self.key.into(), diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index f5f7d3142..b87865d88 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -148,7 +148,7 @@ mod agg_tests; use core::fmt; -pub use agg_limits::AggregationLimits; +pub use agg_limits::AggregationLimitsGuard; pub use collector::{ AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector, DEFAULT_BUCKET_LIMIT, @@ -458,7 +458,7 @@ mod tests { agg_req: Aggregations, index: &Index, query: Option<(&str, &str)>, - limits: AggregationLimits, + limits: AggregationLimitsGuard, ) -> crate::Result { let collector = AggregationCollector::from_aggs(agg_req, limits); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5023b943d..747543e23 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -5,7 +5,7 @@ use std::fmt::Debug; -pub(crate) use super::agg_limits::AggregationLimits; +pub(crate) use super::agg_limits::AggregationLimitsGuard; use super::agg_req::AggregationVariants; use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; @@ -103,7 +103,7 @@ pub(crate) fn build_single_agg_segment_collector( Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( range_req, &mut req.sub_aggregation, - &req.limits, + &mut req.limits, req.field_type, accessor_idx, )?)),