From 9e2faecf5b5c75bd72cf9b8009a1e9bdb6fc25d0 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Thu, 16 Mar 2023 13:21:07 +0800 Subject: [PATCH] add memory limit for aggregations (#1942) * add memory limit for aggregations introduce AggregationLimits to set memory consumption limit and bucket limits memory limit is checked during aggregation, bucket limit is checked before returning the aggregation request. * Apply suggestions from code review Co-authored-by: Paul Masurel * add ByteCount with human readable format --------- Co-authored-by: Paul Masurel --- columnar/src/dynamic_column.rs | 6 +- common/src/bitset.rs | 6 +- common/src/byte_count.rs | 108 ++++++++++++++++++ common/src/file_slice.rs | 7 +- common/src/lib.rs | 2 + examples/aggregation.rs | 6 +- src/aggregation/agg_limits.rs | 94 +++++++++++++++ src/aggregation/agg_req_with_accessor.rs | 23 ++-- src/aggregation/agg_result.rs | 57 ++++++++- src/aggregation/agg_tests.rs | 65 ++++++----- src/aggregation/bucket/histogram/histogram.rs | 100 ++++++++++++++-- src/aggregation/bucket/range.rs | 10 +- src/aggregation/collector.rs | 55 ++++----- src/aggregation/error.rs | 26 ++++- src/aggregation/intermediate_agg_result.rs | 48 ++++++-- src/aggregation/metric/mod.rs | 2 +- src/aggregation/metric/stats.rs | 6 +- src/aggregation/mod.rs | 18 ++- src/aggregation/segment_agg_result.rs | 41 +------ src/core/segment_reader.rs | 2 +- src/directory/composite_file.rs | 2 +- src/error.rs | 2 +- src/fastfield/alive_bitset.rs | 3 +- src/fastfield/mod.rs | 8 +- src/fastfield/readers.rs | 7 +- src/space_usage/mod.rs | 40 +++---- src/store/reader.rs | 5 +- 27 files changed, 556 insertions(+), 193 deletions(-) create mode 100644 common/src/byte_count.rs create mode 100644 src/aggregation/agg_limits.rs diff --git a/columnar/src/dynamic_column.rs b/columnar/src/dynamic_column.rs index 675ad99e7..31e3bab45 100644 --- a/columnar/src/dynamic_column.rs +++ b/columnar/src/dynamic_column.rs @@ -3,7 +3,7 @@ use std::net::Ipv6Addr; use std::sync::Arc; use common::file_slice::FileSlice; -use common::{DateTime, HasLen, OwnedBytes}; +use common::{ByteCount, DateTime, HasLen, OwnedBytes}; use crate::column::{BytesColumn, Column, StrColumn}; use crate::column_values::{monotonic_map_column, StrictlyMonotonicFn}; @@ -248,8 +248,8 @@ impl DynamicColumnHandle { Ok(dynamic_column) } - pub fn num_bytes(&self) -> usize { - self.file_slice.len() + pub fn num_bytes(&self) -> ByteCount { + self.file_slice.len().into() } pub fn column_type(&self) -> ColumnType { diff --git a/common/src/bitset.rs b/common/src/bitset.rs index 74d687f46..6932b0416 100644 --- a/common/src/bitset.rs +++ b/common/src/bitset.rs @@ -4,6 +4,8 @@ use std::{fmt, io, u64}; use ownedbytes::OwnedBytes; +use crate::ByteCount; + #[derive(Clone, Copy, Eq, PartialEq)] pub struct TinySet(u64); @@ -386,8 +388,8 @@ impl ReadOnlyBitSet { } /// Number of bytes used in the bitset representation. - pub fn num_bytes(&self) -> usize { - self.data.len() + pub fn num_bytes(&self) -> ByteCount { + self.data.len().into() } } diff --git a/common/src/byte_count.rs b/common/src/byte_count.rs new file mode 100644 index 000000000..1aa969dcb --- /dev/null +++ b/common/src/byte_count.rs @@ -0,0 +1,108 @@ +use std::iter::Sum; +use std::ops::{Add, AddAssign}; + +use serde::{Deserialize, Serialize}; + +/// Indicates space usage in bytes +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ByteCount(u64); + +impl std::fmt::Debug for ByteCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.human_readable()) + } +} + +const SUFFIX_AND_THRESHOLD: [(&str, u64); 5] = [ + ("KB", 1_000), + ("MB", 1_000_000), + ("GB", 1_000_000_000), + ("TB", 1_000_000_000_000), + ("PB", 1_000_000_000_000_000), +]; + +impl ByteCount { + #[inline] + pub fn get_bytes(&self) -> u64 { + self.0 + } + + pub fn human_readable(&self) -> String { + for (suffix, threshold) in SUFFIX_AND_THRESHOLD.iter().rev() { + if self.get_bytes() >= *threshold { + let unit_num = self.get_bytes() as f64 / *threshold as f64; + return format!("{:.2} {}", unit_num, suffix); + } + } + format!("{:.2} B", self.get_bytes()) + } +} + +impl From for ByteCount { + fn from(value: u64) -> Self { + ByteCount(value) + } +} +impl From for ByteCount { + fn from(value: usize) -> Self { + ByteCount(value as u64) + } +} + +impl Sum for ByteCount { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(ByteCount::default(), |acc, x| acc + x) + } +} + +impl PartialEq for ByteCount { + #[inline] + fn eq(&self, other: &u64) -> bool { + self.get_bytes() == *other + } +} + +impl PartialOrd for ByteCount { + #[inline] + fn partial_cmp(&self, other: &u64) -> Option { + self.get_bytes().partial_cmp(other) + } +} + +impl Add for ByteCount { + type Output = Self; + + #[inline] + fn add(self, other: Self) -> Self { + Self(self.get_bytes() + other.get_bytes()) + } +} + +impl AddAssign for ByteCount { + #[inline] + fn add_assign(&mut self, other: Self) { + *self = Self(self.get_bytes() + other.get_bytes()); + } +} + +#[cfg(test)] +mod test { + use crate::ByteCount; + + #[test] + fn test_bytes() { + assert_eq!(ByteCount::from(0u64).human_readable(), "0 B"); + assert_eq!(ByteCount::from(300u64).human_readable(), "300 B"); + assert_eq!(ByteCount::from(1_000_000u64).human_readable(), "1.00 MB"); + assert_eq!(ByteCount::from(1_500_000u64).human_readable(), "1.50 MB"); + assert_eq!( + ByteCount::from(1_500_000_000u64).human_readable(), + "1.50 GB" + ); + assert_eq!( + ByteCount::from(3_213_000_000_000u64).human_readable(), + "3.21 TB" + ); + } +} diff --git a/common/src/file_slice.rs b/common/src/file_slice.rs index ae4175d10..1ebe2d600 100644 --- a/common/src/file_slice.rs +++ b/common/src/file_slice.rs @@ -5,7 +5,7 @@ use std::{fmt, io}; use async_trait::async_trait; use ownedbytes::{OwnedBytes, StableDeref}; -use crate::HasLen; +use crate::{ByteCount, HasLen}; /// Objects that represents files sections in tantivy. /// @@ -216,6 +216,11 @@ impl FileSlice { pub fn slice_to(&self, to_offset: usize) -> FileSlice { self.slice(0..to_offset) } + + /// Returns the byte count of the FileSlice. + pub fn num_bytes(&self) -> ByteCount { + self.range.len().into() + } } #[async_trait] diff --git a/common/src/lib.rs b/common/src/lib.rs index 93e44a82b..68a29ed4d 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -5,6 +5,7 @@ use std::ops::Deref; pub use byteorder::LittleEndian as Endianness; mod bitset; +mod byte_count; mod datetime; pub mod file_slice; mod group_by; @@ -12,6 +13,7 @@ mod serialize; mod vint; mod writer; pub use bitset::*; +pub use byte_count::ByteCount; pub use datetime::{DatePrecision, DateTime}; pub use group_by::GroupByIteratorExtended; pub use ownedbytes::{OwnedBytes, StableDeref}; diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 5ff62c717..24c11513c 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -192,7 +192,7 @@ fn main() -> tantivy::Result<()> { // let agg_req: Aggregations = serde_json::from_str(agg_req_str)?; - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, Default::default()); let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); let res2: Value = serde_json::to_value(agg_res)?; @@ -239,7 +239,7 @@ fn main() -> tantivy::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, Default::default()); // We use the `AllQuery` which will pass all documents to the AggregationCollector. let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -287,7 +287,7 @@ fn main() -> tantivy::Result<()> { let agg_req: Aggregations = serde_json::from_str(agg_req_str)?; - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, Default::default()); let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); let res: Value = serde_json::to_value(agg_res)?; diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs new file mode 100644 index 000000000..c55dec02e --- /dev/null +++ b/src/aggregation/agg_limits.rs @@ -0,0 +1,94 @@ +use std::collections::HashMap; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +use common::ByteCount; + +use super::collector::DEFAULT_MEMORY_LIMIT; +use super::{AggregationError, DEFAULT_BUCKET_LIMIT}; +use crate::TantivyError; + +/// An estimate for memory consumption +pub trait MemoryConsumption { + fn memory_consumption(&self) -> usize; +} + +impl MemoryConsumption for HashMap { + fn memory_consumption(&self) -> usize { + let num_items = self.capacity(); + (std::mem::size_of::() + std::mem::size_of::()) * num_items + } +} + +/// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT +/// (500MB). The limit is shared by all SegmentCollectors +pub struct AggregationLimits { + /// The counter which is shared between the aggregations for one request. + memory_consumption: Arc, + /// The memory_limit in bytes + memory_limit: ByteCount, + /// The maximum number of buckets _returned_ + /// This is not counting intermediate buckets. + bucket_limit: u32, +} +impl Clone for AggregationLimits { + fn clone(&self) -> Self { + Self { + memory_consumption: Arc::clone(&self.memory_consumption), + memory_limit: self.memory_limit, + bucket_limit: self.bucket_limit, + } + } +} + +impl Default for AggregationLimits { + fn default() -> Self { + Self { + memory_consumption: Default::default(), + memory_limit: DEFAULT_MEMORY_LIMIT.into(), + bucket_limit: DEFAULT_BUCKET_LIMIT, + } + } +} + +impl AggregationLimits { + /// *memory_limit* + /// memory_limit is defined in bytes. + /// Aggregation fails when the estimated memory consumption of the aggregation is higher than + /// memory_limit. + /// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB) + /// + /// *bucket_limit* + /// Limits the maximum number of buckets returned from an aggregation request. + /// bucket_limit will default to `DEFAULT_BUCKET_LIMIT` (65000) + pub fn new(memory_limit: Option, bucket_limit: Option) -> Self { + Self { + memory_consumption: Default::default(), + memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(), + bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT), + } + } + pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> { + if self.get_memory_consumed() > self.memory_limit { + return Err(TantivyError::AggregationError( + AggregationError::MemoryExceeded { + limit: self.memory_limit, + current: self.get_memory_consumed(), + }, + )); + } + Ok(()) + } + pub(crate) fn add_memory_consumed(&self, num_bytes: u64) { + self.memory_consumption + .fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed); + } + pub fn get_memory_consumed(&self) -> ByteCount { + self.memory_consumption + .load(std::sync::atomic::Ordering::Relaxed) + .into() + } + pub fn get_bucket_limit(&self) -> u32 { + self.bucket_limit + } +} diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 66c7d1afc..c5d6c21b6 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,7 +1,5 @@ //! This will enhance the request tree with access to the fastfield and metadata. -use std::rc::Rc; -use std::sync::atomic::AtomicU32; use std::sync::Arc; use columnar::{Column, ColumnType, ColumnValues, StrColumn}; @@ -14,7 +12,7 @@ use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, }; -use super::segment_agg_result::BucketCount; +use super::segment_agg_result::AggregationLimits; use super::VecWithNames; use crate::SegmentReader; @@ -46,7 +44,7 @@ pub struct BucketAggregationWithAccessor { pub(crate) field_type: ColumnType, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) bucket_count: BucketCount, + pub(crate) limits: AggregationLimits, } impl BucketAggregationWithAccessor { @@ -54,8 +52,7 @@ impl BucketAggregationWithAccessor { bucket: &BucketAggregationType, sub_aggregation: &Aggregations, reader: &SegmentReader, - bucket_count: Rc, - max_bucket_count: u32, + limits: AggregationLimits, ) -> crate::Result { let mut str_dict_column = None; let (accessor, field_type) = match &bucket { @@ -83,15 +80,11 @@ impl BucketAggregationWithAccessor { sub_aggregation: get_aggs_with_accessor_and_validate( &sub_aggregation, reader, - bucket_count.clone(), - max_bucket_count, + &limits.clone(), )?, bucket_agg: bucket.clone(), str_dict_column, - bucket_count: BucketCount { - bucket_count, - max_bucket_count, - }, + limits, }) } } @@ -131,8 +124,7 @@ impl MetricAggregationWithAccessor { pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, - bucket_count: Rc, - max_bucket_count: u32, + limits: &AggregationLimits, ) -> crate::Result { let mut metrics = vec![]; let mut buckets = vec![]; @@ -144,8 +136,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate( &bucket.bucket_agg, &bucket.sub_aggregation, reader, - Rc::clone(&bucket_count), - max_bucket_count, + limits.clone(), )?, )), Aggregation::Metric(metric) => metrics.push(( diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 6ce7e6749..d47a7d07b 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -11,6 +11,7 @@ use super::agg_req::BucketAggregationInternal; use super::bucket::GetDocCount; use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult}; use super::metric::{SingleMetricResult, Stats}; +use super::segment_agg_result::AggregationLimits; use super::Key; use crate::TantivyError; @@ -19,6 +20,13 @@ use crate::TantivyError; pub struct AggregationResults(pub FxHashMap); impl AggregationResults { + pub(crate) fn get_bucket_count(&self) -> u64 { + self.0 + .values() + .map(|agg| agg.get_bucket_count()) + .sum::() + } + pub(crate) fn get_value_from_aggregation( &self, name: &str, @@ -47,6 +55,13 @@ pub enum AggregationResult { } impl AggregationResult { + pub(crate) fn get_bucket_count(&self) -> u64 { + match self { + AggregationResult::BucketResult(bucket) => bucket.get_bucket_count(), + AggregationResult::MetricResult(_) => 0, + } + } + pub(crate) fn get_value_from_aggregation( &self, _name: &str, @@ -153,9 +168,28 @@ pub enum BucketResult { } impl BucketResult { - pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result { + pub(crate) fn get_bucket_count(&self) -> u64 { + match self { + BucketResult::Range { buckets } => { + buckets.iter().map(|bucket| bucket.get_bucket_count()).sum() + } + BucketResult::Histogram { buckets } => { + buckets.iter().map(|bucket| bucket.get_bucket_count()).sum() + } + BucketResult::Terms { + buckets, + sum_other_doc_count: _, + doc_count_error_upper_bound: _, + } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), + } + } + + pub(crate) fn empty_from_req( + req: &BucketAggregationInternal, + limits: &AggregationLimits, + ) -> crate::Result { let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - empty_bucket.into_final_bucket_result(req) + empty_bucket.into_final_bucket_result(req, limits) } } @@ -170,6 +204,15 @@ pub enum BucketEntries { HashMap(FxHashMap), } +impl BucketEntries { + fn iter<'a>(&'a self) -> Box + 'a> { + match self { + BucketEntries::Vec(vec) => Box::new(vec.iter()), + BucketEntries::HashMap(map) => Box::new(map.values()), + } + } +} + /// This is the default entry for a bucket, which contains a key, count, and optionally /// sub-aggregations. /// @@ -209,6 +252,11 @@ pub struct BucketEntry { /// Sub-aggregations in this bucket. pub sub_aggregation: AggregationResults, } +impl BucketEntry { + pub(crate) fn get_bucket_count(&self) -> u64 { + 1 + self.sub_aggregation.get_bucket_count() + } +} impl GetDocCount for &BucketEntry { fn doc_count(&self) -> u64 { self.doc_count @@ -272,3 +320,8 @@ pub struct RangeBucketEntry { #[serde(skip_serializing_if = "Option::is_none")] pub to_as_string: Option, } +impl RangeBucketEntry { + pub(crate) fn get_bucket_count(&self) -> u64 { + 1 + self.sub_aggregation.get_bucket_count() + } +} diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 818c43348..8819ddc52 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -9,6 +9,7 @@ use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::metric::AverageAggregation; +use crate::aggregation::segment_agg_result::AggregationLimits; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -21,6 +22,10 @@ fn get_avg_req(field_name: &str) -> Aggregation { )) } +fn get_collector(agg_req: Aggregations) -> AggregationCollector { + AggregationCollector::from_aggs(agg_req, Default::default()) +} + // *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE *** fn test_aggregation_flushing( merge_segments: bool, @@ -98,15 +103,18 @@ fn test_aggregation_flushing( .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); + let collector = DistributedAggregationCollector::from_aggs( + agg_req.clone(), + AggregationLimits::default(), + ); let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); intermediate_agg_result - .into_final_bucket_result(agg_req) + .into_final_bucket_result(agg_req, &Default::default()) .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -243,7 +251,7 @@ fn test_aggregation_level1() -> crate::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); @@ -432,16 +440,18 @@ fn test_aggregation_level2( }; let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); + let collector = + DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default()); let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); // Test de/serialization roundtrip on intermediate_agg_result let res: IntermediateAggregationResults = serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); - res.into_final_bucket_result(agg_req.clone()).unwrap() + res.into_final_bucket_result(agg_req.clone(), &Default::default()) + .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req.clone(), None); + let collector = get_collector(agg_req.clone()); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -499,7 +509,7 @@ fn test_aggregation_level2( ); // Test empty result set - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&query_with_no_hits, &collector).unwrap(); @@ -562,7 +572,7 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); @@ -620,7 +630,7 @@ fn test_aggregation_on_json_object() { )] .into_iter() .collect(); - let aggregation_collector = AggregationCollector::from_aggs(agg, None); + let aggregation_collector = get_collector(agg); let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); assert_eq!( @@ -690,7 +700,7 @@ fn test_aggregation_on_json_object_empty_columns() { .into_iter() .collect(); - let aggregation_collector = AggregationCollector::from_aggs(agg, None); + let aggregation_collector = get_collector(agg); let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); assert_eq!( @@ -721,9 +731,8 @@ fn test_aggregation_on_json_object_empty_columns() { } } "#; let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap(); - let aggregation_results = searcher - .search(&AllQuery, &AggregationCollector::from_aggs(agg, None)) - .unwrap(); + let aggregation_collector = get_collector(agg); + let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); assert_eq!( &aggregation_res_json, @@ -883,7 +892,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -912,7 +921,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -941,7 +950,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -978,7 +987,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -1008,7 +1017,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1047,7 +1056,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1077,7 +1086,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1111,7 +1120,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = get_collector(agg_req); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1149,7 +1158,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1196,7 +1205,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1236,7 +1245,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() }); @@ -1275,7 +1284,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1306,7 +1315,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -1364,7 +1373,7 @@ mod bench { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = get_collector(agg_req_1); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index a1ab3ddb6..ee5f4b1c5 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -7,6 +7,7 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; +use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::AggregationsInternal; use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, @@ -16,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, SegmentAggregationCollector, + build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames}; use crate::{DocId, TantivyError}; @@ -249,6 +250,8 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { let sub_aggregation_accessor = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + let mem_pre = self.get_memory_consumption(); + let bounds = self.bounds; let interval = self.interval; let offset = self.offset; @@ -271,6 +274,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { } } } + + let mem_delta = self.get_memory_consumption() - mem_pre; + let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits; + limits.add_memory_consumed(mem_delta as u64); + limits.validate_memory_consumption()?; + Ok(()) } @@ -287,6 +296,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { } impl SegmentHistogramCollector { + fn get_memory_consumption(&self) -> usize { + let self_mem = std::mem::size_of::(); + let sub_aggs_mem = self.sub_aggregations.memory_consumption(); + let buckets_mem = self.buckets.memory_consumption(); + self_mem + sub_aggs_mem + buckets_mem + } pub fn into_intermediate_bucket_result( self, agg_with_accessor: &BucketAggregationWithAccessor, @@ -389,6 +404,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, + limits: &AggregationLimits, ) -> crate::Result> { // Generate the full list of buckets without gaps. // @@ -396,7 +412,17 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( // extended_bounds from the request let min_max = minmax(buckets.iter().map(|bucket| bucket.key)); - // TODO add memory check + // memory check upfront + let (_, first_bucket_num, last_bucket_num) = + generate_bucket_pos_with_opt_minmax(histogram_req, min_max); + let added_buckets = (first_bucket_num..=last_bucket_num) + .count() + .saturating_sub(buckets.len()); + limits.add_memory_consumed( + added_buckets as u64 * std::mem::size_of::() as u64, + ); + limits.validate_memory_consumption()?; + // create buckets let fill_gaps_buckets = generate_buckets_with_opt_minmax(histogram_req, min_max); let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(sub_aggregation); @@ -425,7 +451,9 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( sub_aggregation: empty_sub_aggregation.clone(), }, }) - .map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation)) + .map(|intermediate_bucket| { + intermediate_bucket.into_final_bucket_entry(sub_aggregation, limits) + }) .collect::>>() } @@ -435,18 +463,26 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( column_type: Option, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, + limits: &AggregationLimits, ) -> crate::Result> { let mut buckets = if histogram_req.min_doc_count() == 0 { // With min_doc_count != 0, we may need to add buckets, so that there are no // gaps, since intermediate result does not contain empty buckets (filtered to // reduce serialization size). - intermediate_buckets_to_final_buckets_fill_gaps(buckets, histogram_req, sub_aggregation)? + intermediate_buckets_to_final_buckets_fill_gaps( + buckets, + histogram_req, + sub_aggregation, + limits, + )? } else { buckets .into_iter() .filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count()) - .map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation)) + .map(|histogram_bucket| { + histogram_bucket.into_final_bucket_entry(sub_aggregation, limits) + }) .collect::>>()? }; @@ -485,15 +521,27 @@ fn get_req_min_max(req: &HistogramAggregation, min_max: Option<(f64, f64)>) -> ( /// Generates buckets with req.interval /// Range is computed for provided min_max and request extended_bounds/hard_bounds /// returns empty vec when there is no range to span -pub(crate) fn generate_buckets_with_opt_minmax( +pub(crate) fn generate_bucket_pos_with_opt_minmax( req: &HistogramAggregation, min_max: Option<(f64, f64)>, -) -> Vec { +) -> (f64, i64, i64) { let (min, max) = get_req_min_max(req, min_max); let offset = req.offset.unwrap_or(0.0); let first_bucket_num = get_bucket_pos_f64(min, req.interval, offset) as i64; let last_bucket_num = get_bucket_pos_f64(max, req.interval, offset) as i64; + (offset, first_bucket_num, last_bucket_num) +} + +/// Generates buckets with req.interval +/// Range is computed for provided min_max and request extended_bounds/hard_bounds +/// returns empty vec when there is no range to span +pub(crate) fn generate_buckets_with_opt_minmax( + req: &HistogramAggregation, + min_max: Option<(f64, f64)>, +) -> Vec { + let (offset, first_bucket_num, last_bucket_num) = + generate_bucket_pos_with_opt_minmax(req, min_max); let mut buckets = Vec::with_capacity((first_bucket_num..=last_bucket_num).count()); for bucket_pos in first_bucket_num..=last_bucket_num { let bucket_key = bucket_pos as f64 * req.interval + offset; @@ -515,8 +563,8 @@ mod tests { }; use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; use crate::aggregation::tests::{ - exec_request, exec_request_with_query, get_test_index_2_segments, - get_test_index_from_values, get_test_index_with_num_docs, + exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, + get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs, }; #[test] @@ -661,6 +709,40 @@ mod tests { Ok(()) } + #[test] + fn histogram_memory_limit() -> crate::Result<()> { + let index = get_test_index_with_num_docs(true, 100)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(Box::new(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 0.1, + ..Default::default() + }), + sub_aggregation: Default::default(), + })), + )] + .into_iter() + .collect(); + + let res = exec_request_with_query_and_memory_limit( + agg_req, + &index, + None, + AggregationLimits::new(Some(5_000), None), + ) + .unwrap_err(); + assert_eq!( + res.to_string(), + "Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \ + 102.48 KB" + ); + + Ok(()) + } + #[test] fn histogram_merge_test() -> crate::Result<()> { // Merge buckets counts from different segments diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 78b7316d4..87453730f 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -11,7 +11,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateRangeBucketResult, }; use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, BucketCount, SegmentAggregationCollector, + build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; use crate::aggregation::{ f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames, @@ -260,7 +260,7 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &AggregationsWithAccessor, - bucket_count: &BucketCount, + limits: &AggregationLimits, field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { @@ -304,8 +304,10 @@ impl SegmentRangeCollector { }) .collect::>()?; - bucket_count.add_count(buckets.len() as u32); - bucket_count.validate_bucket_count()?; + limits.add_memory_consumed( + buckets.len() as u64 * std::mem::size_of::() as u64, + ); + limits.validate_memory_consumption()?; Ok(SegmentRangeCollector { buckets, diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index b00c03690..ffbf1f82f 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,36 +1,36 @@ -use std::rc::Rc; - use super::agg_req::Aggregations; use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; use super::buf_collector::BufAggregationCollector; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::{build_segment_agg_collector, SegmentAggregationCollector}; +use super::segment_agg_result::{ + build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, +}; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; use crate::{SegmentReader, TantivyError}; /// The default max bucket count, before the aggregation fails. -pub const MAX_BUCKET_COUNT: u32 = 65000; +pub const DEFAULT_BUCKET_LIMIT: u32 = 65000; + +/// The default memory limit in bytes before the aggregation fails. 500MB +pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000; /// Collector for aggregations. /// /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, - max_bucket_count: u32, + limits: AggregationLimits, } impl AggregationCollector { /// Create collector from aggregation request. /// - /// Aggregation fails when the total bucket count is higher than max_bucket_count. - /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset - pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { - Self { - agg, - max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), - } + /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and + /// bucket limit) + pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + Self { agg, limits } } } @@ -44,18 +44,16 @@ impl AggregationCollector { /// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, - max_bucket_count: u32, + limits: AggregationLimits, } impl DistributedAggregationCollector { /// Create collector from aggregation request. /// - /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset - pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { - Self { - agg, - max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), - } + /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and + /// bucket limit) + pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self { + Self { agg, limits } } } @@ -69,11 +67,7 @@ impl Collector for DistributedAggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader( - &self.agg, - reader, - self.max_bucket_count, - ) + AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) } fn requires_scoring(&self) -> bool { @@ -98,11 +92,7 @@ impl Collector for AggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader( - &self.agg, - reader, - self.max_bucket_count, - ) + AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) } fn requires_scoring(&self) -> bool { @@ -114,7 +104,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - res.into_final_bucket_result(self.agg.clone()) + res.into_final_bucket_result(self.agg.clone(), &self.limits) } } @@ -145,10 +135,9 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, - max_bucket_count: u32, + limits: &AggregationLimits, ) -> crate::Result { - let aggs_with_accessor = - get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?; + let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, limits)?; let result = BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?); Ok(AggregationSegmentCollector { diff --git a/src/aggregation/error.rs b/src/aggregation/error.rs index b04d07861..a2b864c33 100644 --- a/src/aggregation/error.rs +++ b/src/aggregation/error.rs @@ -1,9 +1,33 @@ +use common::ByteCount; + use super::bucket::DateHistogramParseError; /// Error that may occur when opening a directory #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum AggregationError { - /// Failed to open the directory. + /// Date histogram parse error #[error("Date histogram parse error: {0:?}")] DateHistogramParseError(#[from] DateHistogramParseError), + /// Memory limit exceeded + #[error( + "Aborting aggregation because memory limit was exceeded. Limit: {limit:?}, Current: \ + {current:?}" + )] + MemoryExceeded { + /// Memory consumption limit + limit: ByteCount, + /// Current memory consumption + current: ByteCount, + }, + /// Bucket limit exceeded + #[error( + "Aborting aggregation because bucket limit was exceeded. Limit: {limit:?}, Current: \ + {current:?}" + )] + BucketLimitExceeded { + /// Bucket limit + limit: u32, + /// Current num buckets + current: u32, + }, } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 7088ac148..81dd64411 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -22,9 +22,11 @@ use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, IntermediateSum, }; -use super::{format_date, Key, SerializedKey, VecWithNames}; +use super::segment_agg_result::AggregationLimits; +use super::{format_date, AggregationError, Key, SerializedKey, VecWithNames}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; +use crate::TantivyError; /// Contains the intermediate aggregation result, which is optimized to be merged with other /// intermediate results. @@ -38,8 +40,23 @@ pub struct IntermediateAggregationResults { impl IntermediateAggregationResults { /// Convert intermediate result and its aggregation request to the final result. - pub fn into_final_bucket_result(self, req: Aggregations) -> crate::Result { - self.into_final_bucket_result_internal(&(req.into())) + pub fn into_final_bucket_result( + self, + req: Aggregations, + limits: &AggregationLimits, + ) -> crate::Result { + // TODO count and validate buckets + let res = self.into_final_bucket_result_internal(&(req.into()), limits)?; + let bucket_count = res.get_bucket_count() as u32; + if bucket_count > limits.get_bucket_limit() { + return Err(TantivyError::AggregationError( + AggregationError::BucketLimitExceeded { + limit: limits.get_bucket_limit(), + current: bucket_count, + }, + )); + } + Ok(res) } /// Convert intermediate result and its aggregation request to the final result. @@ -49,6 +66,7 @@ impl IntermediateAggregationResults { pub(crate) fn into_final_bucket_result_internal( self, req: &AggregationsInternal, + limits: &AggregationLimits, ) -> crate::Result { // Important assumption: // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the @@ -56,11 +74,11 @@ impl IntermediateAggregationResults { let mut results: FxHashMap = FxHashMap::default(); if let Some(buckets) = self.buckets { - convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)? + convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)? } else { // When there are no buckets, we create empty buckets, so that the serialized json // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets)? + add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)? }; if let Some(metrics) = self.metrics { @@ -161,10 +179,12 @@ fn add_empty_final_metrics_to_result( fn add_empty_final_buckets_to_result( results: &mut FxHashMap, req_buckets: &VecWithNames, + limits: &AggregationLimits, ) -> crate::Result<()> { let requested_buckets = req_buckets.iter(); for (key, req) in requested_buckets { - let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); + let empty_bucket = + AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?); results.insert(key.to_string(), empty_bucket); } Ok(()) @@ -174,12 +194,13 @@ fn convert_and_add_final_buckets_to_result( results: &mut FxHashMap, buckets: VecWithNames, req_buckets: &VecWithNames, + limits: &AggregationLimits, ) -> crate::Result<()> { assert_eq!(buckets.len(), req_buckets.len()); let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); for ((key, bucket), req) in buckets_with_request { - let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?); + let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?); results.insert(key, result); } Ok(()) @@ -287,6 +308,7 @@ impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, req: &BucketAggregationInternal, + limits: &AggregationLimits, ) -> crate::Result { match self { IntermediateBucketResult::Range(range_res) => { @@ -299,6 +321,7 @@ impl IntermediateBucketResult { req.as_range() .expect("unexpected aggregation, expected histogram aggregation"), range_res.column_type, + limits, ) }) .collect::>>()?; @@ -337,6 +360,7 @@ impl IntermediateBucketResult { column_type, histogram_req, &req.sub_aggregation, + limits, )?; let buckets = if histogram_req.keyed { @@ -355,6 +379,7 @@ impl IntermediateBucketResult { req.as_term() .expect("unexpected aggregation, expected term aggregation"), &req.sub_aggregation, + limits, ), } } @@ -449,6 +474,7 @@ impl IntermediateTermBucketResult { self, req: &TermsAggregation, sub_aggregation_req: &AggregationsInternal, + limits: &AggregationLimits, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); let mut buckets: Vec = self @@ -462,7 +488,7 @@ impl IntermediateTermBucketResult { doc_count: entry.doc_count, sub_aggregation: entry .sub_aggregation - .into_final_bucket_result_internal(sub_aggregation_req)?, + .into_final_bucket_result_internal(sub_aggregation_req, limits)?, }) }) .collect::>()?; @@ -582,6 +608,7 @@ impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, req: &AggregationsInternal, + limits: &AggregationLimits, ) -> crate::Result { Ok(BucketEntry { key_as_string: None, @@ -589,7 +616,7 @@ impl IntermediateHistogramBucketEntry { doc_count: self.doc_count, sub_aggregation: self .sub_aggregation - .into_final_bucket_result_internal(req)?, + .into_final_bucket_result_internal(req, limits)?, }) } } @@ -628,13 +655,14 @@ impl IntermediateRangeBucketEntry { req: &AggregationsInternal, _range_req: &RangeAggregation, column_type: Option, + limits: &AggregationLimits, ) -> crate::Result { let mut range_bucket_entry = RangeBucketEntry { key: self.key, doc_count: self.doc_count, sub_aggregation: self .sub_aggregation - .into_final_bucket_result_internal(req)?, + .into_final_bucket_result_internal(req, limits)?, to: self.to, from: self.from, to_as_string: None, diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index fd2e6ae23..4812f2062 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -81,7 +81,7 @@ mod tests { "price_sum": { "sum": { "field": "price" } } }"#; let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap(); - let collector = AggregationCollector::from_aggs(aggregations, None); + let collector = AggregationCollector::from_aggs(aggregations, Default::default()); let reader = index.reader().unwrap(); let searcher = reader.searcher(); let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 099207cca..b5ee41a1a 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -294,7 +294,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); let reader = index.reader()?; let searcher = reader.searcher(); @@ -331,7 +331,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); let reader = index.reader()?; let searcher = reader.searcher(); @@ -411,7 +411,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 0cc786d86..6f0e04d18 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -70,7 +70,7 @@ //! .into_iter() //! .collect(); //! -//! let collector = AggregationCollector::from_aggs(agg_req, None); +//! let collector = AggregationCollector::from_aggs(agg_req, Default::default()); //! //! let searcher = reader.searcher(); //! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -155,6 +155,7 @@ //! [`AggregationResults`](agg_result::AggregationResults) via the //! [`into_final_bucket_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_bucket_result) method. +mod agg_limits; pub mod agg_req; mod agg_req_with_accessor; pub mod agg_result; @@ -165,6 +166,7 @@ mod date; mod error; pub mod intermediate_agg_result; pub mod metric; + mod segment_agg_result; use std::collections::HashMap; use std::fmt::Display; @@ -174,7 +176,7 @@ mod agg_tests; pub use collector::{ AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector, - MAX_BUCKET_COUNT, + DEFAULT_BUCKET_LIMIT, }; use columnar::{ColumnType, MonotonicallyMappableToU64}; pub(crate) use date::format_date; @@ -345,6 +347,7 @@ mod tests { use time::OffsetDateTime; use super::agg_req::Aggregations; + use super::segment_agg_result::AggregationLimits; use super::*; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, TermQuery}; @@ -369,7 +372,16 @@ mod tests { index: &Index, query: Option<(&str, &str)>, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req, None); + exec_request_with_query_and_memory_limit(agg_req, index, query, Default::default()) + } + + pub fn exec_request_with_query_and_memory_limit( + agg_req: Aggregations, + index: &Index, + query: Option<(&str, &str)>, + limits: AggregationLimits, + ) -> crate::Result { + let collector = AggregationCollector::from_aggs(agg_req, limits); let reader = index.reader()?; let searcher = reader.searcher(); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index b91a51437..6b8400119 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -4,15 +4,13 @@ //! merging. use std::fmt::Debug; -use std::rc::Rc; -use std::sync::atomic::AtomicU32; +pub(crate) use super::agg_limits::AggregationLimits; use super::agg_req::MetricAggregation; use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; -use super::collector::MAX_BUCKET_COUNT; use super::intermediate_agg_result::IntermediateAggregationResults; use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector, @@ -20,7 +18,6 @@ use super::metric::{ }; use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; -use crate::TantivyError; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn into_intermediate_aggregations_result( @@ -131,7 +128,7 @@ pub(crate) fn build_bucket_segment_agg_collector( Ok(Box::new(SegmentRangeCollector::from_req_and_validate( range_req, &req.sub_aggregation, - &req.bucket_count, + &req.limits, req.field_type, accessor_idx, )?)) @@ -284,37 +281,3 @@ impl GenericSegmentAggregationResultsCollector { Ok(GenericSegmentAggregationResultsCollector { metrics, buckets }) } } - -#[derive(Clone)] -pub(crate) struct BucketCount { - /// The counter which is shared between the aggregations for one request. - pub(crate) bucket_count: Rc, - pub(crate) max_bucket_count: u32, -} - -impl Default for BucketCount { - fn default() -> Self { - Self { - bucket_count: Default::default(), - max_bucket_count: MAX_BUCKET_COUNT, - } - } -} - -impl BucketCount { - pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> { - if self.get_count() > self.max_bucket_count { - return Err(TantivyError::InvalidArgument( - "Aborting aggregation because too many buckets were created".to_string(), - )); - } - Ok(()) - } - pub(crate) fn add_count(&self, count: u32) { - self.bucket_count - .fetch_add(count, std::sync::atomic::Ordering::Relaxed); - } - pub(crate) fn get_count(&self) -> u32 { - self.bucket_count.load(std::sync::atomic::Ordering::Relaxed) - } -} diff --git a/src/core/segment_reader.rs b/src/core/segment_reader.rs index 698ea2f12..f42610e12 100644 --- a/src/core/segment_reader.rs +++ b/src/core/segment_reader.rs @@ -327,7 +327,7 @@ impl SegmentReader { self.alive_bitset_opt .as_ref() .map(AliveBitSet::space_usage) - .unwrap_or(0), + .unwrap_or_default(), )) } } diff --git a/src/directory/composite_file.rs b/src/directory/composite_file.rs index d4af43a89..d33b67b95 100644 --- a/src/directory/composite_file.rs +++ b/src/directory/composite_file.rs @@ -172,7 +172,7 @@ impl CompositeFile { let mut fields = Vec::new(); for (&field_addr, byte_range) in &self.offsets_index { let mut field_usage = FieldUsage::empty(field_addr.field); - field_usage.add_field_idx(field_addr.idx, byte_range.len()); + field_usage.add_field_idx(field_addr.idx, byte_range.len().into()); fields.push(field_usage); } PerFieldSpaceUsage::new(fields) diff --git a/src/error.rs b/src/error.rs index 816074a19..0089e550c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,7 +55,7 @@ impl fmt::Debug for DataCorruption { #[derive(Debug, Clone, Error)] pub enum TantivyError { /// Error when handling aggregations. - #[error("AggregationError {0:?}")] + #[error(transparent)] AggregationError(#[from] AggregationError), /// Failed to open the directory. #[error("Failed to open the directory: '{0:?}'")] diff --git a/src/fastfield/alive_bitset.rs b/src/fastfield/alive_bitset.rs index 24391d773..11d7463c7 100644 --- a/src/fastfield/alive_bitset.rs +++ b/src/fastfield/alive_bitset.rs @@ -1,9 +1,8 @@ use std::io; use std::io::Write; -use common::{intersect_bitsets, BitSet, OwnedBytes, ReadOnlyBitSet}; +use common::{intersect_bitsets, BitSet, ByteCount, OwnedBytes, ReadOnlyBitSet}; -use crate::space_usage::ByteCount; use crate::DocId; /// Write an alive `BitSet` diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index 92e6850ab..81b893694 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -80,7 +80,7 @@ mod tests { use std::path::Path; use columnar::{Column, MonotonicallyMappableToU64, StrColumn}; - use common::{HasLen, TerminatingWrite}; + use common::{ByteCount, HasLen, TerminatingWrite}; use once_cell::sync::Lazy; use rand::prelude::SliceRandom; use rand::rngs::StdRng; @@ -862,16 +862,16 @@ mod tests { #[test] pub fn test_gcd_date() { let size_prec_sec = test_gcd_date_with_codec(DatePrecision::Seconds); - assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec)); // 13 bits per val = ceil(log_2(number of seconds in 2hours); + assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec.get_bytes())); // 13 bits per val = ceil(log_2(number of seconds in 2hours); let size_prec_micros = test_gcd_date_with_codec(DatePrecision::Microseconds); - assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros)); + assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros.get_bytes())); // 33 bits per // val = ceil(log_2(number // of microsecsseconds // in 2hours); } - fn test_gcd_date_with_codec(precision: DatePrecision) -> usize { + fn test_gcd_date_with_codec(precision: DatePrecision) -> ByteCount { let mut rng = StdRng::seed_from_u64(2u64); const T0: i64 = 1_662_345_825_012_529i64; const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000; diff --git a/src/fastfield/readers.rs b/src/fastfield/readers.rs index d394aff70..ea9dafa22 100644 --- a/src/fastfield/readers.rs +++ b/src/fastfield/readers.rs @@ -6,6 +6,7 @@ use columnar::{ BytesColumn, Column, ColumnType, ColumnValues, ColumnarReader, DynamicColumn, DynamicColumnHandle, HasAssociatedColumnType, StrColumn, }; +use common::ByteCount; use crate::core::json_utils::encode_column_name; use crate::directory::FileSlice; @@ -42,7 +43,7 @@ impl FastFieldReaders { let mut per_field_usages: Vec = Default::default(); for (field, field_entry) in schema.fields() { let column_handles = self.columnar.read_columns(field_entry.name())?; - let num_bytes: usize = column_handles + let num_bytes: ByteCount = column_handles .iter() .map(|column_handle| column_handle.num_bytes()) .sum(); @@ -136,9 +137,9 @@ impl FastFieldReaders { /// Returns the number of `bytes` associated with a column. /// /// Returns 0 if the column does not exist. - pub fn column_num_bytes(&self, field: &str) -> crate::Result { + pub fn column_num_bytes(&self, field: &str) -> crate::Result { let Some(resolved_field_name) = self.resolve_field(field)? else { - return Ok(0); + return Ok(0u64.into()); }; Ok(self .columnar diff --git a/src/space_usage/mod.rs b/src/space_usage/mod.rs index 8c203f924..ff0af1c89 100644 --- a/src/space_usage/mod.rs +++ b/src/space_usage/mod.rs @@ -9,14 +9,12 @@ use std::collections::HashMap; +use common::ByteCount; use serde::{Deserialize, Serialize}; use crate::schema::Field; use crate::SegmentComponent; -/// Indicates space usage in bytes -pub type ByteCount = usize; - /// Enum containing any of the possible space usage results for segment components. pub enum ComponentSpaceUsage { /// Data is stored per field in a uniform way @@ -38,7 +36,7 @@ impl SearcherSpaceUsage { pub(crate) fn new() -> SearcherSpaceUsage { SearcherSpaceUsage { segments: Vec::new(), - total: 0, + total: Default::default(), } } @@ -260,7 +258,7 @@ impl FieldUsage { pub(crate) fn empty(field: Field) -> FieldUsage { FieldUsage { field, - num_bytes: 0, + num_bytes: Default::default(), sub_num_bytes: Vec::new(), } } @@ -294,7 +292,7 @@ impl FieldUsage { mod test { use crate::core::Index; use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT}; - use crate::space_usage::{ByteCount, PerFieldSpaceUsage}; + use crate::space_usage::PerFieldSpaceUsage; use crate::Term; #[test] @@ -304,14 +302,14 @@ mod test { let reader = index.reader().unwrap(); let searcher = reader.searcher(); let searcher_space_usage = searcher.space_usage().unwrap(); - assert_eq!(0, searcher_space_usage.total()); + assert_eq!(searcher_space_usage.total(), 0u64); } fn expect_single_field( field_space: &PerFieldSpaceUsage, field: &Field, - min_size: ByteCount, - max_size: ByteCount, + min_size: u64, + max_size: u64, ) { assert!(field_space.total() >= min_size); assert!(field_space.total() <= max_size); @@ -353,12 +351,12 @@ mod test { expect_single_field(segment.termdict(), &name, 1, 512); expect_single_field(segment.postings(), &name, 1, 512); - assert_eq!(0, segment.positions().total()); + assert_eq!(segment.positions().total(), 0); expect_single_field(segment.fast_fields(), &name, 1, 512); expect_single_field(segment.fieldnorms(), &name, 1, 512); // TODO: understand why the following fails // assert_eq!(0, segment.store().total()); - assert_eq!(0, segment.deletes()); + assert_eq!(segment.deletes(), 0); Ok(()) } @@ -394,11 +392,11 @@ mod test { expect_single_field(segment.termdict(), &name, 1, 512); expect_single_field(segment.postings(), &name, 1, 512); expect_single_field(segment.positions(), &name, 1, 512); - assert_eq!(0, segment.fast_fields().total()); + assert_eq!(segment.fast_fields().total(), 0); expect_single_field(segment.fieldnorms(), &name, 1, 512); // TODO: understand why the following fails // assert_eq!(0, segment.store().total()); - assert_eq!(0, segment.deletes()); + assert_eq!(segment.deletes(), 0); Ok(()) } @@ -430,14 +428,14 @@ mod test { assert_eq!(4, segment.num_docs()); - assert_eq!(0, segment.termdict().total()); - assert_eq!(0, segment.postings().total()); - assert_eq!(0, segment.positions().total()); - assert_eq!(0, segment.fast_fields().total()); - assert_eq!(0, segment.fieldnorms().total()); + assert_eq!(segment.termdict().total(), 0); + assert_eq!(segment.postings().total(), 0); + assert_eq!(segment.positions().total(), 0); + assert_eq!(segment.fast_fields().total(), 0); + assert_eq!(segment.fieldnorms().total(), 0); assert!(segment.store().total() > 0); assert!(segment.store().total() < 512); - assert_eq!(0, segment.deletes()); + assert_eq!(segment.deletes(), 0); Ok(()) } @@ -478,8 +476,8 @@ mod test { expect_single_field(segment_space_usage.termdict(), &name, 1, 512); expect_single_field(segment_space_usage.postings(), &name, 1, 512); - assert_eq!(0, segment_space_usage.positions().total()); - assert_eq!(0, segment_space_usage.fast_fields().total()); + assert_eq!(segment_space_usage.positions().total(), 0u64); + assert_eq!(segment_space_usage.fast_fields().total(), 0u64); expect_single_field(segment_space_usage.fieldnorms(), &name, 1, 512); assert!(segment_space_usage.deletes() > 0); Ok(()) diff --git a/src/store/reader.rs b/src/store/reader.rs index 9b9ea1647..03a277dd8 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -5,7 +5,7 @@ use std::ops::{AddAssign, Range}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; -use common::{BinarySerializable, HasLen, OwnedBytes}; +use common::{BinarySerializable, OwnedBytes}; use lru::LruCache; use super::footer::DocStoreFooter; @@ -122,7 +122,8 @@ impl StoreReader { let (data_file, offset_index_file) = data_and_offset.split(footer.offset as usize); let index_data = offset_index_file.read_bytes()?; - let space_usage = StoreSpaceUsage::new(data_file.len(), offset_index_file.len()); + let space_usage = + StoreSpaceUsage::new(data_file.num_bytes(), offset_index_file.num_bytes()); let skip_index = SkipIndex::open(index_data); Ok(StoreReader { decompressor: footer.decompressor,