add memory limit for aggregations (#1942)

* add memory limit for aggregations

introduce AggregationLimits to set memory consumption limit and bucket limits
memory limit is checked during aggregation, bucket limit is checked before returning the aggregation request.

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add ByteCount with human readable format

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
This commit is contained in:
PSeitz
2023-03-16 13:21:07 +08:00
committed by GitHub
parent b6703f1b3c
commit 9e2faecf5b
27 changed files with 556 additions and 193 deletions

View File

@@ -3,7 +3,7 @@ use std::net::Ipv6Addr;
use std::sync::Arc;
use common::file_slice::FileSlice;
use common::{DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{monotonic_map_column, StrictlyMonotonicFn};
@@ -248,8 +248,8 @@ impl DynamicColumnHandle {
Ok(dynamic_column)
}
pub fn num_bytes(&self) -> usize {
self.file_slice.len()
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {

View File

@@ -4,6 +4,8 @@ use std::{fmt, io, u64};
use ownedbytes::OwnedBytes;
use crate::ByteCount;
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct TinySet(u64);
@@ -386,8 +388,8 @@ impl ReadOnlyBitSet {
}
/// Number of bytes used in the bitset representation.
pub fn num_bytes(&self) -> usize {
self.data.len()
pub fn num_bytes(&self) -> ByteCount {
self.data.len().into()
}
}

108
common/src/byte_count.rs Normal file
View File

@@ -0,0 +1,108 @@
use std::iter::Sum;
use std::ops::{Add, AddAssign};
use serde::{Deserialize, Serialize};
/// Indicates space usage in bytes
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ByteCount(u64);
impl std::fmt::Debug for ByteCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.human_readable())
}
}
const SUFFIX_AND_THRESHOLD: [(&str, u64); 5] = [
("KB", 1_000),
("MB", 1_000_000),
("GB", 1_000_000_000),
("TB", 1_000_000_000_000),
("PB", 1_000_000_000_000_000),
];
impl ByteCount {
#[inline]
pub fn get_bytes(&self) -> u64 {
self.0
}
pub fn human_readable(&self) -> String {
for (suffix, threshold) in SUFFIX_AND_THRESHOLD.iter().rev() {
if self.get_bytes() >= *threshold {
let unit_num = self.get_bytes() as f64 / *threshold as f64;
return format!("{:.2} {}", unit_num, suffix);
}
}
format!("{:.2} B", self.get_bytes())
}
}
impl From<u64> for ByteCount {
fn from(value: u64) -> Self {
ByteCount(value)
}
}
impl From<usize> for ByteCount {
fn from(value: usize) -> Self {
ByteCount(value as u64)
}
}
impl Sum for ByteCount {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(ByteCount::default(), |acc, x| acc + x)
}
}
impl PartialEq<u64> for ByteCount {
#[inline]
fn eq(&self, other: &u64) -> bool {
self.get_bytes() == *other
}
}
impl PartialOrd<u64> for ByteCount {
#[inline]
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.get_bytes().partial_cmp(other)
}
}
impl Add for ByteCount {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
Self(self.get_bytes() + other.get_bytes())
}
}
impl AddAssign for ByteCount {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = Self(self.get_bytes() + other.get_bytes());
}
}
#[cfg(test)]
mod test {
use crate::ByteCount;
#[test]
fn test_bytes() {
assert_eq!(ByteCount::from(0u64).human_readable(), "0 B");
assert_eq!(ByteCount::from(300u64).human_readable(), "300 B");
assert_eq!(ByteCount::from(1_000_000u64).human_readable(), "1.00 MB");
assert_eq!(ByteCount::from(1_500_000u64).human_readable(), "1.50 MB");
assert_eq!(
ByteCount::from(1_500_000_000u64).human_readable(),
"1.50 GB"
);
assert_eq!(
ByteCount::from(3_213_000_000_000u64).human_readable(),
"3.21 TB"
);
}
}

View File

@@ -5,7 +5,7 @@ use std::{fmt, io};
use async_trait::async_trait;
use ownedbytes::{OwnedBytes, StableDeref};
use crate::HasLen;
use crate::{ByteCount, HasLen};
/// Objects that represents files sections in tantivy.
///
@@ -216,6 +216,11 @@ impl FileSlice {
pub fn slice_to(&self, to_offset: usize) -> FileSlice {
self.slice(0..to_offset)
}
/// Returns the byte count of the FileSlice.
pub fn num_bytes(&self) -> ByteCount {
self.range.len().into()
}
}
#[async_trait]

View File

@@ -5,6 +5,7 @@ use std::ops::Deref;
pub use byteorder::LittleEndian as Endianness;
mod bitset;
mod byte_count;
mod datetime;
pub mod file_slice;
mod group_by;
@@ -12,6 +13,7 @@ mod serialize;
mod vint;
mod writer;
pub use bitset::*;
pub use byte_count::ByteCount;
pub use datetime::{DatePrecision, DateTime};
pub use group_by::GroupByIteratorExtended;
pub use ownedbytes::{OwnedBytes, StableDeref};

View File

@@ -192,7 +192,7 @@ fn main() -> tantivy::Result<()> {
//
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res2: Value = serde_json::to_value(agg_res)?;
@@ -239,7 +239,7 @@ fn main() -> tantivy::Result<()> {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
// We use the `AllQuery` which will pass all documents to the AggregationCollector.
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
@@ -287,7 +287,7 @@ fn main() -> tantivy::Result<()> {
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::to_value(agg_res)?;

View File

@@ -0,0 +1,94 @@
use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use common::ByteCount;
use super::collector::DEFAULT_MEMORY_LIMIT;
use super::{AggregationError, DEFAULT_BUCKET_LIMIT};
use crate::TantivyError;
/// An estimate for memory consumption
pub trait MemoryConsumption {
fn memory_consumption(&self) -> usize;
}
impl<K, V, S> MemoryConsumption for HashMap<K, V, S> {
fn memory_consumption(&self) -> usize {
let num_items = self.capacity();
(std::mem::size_of::<K>() + std::mem::size_of::<V>()) * num_items
}
}
/// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT
/// (500MB). The limit is shared by all SegmentCollectors
pub struct AggregationLimits {
/// The counter which is shared between the aggregations for one request.
memory_consumption: Arc<AtomicU64>,
/// The memory_limit in bytes
memory_limit: ByteCount,
/// The maximum number of buckets _returned_
/// This is not counting intermediate buckets.
bucket_limit: u32,
}
impl Clone for AggregationLimits {
fn clone(&self) -> Self {
Self {
memory_consumption: Arc::clone(&self.memory_consumption),
memory_limit: self.memory_limit,
bucket_limit: self.bucket_limit,
}
}
}
impl Default for AggregationLimits {
fn default() -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: DEFAULT_MEMORY_LIMIT.into(),
bucket_limit: DEFAULT_BUCKET_LIMIT,
}
}
}
impl AggregationLimits {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*
/// Limits the maximum number of buckets returned from an aggregation request.
/// bucket_limit will default to `DEFAULT_BUCKET_LIMIT` (65000)
pub fn new(memory_limit: Option<u64>, bucket_limit: Option<u32>) -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(),
bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT),
}
}
pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> {
if self.get_memory_consumed() > self.memory_limit {
return Err(TantivyError::AggregationError(
AggregationError::MemoryExceeded {
limit: self.memory_limit,
current: self.get_memory_consumed(),
},
));
}
Ok(())
}
pub(crate) fn add_memory_consumed(&self, num_bytes: u64) {
self.memory_consumption
.fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_memory_consumed(&self) -> ByteCount {
self.memory_consumption
.load(std::sync::atomic::Ordering::Relaxed)
.into()
}
pub fn get_bucket_limit(&self) -> u32 {
self.bucket_limit
}
}

View File

@@ -1,7 +1,5 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use columnar::{Column, ColumnType, ColumnValues, StrColumn};
@@ -14,7 +12,7 @@ use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
};
use super::segment_agg_result::BucketCount;
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
use crate::SegmentReader;
@@ -46,7 +44,7 @@ pub struct BucketAggregationWithAccessor {
pub(crate) field_type: ColumnType,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: BucketCount,
pub(crate) limits: AggregationLimits,
}
impl BucketAggregationWithAccessor {
@@ -54,8 +52,7 @@ impl BucketAggregationWithAccessor {
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: AggregationLimits,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut str_dict_column = None;
let (accessor, field_type) = match &bucket {
@@ -83,15 +80,11 @@ impl BucketAggregationWithAccessor {
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
bucket_count.clone(),
max_bucket_count,
&limits.clone(),
)?,
bucket_agg: bucket.clone(),
str_dict_column,
bucket_count: BucketCount {
bucket_count,
max_bucket_count,
},
limits,
})
}
}
@@ -131,8 +124,7 @@ impl MetricAggregationWithAccessor {
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<AggregationsWithAccessor> {
let mut metrics = vec![];
let mut buckets = vec![];
@@ -144,8 +136,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
max_bucket_count,
limits.clone(),
)?,
)),
Aggregation::Metric(metric) => metrics.push((

View File

@@ -11,6 +11,7 @@ use super::agg_req::BucketAggregationInternal;
use super::bucket::GetDocCount;
use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult};
use super::metric::{SingleMetricResult, Stats};
use super::segment_agg_result::AggregationLimits;
use super::Key;
use crate::TantivyError;
@@ -19,6 +20,13 @@ use crate::TantivyError;
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {
pub(crate) fn get_bucket_count(&self) -> u64 {
self.0
.values()
.map(|agg| agg.get_bucket_count())
.sum::<u64>()
}
pub(crate) fn get_value_from_aggregation(
&self,
name: &str,
@@ -47,6 +55,13 @@ pub enum AggregationResult {
}
impl AggregationResult {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
AggregationResult::BucketResult(bucket) => bucket.get_bucket_count(),
AggregationResult::MetricResult(_) => 0,
}
}
pub(crate) fn get_value_from_aggregation(
&self,
_name: &str,
@@ -153,9 +168,28 @@ pub enum BucketResult {
}
impl BucketResult {
pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result<Self> {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
BucketResult::Range { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Histogram { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Terms {
buckets,
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
}
}
pub(crate) fn empty_from_req(
req: &BucketAggregationInternal,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
empty_bucket.into_final_bucket_result(req)
empty_bucket.into_final_bucket_result(req, limits)
}
}
@@ -170,6 +204,15 @@ pub enum BucketEntries<T> {
HashMap(FxHashMap<String, T>),
}
impl<T> BucketEntries<T> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),
}
}
}
/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub-aggregations.
///
@@ -209,6 +252,11 @@ pub struct BucketEntry {
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl BucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}
impl GetDocCount for &BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
@@ -272,3 +320,8 @@ pub struct RangeBucketEntry {
#[serde(skip_serializing_if = "Option::is_none")]
pub to_as_string: Option<String>,
}
impl RangeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -9,6 +9,7 @@ use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::metric::AverageAggregation;
use crate::aggregation::segment_agg_result::AggregationLimits;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery};
@@ -21,6 +22,10 @@ fn get_avg_req(field_name: &str) -> Aggregation {
))
}
fn get_collector(agg_req: Aggregations) -> AggregationCollector {
AggregationCollector::from_aggs(agg_req, Default::default())
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
fn test_aggregation_flushing(
merge_segments: bool,
@@ -98,15 +103,18 @@ fn test_aggregation_flushing(
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None);
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimits::default(),
);
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
intermediate_agg_result
.into_final_bucket_result(agg_req)
.into_final_bucket_result(agg_req, &Default::default())
.unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -243,7 +251,7 @@ fn test_aggregation_level1() -> crate::Result<()> {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
@@ -432,16 +440,18 @@ fn test_aggregation_level2(
};
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None);
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let searcher = reader.searcher();
let res = searcher.search(&term_query, &collector).unwrap();
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
res.into_final_bucket_result(agg_req.clone()).unwrap()
res.into_final_bucket_result(agg_req.clone(), &Default::default())
.unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req.clone(), None);
let collector = get_collector(agg_req.clone());
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -499,7 +509,7 @@ fn test_aggregation_level2(
);
// Test empty result set
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&query_with_no_hits, &collector).unwrap();
@@ -562,7 +572,7 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
@@ -620,7 +630,7 @@ fn test_aggregation_on_json_object() {
)]
.into_iter()
.collect();
let aggregation_collector = AggregationCollector::from_aggs(agg, None);
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
@@ -690,7 +700,7 @@ fn test_aggregation_on_json_object_empty_columns() {
.into_iter()
.collect();
let aggregation_collector = AggregationCollector::from_aggs(agg, None);
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
@@ -721,9 +731,8 @@ fn test_aggregation_on_json_object_empty_columns() {
}
} "#;
let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap();
let aggregation_results = searcher
.search(&AllQuery, &AggregationCollector::from_aggs(agg, None))
.unwrap();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
@@ -883,7 +892,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -912,7 +921,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -941,7 +950,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -978,7 +987,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -1008,7 +1017,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1047,7 +1056,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1077,7 +1086,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1111,7 +1120,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1149,7 +1158,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1196,7 +1205,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1236,7 +1245,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
@@ -1275,7 +1284,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1306,7 +1315,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1364,7 +1373,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()

View File

@@ -7,6 +7,7 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::AggregationsInternal;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
@@ -16,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames};
use crate::{DocId, TantivyError};
@@ -249,6 +250,8 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
let sub_aggregation_accessor =
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
let mem_pre = self.get_memory_consumption();
let bounds = self.bounds;
let interval = self.interval;
let offset = self.offset;
@@ -271,6 +274,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits;
limits.add_memory_consumed(mem_delta as u64);
limits.validate_memory_consumption()?;
Ok(())
}
@@ -287,6 +296,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
}
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
@@ -389,6 +404,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
buckets: Vec<IntermediateHistogramBucketEntry>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
// Generate the full list of buckets without gaps.
//
@@ -396,7 +412,17 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
// extended_bounds from the request
let min_max = minmax(buckets.iter().map(|bucket| bucket.key));
// TODO add memory check
// memory check upfront
let (_, first_bucket_num, last_bucket_num) =
generate_bucket_pos_with_opt_minmax(histogram_req, min_max);
let added_buckets = (first_bucket_num..=last_bucket_num)
.count()
.saturating_sub(buckets.len());
limits.add_memory_consumed(
added_buckets as u64 * std::mem::size_of::<IntermediateHistogramBucketEntry>() as u64,
);
limits.validate_memory_consumption()?;
// create buckets
let fill_gaps_buckets = generate_buckets_with_opt_minmax(histogram_req, min_max);
let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(sub_aggregation);
@@ -425,7 +451,9 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
sub_aggregation: empty_sub_aggregation.clone(),
},
})
.map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation))
.map(|intermediate_bucket| {
intermediate_bucket.into_final_bucket_entry(sub_aggregation, limits)
})
.collect::<crate::Result<Vec<_>>>()
}
@@ -435,18 +463,26 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets(
column_type: Option<ColumnType>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
let mut buckets = if histogram_req.min_doc_count() == 0 {
// With min_doc_count != 0, we may need to add buckets, so that there are no
// gaps, since intermediate result does not contain empty buckets (filtered to
// reduce serialization size).
intermediate_buckets_to_final_buckets_fill_gaps(buckets, histogram_req, sub_aggregation)?
intermediate_buckets_to_final_buckets_fill_gaps(
buckets,
histogram_req,
sub_aggregation,
limits,
)?
} else {
buckets
.into_iter()
.filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count())
.map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation))
.map(|histogram_bucket| {
histogram_bucket.into_final_bucket_entry(sub_aggregation, limits)
})
.collect::<crate::Result<Vec<_>>>()?
};
@@ -485,15 +521,27 @@ fn get_req_min_max(req: &HistogramAggregation, min_max: Option<(f64, f64)>) -> (
/// Generates buckets with req.interval
/// Range is computed for provided min_max and request extended_bounds/hard_bounds
/// returns empty vec when there is no range to span
pub(crate) fn generate_buckets_with_opt_minmax(
pub(crate) fn generate_bucket_pos_with_opt_minmax(
req: &HistogramAggregation,
min_max: Option<(f64, f64)>,
) -> Vec<f64> {
) -> (f64, i64, i64) {
let (min, max) = get_req_min_max(req, min_max);
let offset = req.offset.unwrap_or(0.0);
let first_bucket_num = get_bucket_pos_f64(min, req.interval, offset) as i64;
let last_bucket_num = get_bucket_pos_f64(max, req.interval, offset) as i64;
(offset, first_bucket_num, last_bucket_num)
}
/// Generates buckets with req.interval
/// Range is computed for provided min_max and request extended_bounds/hard_bounds
/// returns empty vec when there is no range to span
pub(crate) fn generate_buckets_with_opt_minmax(
req: &HistogramAggregation,
min_max: Option<(f64, f64)>,
) -> Vec<f64> {
let (offset, first_bucket_num, last_bucket_num) =
generate_bucket_pos_with_opt_minmax(req, min_max);
let mut buckets = Vec::with_capacity((first_bucket_num..=last_bucket_num).count());
for bucket_pos in first_bucket_num..=last_bucket_num {
let bucket_key = bucket_pos as f64 * req.interval + offset;
@@ -515,8 +563,8 @@ mod tests {
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::tests::{
exec_request, exec_request_with_query, get_test_index_2_segments,
get_test_index_from_values, get_test_index_with_num_docs,
exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit,
get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs,
};
#[test]
@@ -661,6 +709,40 @@ mod tests {
Ok(())
}
#[test]
fn histogram_memory_limit() -> crate::Result<()> {
let index = get_test_index_with_num_docs(true, 100)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 0.1,
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let res = exec_request_with_query_and_memory_limit(
agg_req,
&index,
None,
AggregationLimits::new(Some(5_000), None),
)
.unwrap_err();
assert_eq!(
res.to_string(),
"Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \
102.48 KB"
);
Ok(())
}
#[test]
fn histogram_merge_test() -> crate::Result<()> {
// Merge buckets counts from different segments

View File

@@ -11,7 +11,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, BucketCount, SegmentAggregationCollector,
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames,
@@ -260,7 +260,7 @@ impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &BucketCount,
limits: &AggregationLimits,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
@@ -304,8 +304,10 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
bucket_count.add_count(buckets.len() as u32);
bucket_count.validate_bucket_count()?;
limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
);
limits.validate_memory_consumption()?;
Ok(SegmentRangeCollector {
buckets,

View File

@@ -1,36 +1,36 @@
use std::rc::Rc;
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{build_segment_agg_collector, SegmentAggregationCollector};
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::{SegmentReader, TantivyError};
/// The default max bucket count, before the aggregation fails.
pub const MAX_BUCKET_COUNT: u32 = 65000;
pub const DEFAULT_BUCKET_LIMIT: u32 = 65000;
/// The default memory limit in bytes before the aggregation fails. 500MB
pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// Collector for aggregations.
///
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl AggregationCollector {
/// Create collector from aggregation request.
///
/// Aggregation fails when the total bucket count is higher than max_bucket_count.
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -44,18 +44,16 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl DistributedAggregationCollector {
/// Create collector from aggregation request.
///
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -69,11 +67,7 @@ impl Collector for DistributedAggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -98,11 +92,7 @@ impl Collector for AggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -114,7 +104,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_bucket_result(self.agg.clone())
res.into_final_bucket_result(self.agg.clone(), &self.limits)
}
}
@@ -145,10 +135,9 @@ impl AggregationSegmentCollector {
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let aggs_with_accessor =
get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?;
let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, limits)?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?);
Ok(AggregationSegmentCollector {

View File

@@ -1,9 +1,33 @@
use common::ByteCount;
use super::bucket::DateHistogramParseError;
/// Error that may occur when opening a directory
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum AggregationError {
/// Failed to open the directory.
/// Date histogram parse error
#[error("Date histogram parse error: {0:?}")]
DateHistogramParseError(#[from] DateHistogramParseError),
/// Memory limit exceeded
#[error(
"Aborting aggregation because memory limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
MemoryExceeded {
/// Memory consumption limit
limit: ByteCount,
/// Current memory consumption
current: ByteCount,
},
/// Bucket limit exceeded
#[error(
"Aborting aggregation because bucket limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
BucketLimitExceeded {
/// Bucket limit
limit: u32,
/// Current num buckets
current: u32,
},
}

View File

@@ -22,9 +22,11 @@ use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
IntermediateSum,
};
use super::{format_date, Key, SerializedKey, VecWithNames};
use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
/// intermediate results.
@@ -38,8 +40,23 @@ pub struct IntermediateAggregationResults {
impl IntermediateAggregationResults {
/// Convert intermediate result and its aggregation request to the final result.
pub fn into_final_bucket_result(self, req: Aggregations) -> crate::Result<AggregationResults> {
self.into_final_bucket_result_internal(&(req.into()))
pub fn into_final_bucket_result(
self,
req: Aggregations,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// TODO count and validate buckets
let res = self.into_final_bucket_result_internal(&(req.into()), limits)?;
let bucket_count = res.get_bucket_count() as u32;
if bucket_count > limits.get_bucket_limit() {
return Err(TantivyError::AggregationError(
AggregationError::BucketLimitExceeded {
limit: limits.get_bucket_limit(),
current: bucket_count,
},
));
}
Ok(res)
}
/// Convert intermediate result and its aggregation request to the final result.
@@ -49,6 +66,7 @@ impl IntermediateAggregationResults {
pub(crate) fn into_final_bucket_result_internal(
self,
req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
@@ -56,11 +74,11 @@ impl IntermediateAggregationResults {
let mut results: FxHashMap<String, AggregationResult> = FxHashMap::default();
if let Some(buckets) = self.buckets {
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)?
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)?
} else {
// When there are no buckets, we create empty buckets, so that the serialized json
// format is constant
add_empty_final_buckets_to_result(&mut results, &req.buckets)?
add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)?
};
if let Some(metrics) = self.metrics {
@@ -161,10 +179,12 @@ fn add_empty_final_metrics_to_result(
fn add_empty_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
let requested_buckets = req_buckets.iter();
for (key, req) in requested_buckets {
let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?);
let empty_bucket =
AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?);
results.insert(key.to_string(), empty_bucket);
}
Ok(())
@@ -174,12 +194,13 @@ fn convert_and_add_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
buckets: VecWithNames<IntermediateBucketResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
assert_eq!(buckets.len(), req_buckets.len());
let buckets_with_request = buckets.into_iter().zip(req_buckets.values());
for ((key, bucket), req) in buckets_with_request {
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?);
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?);
results.insert(key, result);
}
Ok(())
@@ -287,6 +308,7 @@ impl IntermediateBucketResult {
pub(crate) fn into_final_bucket_result(
self,
req: &BucketAggregationInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
match self {
IntermediateBucketResult::Range(range_res) => {
@@ -299,6 +321,7 @@ impl IntermediateBucketResult {
req.as_range()
.expect("unexpected aggregation, expected histogram aggregation"),
range_res.column_type,
limits,
)
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -337,6 +360,7 @@ impl IntermediateBucketResult {
column_type,
histogram_req,
&req.sub_aggregation,
limits,
)?;
let buckets = if histogram_req.keyed {
@@ -355,6 +379,7 @@ impl IntermediateBucketResult {
req.as_term()
.expect("unexpected aggregation, expected term aggregation"),
&req.sub_aggregation,
limits,
),
}
}
@@ -449,6 +474,7 @@ impl IntermediateTermBucketResult {
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
let req = TermsAggregationInternal::from_req(req);
let mut buckets: Vec<BucketEntry> = self
@@ -462,7 +488,7 @@ impl IntermediateTermBucketResult {
doc_count: entry.doc_count,
sub_aggregation: entry
.sub_aggregation
.into_final_bucket_result_internal(sub_aggregation_req)?,
.into_final_bucket_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<_>>()?;
@@ -582,6 +608,7 @@ impl IntermediateHistogramBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketEntry> {
Ok(BucketEntry {
key_as_string: None,
@@ -589,7 +616,7 @@ impl IntermediateHistogramBucketEntry {
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req)?,
.into_final_bucket_result_internal(req, limits)?,
})
}
}
@@ -628,13 +655,14 @@ impl IntermediateRangeBucketEntry {
req: &AggregationsInternal,
_range_req: &RangeAggregation,
column_type: Option<ColumnType>,
limits: &AggregationLimits,
) -> crate::Result<RangeBucketEntry> {
let mut range_bucket_entry = RangeBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req)?,
.into_final_bucket_result_internal(req, limits)?,
to: self.to,
from: self.from,
to_as_string: None,

View File

@@ -81,7 +81,7 @@ mod tests {
"price_sum": { "sum": { "field": "price" } }
}"#;
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
let collector = AggregationCollector::from_aggs(aggregations, None);
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();

View File

@@ -294,7 +294,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -331,7 +331,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -411,7 +411,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();

View File

@@ -70,7 +70,7 @@
//! .into_iter()
//! .collect();
//!
//! let collector = AggregationCollector::from_aggs(agg_req, None);
//! let collector = AggregationCollector::from_aggs(agg_req, Default::default());
//!
//! let searcher = reader.searcher();
//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
@@ -155,6 +155,7 @@
//! [`AggregationResults`](agg_result::AggregationResults) via the
//! [`into_final_bucket_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_bucket_result) method.
mod agg_limits;
pub mod agg_req;
mod agg_req_with_accessor;
pub mod agg_result;
@@ -165,6 +166,7 @@ mod date;
mod error;
pub mod intermediate_agg_result;
pub mod metric;
mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
@@ -174,7 +176,7 @@ mod agg_tests;
pub use collector::{
AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector,
MAX_BUCKET_COUNT,
DEFAULT_BUCKET_LIMIT,
};
use columnar::{ColumnType, MonotonicallyMappableToU64};
pub(crate) use date::format_date;
@@ -345,6 +347,7 @@ mod tests {
use time::OffsetDateTime;
use super::agg_req::Aggregations;
use super::segment_agg_result::AggregationLimits;
use super::*;
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, TermQuery};
@@ -369,7 +372,16 @@ mod tests {
index: &Index,
query: Option<(&str, &str)>,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, None);
exec_request_with_query_and_memory_limit(agg_req, index, query, Default::default())
}
pub fn exec_request_with_query_and_memory_limit(
agg_req: Aggregations,
index: &Index,
query: Option<(&str, &str)>,
limits: AggregationLimits,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, limits);
let reader = index.reader()?;
let searcher = reader.searcher();

View File

@@ -4,15 +4,13 @@
//! merging.
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
pub(crate) use super::agg_limits::AggregationLimits;
use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::collector::MAX_BUCKET_COUNT;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector,
@@ -20,7 +18,6 @@ use super::metric::{
};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
use crate::TantivyError;
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn into_intermediate_aggregations_result(
@@ -131,7 +128,7 @@ pub(crate) fn build_bucket_segment_agg_collector(
Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.bucket_count,
&req.limits,
req.field_type,
accessor_idx,
)?))
@@ -284,37 +281,3 @@ impl GenericSegmentAggregationResultsCollector {
Ok(GenericSegmentAggregationResultsCollector { metrics, buckets })
}
}
#[derive(Clone)]
pub(crate) struct BucketCount {
/// The counter which is shared between the aggregations for one request.
pub(crate) bucket_count: Rc<AtomicU32>,
pub(crate) max_bucket_count: u32,
}
impl Default for BucketCount {
fn default() -> Self {
Self {
bucket_count: Default::default(),
max_bucket_count: MAX_BUCKET_COUNT,
}
}
}
impl BucketCount {
pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> {
if self.get_count() > self.max_bucket_count {
return Err(TantivyError::InvalidArgument(
"Aborting aggregation because too many buckets were created".to_string(),
));
}
Ok(())
}
pub(crate) fn add_count(&self, count: u32) {
self.bucket_count
.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get_count(&self) -> u32 {
self.bucket_count.load(std::sync::atomic::Ordering::Relaxed)
}
}

View File

@@ -327,7 +327,7 @@ impl SegmentReader {
self.alive_bitset_opt
.as_ref()
.map(AliveBitSet::space_usage)
.unwrap_or(0),
.unwrap_or_default(),
))
}
}

View File

@@ -172,7 +172,7 @@ impl CompositeFile {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
field_usage.add_field_idx(field_addr.idx, byte_range.len());
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}
PerFieldSpaceUsage::new(fields)

View File

@@ -55,7 +55,7 @@ impl fmt::Debug for DataCorruption {
#[derive(Debug, Clone, Error)]
pub enum TantivyError {
/// Error when handling aggregations.
#[error("AggregationError {0:?}")]
#[error(transparent)]
AggregationError(#[from] AggregationError),
/// Failed to open the directory.
#[error("Failed to open the directory: '{0:?}'")]

View File

@@ -1,9 +1,8 @@
use std::io;
use std::io::Write;
use common::{intersect_bitsets, BitSet, OwnedBytes, ReadOnlyBitSet};
use common::{intersect_bitsets, BitSet, ByteCount, OwnedBytes, ReadOnlyBitSet};
use crate::space_usage::ByteCount;
use crate::DocId;
/// Write an alive `BitSet`

View File

@@ -80,7 +80,7 @@ mod tests {
use std::path::Path;
use columnar::{Column, MonotonicallyMappableToU64, StrColumn};
use common::{HasLen, TerminatingWrite};
use common::{ByteCount, HasLen, TerminatingWrite};
use once_cell::sync::Lazy;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
@@ -862,16 +862,16 @@ mod tests {
#[test]
pub fn test_gcd_date() {
let size_prec_sec = test_gcd_date_with_codec(DatePrecision::Seconds);
assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec)); // 13 bits per val = ceil(log_2(number of seconds in 2hours);
assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec.get_bytes())); // 13 bits per val = ceil(log_2(number of seconds in 2hours);
let size_prec_micros = test_gcd_date_with_codec(DatePrecision::Microseconds);
assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros));
assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros.get_bytes()));
// 33 bits per
// val = ceil(log_2(number
// of microsecsseconds
// in 2hours);
}
fn test_gcd_date_with_codec(precision: DatePrecision) -> usize {
fn test_gcd_date_with_codec(precision: DatePrecision) -> ByteCount {
let mut rng = StdRng::seed_from_u64(2u64);
const T0: i64 = 1_662_345_825_012_529i64;
const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000;

View File

@@ -6,6 +6,7 @@ use columnar::{
BytesColumn, Column, ColumnType, ColumnValues, ColumnarReader, DynamicColumn,
DynamicColumnHandle, HasAssociatedColumnType, StrColumn,
};
use common::ByteCount;
use crate::core::json_utils::encode_column_name;
use crate::directory::FileSlice;
@@ -42,7 +43,7 @@ impl FastFieldReaders {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: usize = column_handles
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
@@ -136,9 +137,9 @@ impl FastFieldReaders {
/// Returns the number of `bytes` associated with a column.
///
/// Returns 0 if the column does not exist.
pub fn column_num_bytes(&self, field: &str) -> crate::Result<usize> {
pub fn column_num_bytes(&self, field: &str) -> crate::Result<ByteCount> {
let Some(resolved_field_name) = self.resolve_field(field)? else {
return Ok(0);
return Ok(0u64.into());
};
Ok(self
.columnar

View File

@@ -9,14 +9,12 @@
use std::collections::HashMap;
use common::ByteCount;
use serde::{Deserialize, Serialize};
use crate::schema::Field;
use crate::SegmentComponent;
/// Indicates space usage in bytes
pub type ByteCount = usize;
/// Enum containing any of the possible space usage results for segment components.
pub enum ComponentSpaceUsage {
/// Data is stored per field in a uniform way
@@ -38,7 +36,7 @@ impl SearcherSpaceUsage {
pub(crate) fn new() -> SearcherSpaceUsage {
SearcherSpaceUsage {
segments: Vec::new(),
total: 0,
total: Default::default(),
}
}
@@ -260,7 +258,7 @@ impl FieldUsage {
pub(crate) fn empty(field: Field) -> FieldUsage {
FieldUsage {
field,
num_bytes: 0,
num_bytes: Default::default(),
sub_num_bytes: Vec::new(),
}
}
@@ -294,7 +292,7 @@ impl FieldUsage {
mod test {
use crate::core::Index;
use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT};
use crate::space_usage::{ByteCount, PerFieldSpaceUsage};
use crate::space_usage::PerFieldSpaceUsage;
use crate::Term;
#[test]
@@ -304,14 +302,14 @@ mod test {
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let searcher_space_usage = searcher.space_usage().unwrap();
assert_eq!(0, searcher_space_usage.total());
assert_eq!(searcher_space_usage.total(), 0u64);
}
fn expect_single_field(
field_space: &PerFieldSpaceUsage,
field: &Field,
min_size: ByteCount,
max_size: ByteCount,
min_size: u64,
max_size: u64,
) {
assert!(field_space.total() >= min_size);
assert!(field_space.total() <= max_size);
@@ -353,12 +351,12 @@ mod test {
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
assert_eq!(0, segment.positions().total());
assert_eq!(segment.positions().total(), 0);
expect_single_field(segment.fast_fields(), &name, 1, 512);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -394,11 +392,11 @@ mod test {
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
expect_single_field(segment.positions(), &name, 1, 512);
assert_eq!(0, segment.fast_fields().total());
assert_eq!(segment.fast_fields().total(), 0);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -430,14 +428,14 @@ mod test {
assert_eq!(4, segment.num_docs());
assert_eq!(0, segment.termdict().total());
assert_eq!(0, segment.postings().total());
assert_eq!(0, segment.positions().total());
assert_eq!(0, segment.fast_fields().total());
assert_eq!(0, segment.fieldnorms().total());
assert_eq!(segment.termdict().total(), 0);
assert_eq!(segment.postings().total(), 0);
assert_eq!(segment.positions().total(), 0);
assert_eq!(segment.fast_fields().total(), 0);
assert_eq!(segment.fieldnorms().total(), 0);
assert!(segment.store().total() > 0);
assert!(segment.store().total() < 512);
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -478,8 +476,8 @@ mod test {
expect_single_field(segment_space_usage.termdict(), &name, 1, 512);
expect_single_field(segment_space_usage.postings(), &name, 1, 512);
assert_eq!(0, segment_space_usage.positions().total());
assert_eq!(0, segment_space_usage.fast_fields().total());
assert_eq!(segment_space_usage.positions().total(), 0u64);
assert_eq!(segment_space_usage.fast_fields().total(), 0u64);
expect_single_field(segment_space_usage.fieldnorms(), &name, 1, 512);
assert!(segment_space_usage.deletes() > 0);
Ok(())

View File

@@ -5,7 +5,7 @@ use std::ops::{AddAssign, Range};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use common::{BinarySerializable, HasLen, OwnedBytes};
use common::{BinarySerializable, OwnedBytes};
use lru::LruCache;
use super::footer::DocStoreFooter;
@@ -122,7 +122,8 @@ impl StoreReader {
let (data_file, offset_index_file) = data_and_offset.split(footer.offset as usize);
let index_data = offset_index_file.read_bytes()?;
let space_usage = StoreSpaceUsage::new(data_file.len(), offset_index_file.len());
let space_usage =
StoreSpaceUsage::new(data_file.num_bytes(), offset_index_file.num_bytes());
let skip_index = SkipIndex::open(index_data);
Ok(StoreReader {
decompressor: footer.decompressor,