From 45ff0e3c5c5c7130585ff3aa8bcb41586755d3d3 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Mon, 8 May 2023 16:15:09 +0800 Subject: [PATCH] clear memory consumption in AggregationLimits (#2022) * clear memory consumption in AggregationLimits clear memory consumption in AggregationLimits at the end of segment collection * switch to ResourceLimitGuard * unduplicate code * merge methods * Apply suggestions from code review Co-authored-by: Paul Masurel --------- Co-authored-by: Paul Masurel --- src/aggregation/agg_limits.rs | 79 ++++++++++++++----- src/aggregation/agg_req_with_accessor.rs | 14 ++-- src/aggregation/bucket/histogram/histogram.rs | 19 ++--- src/aggregation/bucket/range.rs | 15 ++-- src/aggregation/bucket/term_agg.rs | 8 +- src/aggregation/collector.rs | 7 +- src/aggregation/mod.rs | 21 +++-- src/aggregation/segment_agg_result.rs | 24 +++--- 8 files changed, 116 insertions(+), 71 deletions(-) diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs index 0cd1df8bd..df93a9c13 100644 --- a/src/aggregation/agg_limits.rs +++ b/src/aggregation/agg_limits.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; -use std::sync::atomic::AtomicU64; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use common::ByteCount; use super::collector::DEFAULT_MEMORY_LIMIT; use super::{AggregationError, DEFAULT_BUCKET_LIMIT}; -use crate::TantivyError; /// An estimate for memory consumption. Non recursive pub trait MemoryConsumption { @@ -68,28 +67,68 @@ impl AggregationLimits { bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT), } } - pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> { - if self.get_memory_consumed() > self.memory_limit { - return Err(TantivyError::AggregationError( - AggregationError::MemoryExceeded { - limit: self.memory_limit, - current: self.get_memory_consumed(), - }, - )); + + /// Create a new ResourceLimitGuard, that will release the memory when dropped. + pub fn new_guard(&self) -> ResourceLimitGuard { + ResourceLimitGuard { + /// The counter which is shared between the aggregations for one request. + memory_consumption: Arc::clone(&self.memory_consumption), + /// The memory_limit in bytes + memory_limit: self.memory_limit, + allocated_with_the_guard: 0, } + } + + pub(crate) fn add_memory_consumed(&self, num_bytes: u64) -> crate::Result<()> { + self.memory_consumption + .fetch_add(num_bytes, Ordering::Relaxed); + validate_memory_consumption(&self.memory_consumption, self.memory_limit)?; Ok(()) } - pub(crate) fn add_memory_consumed(&self, num_bytes: u64) { - self.memory_consumption - .fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed); - } - /// Returns the estimated memory consumed by the aggregations - pub fn get_memory_consumed(&self) -> ByteCount { - self.memory_consumption - .load(std::sync::atomic::Ordering::Relaxed) - .into() - } + pub(crate) fn get_bucket_limit(&self) -> u32 { self.bucket_limit } } + +fn validate_memory_consumption( + memory_consumption: &AtomicU64, + memory_limit: ByteCount, +) -> Result<(), AggregationError> { + // Load the estimated memory consumed by the aggregations + let memory_consumed: ByteCount = memory_consumption.load(Ordering::Relaxed).into(); + if memory_consumed > memory_limit { + return Err(AggregationError::MemoryExceeded { + limit: memory_limit, + current: memory_consumed, + }); + } + Ok(()) +} + +pub struct ResourceLimitGuard { + /// The counter which is shared between the aggregations for one request. + memory_consumption: Arc, + /// The memory_limit in bytes + memory_limit: ByteCount, + /// Allocated memory with this guard. + allocated_with_the_guard: u64, +} + +impl ResourceLimitGuard { + pub(crate) fn add_memory_consumed(&self, num_bytes: u64) -> crate::Result<()> { + self.memory_consumption + .fetch_add(num_bytes, Ordering::Relaxed); + validate_memory_consumption(&self.memory_consumption, self.memory_limit)?; + Ok(()) + } +} + +impl Drop for ResourceLimitGuard { + /// Removes the memory consumed tracked by this _instance_ of AggregationLimits. + /// This is used to clear the segment specific memory consumption all at once. + fn drop(&mut self) { + self.memory_consumption + .fetch_sub(self.allocated_with_the_guard, Ordering::Relaxed); + } +} diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 1ac2b7da6..46116e8cb 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -2,6 +2,7 @@ use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; +use super::agg_limits::ResourceLimitGuard; use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, @@ -14,7 +15,7 @@ use super::segment_agg_result::AggregationLimits; use super::VecWithNames; use crate::SegmentReader; -#[derive(Clone, Default)] +#[derive(Default)] pub(crate) struct AggregationsWithAccessor { pub aggs: VecWithNames, } @@ -29,7 +30,6 @@ impl AggregationsWithAccessor { } } -#[derive(Clone)] pub struct AggregationWithAccessor { /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. So eventually this needs to be Option or moved. @@ -37,7 +37,7 @@ pub struct AggregationWithAccessor { pub(crate) str_dict_column: Option, pub(crate) field_type: ColumnType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) limits: AggregationLimits, + pub(crate) limits: ResourceLimitGuard, pub(crate) column_block_accessor: ColumnBlockAccessor, pub(crate) agg: Aggregation, } @@ -106,14 +106,14 @@ impl AggregationWithAccessor { Ok(AggregationWithAccessor { accessor, field_type, - sub_aggregation: get_aggs_with_accessor_and_validate( + sub_aggregation: get_aggs_with_segment_accessor_and_validate( &sub_aggregation, reader, - &limits.clone(), + &limits, )?, agg: agg.clone(), str_dict_column, - limits, + limits: limits.new_guard(), column_block_accessor: Default::default(), }) } @@ -128,7 +128,7 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] { ] } -pub(crate) fn get_aggs_with_accessor_and_validate( +pub(crate) fn get_aggs_with_segment_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, limits: &AggregationLimits, diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 55392acdc..010187a0b 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -281,9 +281,9 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; - limits.add_memory_consumed(mem_delta as u64); - limits.validate_memory_consumption()?; + bucket_agg_accessor + .limits + .add_memory_consumed(mem_delta as u64)?; Ok(()) } @@ -335,7 +335,7 @@ impl SegmentHistogramCollector { pub(crate) fn from_req_and_validate( req: &HistogramAggregation, - sub_aggregation: &AggregationsWithAccessor, + sub_aggregation: &mut AggregationsWithAccessor, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { @@ -402,8 +402,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( .saturating_sub(buckets.len()); limits.add_memory_consumed( added_buckets as u64 * std::mem::size_of::() as u64, - ); - limits.validate_memory_consumption()?; + )?; // create buckets let fill_gaps_buckets = generate_buckets_with_opt_minmax(histogram_req, min_max); @@ -693,11 +692,9 @@ mod tests { AggregationLimits::new(Some(5_000), None), ) .unwrap_err(); - assert_eq!( - res.to_string(), - "Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \ - 59.82 KB" - ); + assert!(res.to_string().starts_with( + "Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current" + )); Ok(()) } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index c03414558..f478b2f2e 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -5,13 +5,14 @@ use columnar::{ColumnType, MonotonicallyMappableToU64}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; +use crate::aggregation::agg_limits::ResourceLimitGuard; use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, + build_segment_agg_collector, SegmentAggregationCollector, }; use crate::aggregation::{ f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, @@ -260,8 +261,8 @@ impl SegmentAggregationCollector for SegmentRangeCollector { impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, - sub_aggregation: &AggregationsWithAccessor, - limits: &AggregationLimits, + sub_aggregation: &mut AggregationsWithAccessor, + limits: &mut ResourceLimitGuard, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { @@ -307,8 +308,7 @@ impl SegmentRangeCollector { limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, - ); - limits.validate_memory_consumption()?; + )?; Ok(SegmentRangeCollector { buckets, @@ -450,6 +450,7 @@ mod tests { exec_request, exec_request_with_query, get_test_index_2_segments, get_test_index_with_num_docs, }; + use crate::aggregation::AggregationLimits; pub fn get_collector_from_ranges( ranges: Vec, @@ -463,8 +464,8 @@ mod tests { SegmentRangeCollector::from_req_and_validate( &req, - &Default::default(), - &Default::default(), + &mut Default::default(), + &mut AggregationLimits::default().new_guard(), field_type, 0, ) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 19b0c948c..80f2ce30e 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -295,9 +295,9 @@ impl SegmentAggregationCollector for SegmentTermCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; - limits.add_memory_consumed(mem_delta as u64); - limits.validate_memory_consumption()?; + bucket_agg_accessor + .limits + .add_memory_consumed(mem_delta as u64)?; Ok(()) } @@ -320,7 +320,7 @@ impl SegmentTermCollector { pub(crate) fn from_req_and_validate( req: &TermsAggregation, - sub_aggregations: &AggregationsWithAccessor, + sub_aggregations: &mut AggregationsWithAccessor, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 183cc2425..b3a0ed917 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -6,7 +6,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::{ build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; -use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; +use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; use crate::{DocId, SegmentReader, TantivyError}; @@ -137,9 +137,10 @@ impl AggregationSegmentCollector { reader: &SegmentReader, limits: &AggregationLimits, ) -> crate::Result { - let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, limits)?; + let mut aggs_with_accessor = + get_aggs_with_segment_accessor_and_validate(agg, reader, limits)?; let result = - BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?); + BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?); Ok(AggregationSegmentCollector { aggs_with_accessor, agg_collector: result, diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index d878b76fe..32741a29c 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -156,13 +156,22 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; /// Represents an associative array `(key => values)` in a very efficient manner. -#[derive(Clone, PartialEq, Serialize, Deserialize)] -pub(crate) struct VecWithNames { +#[derive(PartialEq, Serialize, Deserialize)] +pub(crate) struct VecWithNames { pub(crate) values: Vec, keys: Vec, } -impl Default for VecWithNames { +impl Clone for VecWithNames { + fn clone(&self) -> Self { + Self { + values: self.values.clone(), + keys: self.keys.clone(), + } + } +} + +impl Default for VecWithNames { fn default() -> Self { Self { values: Default::default(), @@ -171,19 +180,19 @@ impl Default for VecWithNames { } } -impl std::fmt::Debug for VecWithNames { +impl std::fmt::Debug for VecWithNames { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_map().entries(self.iter()).finish() } } -impl From> for VecWithNames { +impl From> for VecWithNames { fn from(map: HashMap) -> Self { VecWithNames::from_entries(map.into_iter().collect_vec()) } } -impl VecWithNames { +impl VecWithNames { fn push(&mut self, key: String, value: T) { self.keys.push(key); self.values.push(value); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 288e3f56c..27df21314 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -61,11 +61,11 @@ impl Clone for Box { } pub(crate) fn build_segment_agg_collector( - req: &AggregationsWithAccessor, + req: &mut AggregationsWithAccessor, ) -> crate::Result> { // Single collector special case if req.aggs.len() == 1 { - let req = &req.aggs.values[0]; + let req = &mut req.aggs.values[0]; let accessor_idx = 0; return build_single_agg_segment_collector(req, accessor_idx); } @@ -75,33 +75,33 @@ pub(crate) fn build_segment_agg_collector( } pub(crate) fn build_single_agg_segment_collector( - req: &AggregationWithAccessor, + req: &mut AggregationWithAccessor, accessor_idx: usize, ) -> crate::Result> { use AggregationVariants::*; match &req.agg.agg { Terms(terms_req) => Ok(Box::new(SegmentTermCollector::from_req_and_validate( terms_req, - &req.sub_aggregation, + &mut req.sub_aggregation, req.field_type, accessor_idx, )?)), Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( range_req, - &req.sub_aggregation, - &req.limits, + &mut req.sub_aggregation, + &mut req.limits, req.field_type, accessor_idx, )?)), Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( histogram, - &req.sub_aggregation, + &mut req.sub_aggregation, req.field_type, accessor_idx, )?)), DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( &histogram.to_histogram_req()?, - &req.sub_aggregation, + &mut req.sub_aggregation, req.field_type, accessor_idx, )?)), @@ -205,14 +205,12 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { } impl GenericSegmentAggregationResultsCollector { - pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result { + pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result { let aggs = req .aggs - .iter() + .values_mut() .enumerate() - .map(|(accessor_idx, (_key, req))| { - build_single_agg_segment_collector(req, accessor_idx) - }) + .map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx)) .collect::>>>()?; Ok(GenericSegmentAggregationResultsCollector { aggs })