diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 0946c2a08..eb03e7815 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -7,12 +7,8 @@ // --- use serde_json::{Deserializer, Value}; -use tantivy::aggregation::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, -}; +use tantivy::aggregation::agg_req::Aggregations; use tantivy::aggregation::agg_result::AggregationResults; -use tantivy::aggregation::bucket::{RangeAggregation, RangeAggregationRange}; -use tantivy::aggregation::metric::AverageAggregation; use tantivy::aggregation::AggregationCollector; use tantivy::query::AllQuery; use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing, FAST}; diff --git a/src/aggregation/agg_bench.rs b/src/aggregation/agg_bench.rs index 9f8ca504f..12a898968 100644 --- a/src/aggregation/agg_bench.rs +++ b/src/aggregation/agg_bench.rs @@ -5,16 +5,10 @@ mod bench { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; + use serde_json::json; use test::{self, Bencher}; - use crate::aggregation::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, - }; - use crate::aggregation::bucket::{ - CustomOrder, HistogramAggregation, HistogramBounds, Order, OrderTarget, RangeAggregation, - TermsAggregation, - }; - use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::AggregationCollector; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING}; @@ -153,14 +147,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average": { "avg": { "field": "score", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -182,14 +172,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_f64".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "stats": { "field": "score_f64", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -211,14 +197,10 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "avg": { "field": "score_f64", } } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -265,22 +247,11 @@ mod bench { IndexRecordOption::Basic, ); - let agg_req_1: Aggregations = vec![ - ( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - ), - ( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average_f64": { "avg": { "field": "score_f64" } }, + "average": { "avg": { "field": "score" } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -296,21 +267,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_few_terms".to_string(), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_few_terms" } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -326,30 +286,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { "field": "text_many_terms" }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -365,21 +310,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_many_terms" } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -395,25 +329,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req: Aggregations = vec![( - "my_texts".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Terms(TermsAggregation { - field: "text_many_terms".to_string(), - order: Some(CustomOrder { - order: Order::Desc, - target: OrderTarget::Key, - }), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } }, + })) + .unwrap(); let collector = get_collector(agg_req); @@ -429,29 +348,17 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..30000f64).into(), - (30000f64..40000f64).into(), - (40000f64..50000f64).into(), - (50000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "range_f64": { "range": { "field": "score_f64", "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 30000 }, + { "from": 30000, "to": 40000 }, + { "from": 40000, "to": 50000 }, + { "from": 50000, "to": 60000 } + ] } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -467,38 +374,25 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..30000f64).into(), - (30000f64..40000f64).into(), - (40000f64..50000f64).into(), - (50000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "range": { + "field": "score_f64", + "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 30000 }, + { "from": 30000, "to": 40000 }, + { "from": 40000, "to": 50000 }, + { "from": 50000, "to": 60000 } + ] + }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + }, + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -519,26 +413,10 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, - hard_bounds: Some(HistogramBounds { - min: 1000.0, - max: 300_000.0, - }), - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { "histogram": { "field": "score_f64", "interval": 100, "hard_bounds": { "min": 1000, "max": 300000 } } }, + })) + .unwrap(); let collector = get_collector(agg_req_1); let searcher = reader.searcher(); @@ -553,31 +431,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let sub_agg_req: Aggregations = vec![( - "average_f64".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score_f64".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, // 1000 buckets - ..Default::default() - }), - sub_aggregation: sub_agg_req, + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "histogram": { "field": "score_f64", "interval": 100 }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } } - .into(), - ), - )] - .into_iter() - .collect(); + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -593,22 +455,15 @@ mod bench { let reader = index.reader().unwrap(); b.iter(|| { - let agg_req_1: Aggregations = vec![( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { - field: "score_f64".to_string(), - interval: 100f64, // 1000 buckets - ..Default::default() - }), - sub_aggregation: Default::default(), - } - .into(), - ), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "histogram": { + "field": "score_f64", + "interval": 100 // 1000 buckets + }, + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -630,43 +485,23 @@ mod bench { IndexRecordOption::Basic, ); - let sub_agg_req_1: Aggregations = vec![( - "average_in_range".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); - - let agg_req_1: Aggregations = vec![ - ( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ( - "rangef64".to_string(), - Aggregation::Bucket( - BucketAggregation { - bucket_agg: BucketAggregationType::Range(RangeAggregation { - field: "score_f64".to_string(), - ranges: vec![ - (3f64..7000f64).into(), - (7000f64..20000f64).into(), - (20000f64..60000f64).into(), - ], - ..Default::default() - }), - sub_aggregation: sub_agg_req_1, - } - .into(), - ), - ), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "rangef64": { + "range": { + "field": "score_f64", + "ranges": [ + { "from": 3, "to": 7000 }, + { "from": 7000, "to": 20000 }, + { "from": 20000, "to": 60000 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "score" } } + } + }, + "average": { "avg": { "field": "score" } } + })) + .unwrap(); let collector = get_collector(agg_req_1); diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 8e3974062..1aa7fa6a5 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -37,7 +37,6 @@ use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, }; -use super::VecWithNames; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user /// defined names. It is also used in [buckets](BucketAggregation) to define sub-aggregations. @@ -45,73 +44,30 @@ use super::VecWithNames; /// The key is the user defined name of the aggregation. pub type Aggregations = HashMap; -/// Like Aggregations, but optimized to work with the aggregation result -#[derive(Clone, Debug)] -pub(crate) struct AggregationsInternal { - pub(crate) metrics: VecWithNames, - pub(crate) buckets: VecWithNames, +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// Aggregation request. +/// +/// An aggregation is either a bucket or a metric. +pub struct Aggregation { + /// The aggregation variant, which can be either a bucket or a metric. + #[serde(flatten)] + pub agg: AggregationVariants, + /// The sub_aggregations, only valid for bucket type aggregations. Each bucket will aggregate + /// on the document set in the bucket. + #[serde(rename = "aggs")] + #[serde(default)] + #[serde(skip_serializing_if = "Aggregations::is_empty")] + pub sub_aggregation: Aggregations, } -impl From for AggregationsInternal { - fn from(aggs: Aggregations) -> Self { - let mut metrics = vec![]; - let mut buckets = vec![]; - for (key, agg) in aggs { - match agg { - Aggregation::Bucket(bucket) => { - let sub_aggregation = bucket.get_sub_aggs().clone().into(); - buckets.push(( - key, - BucketAggregationInternal { - bucket_agg: bucket.bucket_agg, - sub_aggregation, - }, - )) - } - Aggregation::Metric(metric) => metrics.push((key, metric)), - } - } - Self { - metrics: VecWithNames::from_entries(metrics), - buckets: VecWithNames::from_entries(buckets), - } - } -} - -#[derive(Clone, Debug)] -// Like BucketAggregation, but optimized to work with the result -pub(crate) struct BucketAggregationInternal { - /// Bucket aggregation strategy to group documents. - pub bucket_agg: BucketAggregationType, - /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the - /// bucket. - sub_aggregation: AggregationsInternal, -} - -impl BucketAggregationInternal { - pub(crate) fn sub_aggregation(&self) -> &AggregationsInternal { +impl Aggregation { + pub(crate) fn sub_aggregation(&self) -> &Aggregations { &self.sub_aggregation } - pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { - match &self.bucket_agg { - BucketAggregationType::Range(range) => Some(range), - _ => None, - } - } - pub(crate) fn as_histogram(&self) -> crate::Result> { - match &self.bucket_agg { - BucketAggregationType::Histogram(histogram) => Ok(Some(histogram.clone())), - BucketAggregationType::DateHistogram(histogram) => { - Ok(Some(histogram.to_histogram_req()?)) - } - _ => Ok(None), - } - } - pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { - match &self.bucket_agg { - BucketAggregationType::Terms(terms) => Some(terms), - _ => None, - } + + fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { + fast_field_names.insert(self.agg.get_fast_field_name().to_string()); + fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); } } @@ -124,100 +80,24 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { fast_field_names } -/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`]. -/// -/// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum Aggregation { - /// Bucket aggregation, see [`BucketAggregation`] for details. - Bucket(Box), - /// Metric aggregation, see [`MetricAggregation`] for details. - Metric(MetricAggregation), -} - -impl Aggregation { - fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - match self { - Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names), - Aggregation::Metric(metric) => { - fast_field_names.insert(metric.get_fast_field_name().to_string()); - } - } - } -} - -/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which -/// determines whether or not a document in the falls into it. In other words, the buckets -/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can -/// fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also -/// compute and return the number of documents for each bucket. Bucket aggregations, as opposed to -/// metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for -/// the buckets created by their "parent" bucket aggregation. There are different bucket -/// aggregators, each with a different "bucketing" strategy. Some define a single bucket, some -/// define fixed number of multiple buckets, and others dynamically create the buckets during the -/// aggregation process. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct BucketAggregation { - /// Bucket aggregation strategy to group documents. - #[serde(flatten)] - pub bucket_agg: BucketAggregationType, - /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the - /// bucket. - #[serde(rename = "aggs")] - #[serde(default)] - #[serde(skip_serializing_if = "Aggregations::is_empty")] - pub sub_aggregation: Aggregations, -} - -impl BucketAggregation { - pub(crate) fn get_sub_aggs(&self) -> &Aggregations { - &self.sub_aggregation - } - fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - let fast_field_name = self.bucket_agg.get_fast_field_name(); - fast_field_names.insert(fast_field_name.to_string()); - fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); - } -} - -/// The bucket aggregation types. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum BucketAggregationType { +/// All aggregation types. +pub enum AggregationVariants { + // Bucket aggregation types /// Put data into buckets of user-defined ranges. #[serde(rename = "range")] Range(RangeAggregation), - /// Put data into buckets of user-defined ranges. + /// Put data into a histogram. #[serde(rename = "histogram")] Histogram(HistogramAggregation), - /// Put data into buckets of user-defined ranges. + /// Put data into a date histogram. #[serde(rename = "date_histogram")] DateHistogram(DateHistogramAggregationReq), /// Put data into buckets of terms. #[serde(rename = "terms")] Terms(TermsAggregation), -} -impl BucketAggregationType { - fn get_fast_field_name(&self) -> &str { - match self { - BucketAggregationType::Terms(terms) => terms.field.as_str(), - BucketAggregationType::Range(range) => range.field.as_str(), - BucketAggregationType::Histogram(histogram) => histogram.field.as_str(), - BucketAggregationType::DateHistogram(histogram) => histogram.field.as_str(), - } - } -} - -/// The aggregations in this family compute metrics based on values extracted -/// from the documents that are being aggregated. Values are extracted from the fast field of -/// the document. - -/// Some aggregations output a single numeric metric (e.g. Average) and are called -/// single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are -/// called multi-value numeric metrics aggregation. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum MetricAggregation { + // Metric aggregation types /// Computes the average of the extracted values. #[serde(rename = "avg")] Average(AverageAggregation), @@ -242,31 +122,102 @@ pub enum MetricAggregation { Percentiles(PercentilesAggregationReq), } -impl MetricAggregation { - pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { +impl AggregationVariants { + fn get_fast_field_name(&self) -> &str { + match self { + AggregationVariants::Terms(terms) => terms.field.as_str(), + AggregationVariants::Range(range) => range.field.as_str(), + AggregationVariants::Histogram(histogram) => histogram.field.as_str(), + AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), + AggregationVariants::Average(avg) => avg.field_name(), + AggregationVariants::Count(count) => count.field_name(), + AggregationVariants::Max(max) => max.field_name(), + AggregationVariants::Min(min) => min.field_name(), + AggregationVariants::Stats(stats) => stats.field_name(), + AggregationVariants::Sum(sum) => sum.field_name(), + AggregationVariants::Percentiles(per) => per.field_name(), + } + } + + pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { match &self { - MetricAggregation::Percentiles(percentile_req) => Some(percentile_req), + AggregationVariants::Range(range) => Some(range), + _ => None, + } + } + pub(crate) fn as_histogram(&self) -> crate::Result> { + match &self { + AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())), + AggregationVariants::DateHistogram(histogram) => { + Ok(Some(histogram.to_histogram_req()?)) + } + _ => Ok(None), + } + } + pub(crate) fn as_term(&self) -> Option<&TermsAggregation> { + match &self { + AggregationVariants::Terms(terms) => Some(terms), _ => None, } } - fn get_fast_field_name(&self) -> &str { - match self { - MetricAggregation::Average(avg) => avg.field_name(), - MetricAggregation::Count(count) => count.field_name(), - MetricAggregation::Max(max) => max.field_name(), - MetricAggregation::Min(min) => min.field_name(), - MetricAggregation::Stats(stats) => stats.field_name(), - MetricAggregation::Sum(sum) => sum.field_name(), - MetricAggregation::Percentiles(per) => per.field_name(), + pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { + match &self { + AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), + _ => None, } } } #[cfg(test)] mod tests { + use super::*; + #[test] + fn deser_json_test() { + let agg_req_json = r#"{ + "price_avg": { "avg": { "field": "price" } }, + "price_count": { "value_count": { "field": "price" } }, + "price_max": { "max": { "field": "price" } }, + "price_min": { "min": { "field": "price" } }, + "price_stats": { "stats": { "field": "price" } }, + "price_sum": { "sum": { "field": "price" } } + }"#; + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } + + #[test] + fn deser_json_test_bucket() { + let agg_req_json = r#" + { + "termagg": { + "terms": { + "field": "json.mixed_type", + "order": { "min_price": "desc" } + }, + "aggs": { + "min_price": { "min": { "field": "json.mixed_type" } } + } + }, + "rangeagg": { + "range": { + "field": "json.mixed_type", + "ranges": [ + { "to": 3.0 }, + { "from": 19.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "json.mixed_type" } } + } + } + } "#; + + let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); + } + #[test] fn test_metric_aggregations_deser() { let agg_req_json = r#"{ @@ -280,22 +231,22 @@ mod tests { let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap(); assert!( - matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price") + matches!(&agg_req.get("price_avg").unwrap().agg, AggregationVariants::Average(avg) if avg.field == "price") ); assert!( - matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price") + matches!(&agg_req.get("price_count").unwrap().agg, AggregationVariants::Count(count) if count.field == "price") ); assert!( - matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price") + matches!(&agg_req.get("price_max").unwrap().agg, AggregationVariants::Max(max) if max.field == "price") ); assert!( - matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price") + matches!(&agg_req.get("price_min").unwrap().agg, AggregationVariants::Min(min) if min.field == "price") ); assert!( - matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price") + matches!(&agg_req.get("price_stats").unwrap().agg, AggregationVariants::Stats(stats) if stats.field == "price") ); assert!( - matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price") + matches!(&agg_req.get("price_sum").unwrap().agg, AggregationVariants::Sum(sum) if sum.field == "price") ); } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 7b26e2f32..1ac2b7da6 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -2,7 +2,7 @@ use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; -use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; +use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; @@ -16,34 +16,107 @@ use crate::SegmentReader; #[derive(Clone, Default)] pub(crate) struct AggregationsWithAccessor { - pub metrics: VecWithNames, - pub buckets: VecWithNames, + pub aggs: VecWithNames, } impl AggregationsWithAccessor { - fn from_data( - metrics: VecWithNames, - buckets: VecWithNames, - ) -> Self { - Self { metrics, buckets } + fn from_data(aggs: VecWithNames) -> Self { + Self { aggs } } pub fn is_empty(&self) -> bool { - self.metrics.is_empty() && self.buckets.is_empty() + self.aggs.is_empty() } } #[derive(Clone)] -pub struct BucketAggregationWithAccessor { +pub struct AggregationWithAccessor { /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. So eventually this needs to be Option or moved. pub(crate) accessor: Column, pub(crate) str_dict_column: Option, pub(crate) field_type: ColumnType, - pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, pub(crate) limits: AggregationLimits, pub(crate) column_block_accessor: ColumnBlockAccessor, + pub(crate) agg: Aggregation, +} + +impl AggregationWithAccessor { + fn try_from_agg( + agg: &Aggregation, + sub_aggregation: &Aggregations, + reader: &SegmentReader, + limits: AggregationLimits, + ) -> crate::Result { + let mut str_dict_column = None; + use AggregationVariants::*; + let (accessor, field_type) = match &agg.agg { + Range(RangeAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + Histogram(HistogramAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + DateHistogram(DateHistogramAggregationReq { + field: field_name, .. + }) => get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?, + Terms(TermsAggregation { + field: field_name, .. + }) => { + str_dict_column = reader.fast_fields().str(field_name)?; + get_ff_reader_and_validate(reader, field_name, None)? + } + Average(AverageAggregation { field: field_name }) + | Count(CountAggregation { field: field_name }) + | Max(MaxAggregation { field: field_name }) + | Min(MinAggregation { field: field_name }) + | Stats(StatsAggregation { field: field_name }) + | Sum(SumAggregation { field: field_name }) => { + let (accessor, field_type) = get_ff_reader_and_validate( + reader, + field_name, + Some(get_numeric_or_date_column_types()), + )?; + + (accessor, field_type) + } + Percentiles(percentiles) => { + let (accessor, field_type) = get_ff_reader_and_validate( + reader, + percentiles.field_name(), + Some(get_numeric_or_date_column_types()), + )?; + (accessor, field_type) + } + }; + let sub_aggregation = sub_aggregation.clone(); + Ok(AggregationWithAccessor { + accessor, + field_type, + sub_aggregation: get_aggs_with_accessor_and_validate( + &sub_aggregation, + reader, + &limits.clone(), + )?, + agg: agg.clone(), + str_dict_column, + limits, + column_block_accessor: Default::default(), + }) + } } fn get_numeric_or_date_column_types() -> &'static [ColumnType] { @@ -55,140 +128,25 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] { ] } -impl BucketAggregationWithAccessor { - fn try_from_bucket( - bucket: &BucketAggregationType, - sub_aggregation: &Aggregations, - reader: &SegmentReader, - limits: AggregationLimits, - ) -> crate::Result { - let mut str_dict_column = None; - let (accessor, field_type) = match &bucket { - BucketAggregationType::Range(RangeAggregation { - field: field_name, .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Histogram(HistogramAggregation { - field: field_name, .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::DateHistogram(DateHistogramAggregationReq { - field: field_name, - .. - }) => get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?, - BucketAggregationType::Terms(TermsAggregation { - field: field_name, .. - }) => { - str_dict_column = reader.fast_fields().str(field_name)?; - get_ff_reader_and_validate(reader, field_name, None)? - } - }; - let sub_aggregation = sub_aggregation.clone(); - Ok(BucketAggregationWithAccessor { - accessor, - field_type, - sub_aggregation: get_aggs_with_accessor_and_validate( - &sub_aggregation, - reader, - &limits.clone(), - )?, - bucket_agg: bucket.clone(), - str_dict_column, - limits, - column_block_accessor: Default::default(), - }) - } -} - -/// Contains the metric request and the fast field accessor. -#[derive(Clone)] -pub struct MetricAggregationWithAccessor { - pub metric: MetricAggregation, - pub field_type: ColumnType, - pub accessor: Column, - pub column_block_accessor: ColumnBlockAccessor, -} - -impl MetricAggregationWithAccessor { - fn try_from_metric( - metric: &MetricAggregation, - reader: &SegmentReader, - ) -> crate::Result { - match &metric { - MetricAggregation::Average(AverageAggregation { field: field_name }) - | MetricAggregation::Count(CountAggregation { field: field_name }) - | MetricAggregation::Max(MaxAggregation { field: field_name }) - | MetricAggregation::Min(MinAggregation { field: field_name }) - | MetricAggregation::Stats(StatsAggregation { field: field_name }) - | MetricAggregation::Sum(SumAggregation { field: field_name }) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - field_name, - Some(get_numeric_or_date_column_types()), - )?; - - Ok(MetricAggregationWithAccessor { - accessor, - field_type, - metric: metric.clone(), - column_block_accessor: Default::default(), - }) - } - MetricAggregation::Percentiles(percentiles) => { - let (accessor, field_type) = get_ff_reader_and_validate( - reader, - percentiles.field_name(), - Some(get_numeric_or_date_column_types()), - )?; - - Ok(MetricAggregationWithAccessor { - accessor, - field_type, - metric: metric.clone(), - column_block_accessor: Default::default(), - }) - } - } - } -} - pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, limits: &AggregationLimits, ) -> crate::Result { - let mut metrics = vec![]; - let mut buckets = vec![]; + let mut aggss = Vec::new(); for (key, agg) in aggs.iter() { - match agg { - Aggregation::Bucket(bucket) => buckets.push(( - key.to_string(), - BucketAggregationWithAccessor::try_from_bucket( - &bucket.bucket_agg, - bucket.get_sub_aggs(), - reader, - limits.clone(), - )?, - )), - Aggregation::Metric(metric) => metrics.push(( - key.to_string(), - MetricAggregationWithAccessor::try_from_metric(metric, reader)?, - )), - } + aggss.push(( + key.to_string(), + AggregationWithAccessor::try_from_agg( + agg, + agg.sub_aggregation(), + reader, + limits.clone(), + )?, + )); } Ok(AggregationsWithAccessor::from_data( - VecWithNames::from_entries(metrics), - VecWithNames::from_entries(buckets), + VecWithNames::from_entries(aggss), )) } diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 09068f51d..bb95858ba 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -7,11 +7,8 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use super::agg_req::BucketAggregationInternal; use super::bucket::GetDocCount; -use super::intermediate_agg_result::IntermediateBucketResult; use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats}; -use super::segment_agg_result::AggregationLimits; use super::{AggregationError, Key}; use crate::TantivyError; @@ -164,14 +161,6 @@ impl BucketResult { } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), } } - - pub(crate) fn empty_from_req( - req: &BucketAggregationInternal, - limits: &AggregationLimits, - ) -> crate::Result { - let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - empty_bucket.into_final_bucket_result(req, limits) - } } /// This is the wrapper of buckets entries, which can be vector or hashmap diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 2da0ce36d..ad1dc2494 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -1,11 +1,10 @@ use serde_json::Value; -use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; +use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; -use crate::aggregation::metric::AverageAggregation; use crate::aggregation::segment_agg_result::AggregationLimits; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; @@ -14,9 +13,12 @@ use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::{Index, Term}; fn get_avg_req(field_name: &str) -> Aggregation { - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name(field_name.to_string()), - )) + serde_json::from_value(json!({ + "avg": { + "field": field_name, + } + })) + .unwrap() } fn get_collector(agg_req: Aggregations) -> AggregationCollector { @@ -195,6 +197,74 @@ fn test_aggregation_flushing_variants() { test_aggregation_flushing(true, true).unwrap(); } +#[test] +fn test_aggregation_level1_simple() -> crate::Result<()> { + let index = get_test_index_2_segments(true)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let range_agg = |field_name: &str| -> Aggregation { + serde_json::from_value(json!({ + "range": { + "field": field_name, + "ranges": [ { "from": 3.0f64, "to": 7.0f64 }, { "from": 7.0f64, "to": 20.0f64 } ] + } + })) + .unwrap() + }; + + let agg_req_1: Aggregations = vec![ + ("average".to_string(), get_avg_req("score")), + ("range".to_string(), range_agg("score")), + ] + .into_iter() + .collect(); + + let collector = get_collector(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + assert_eq!(res["average"]["value"], 12.142857142857142); + assert_eq!( + res["range"]["buckets"], + json!( + [ + { + "key": "*-3", + "doc_count": 1, + "to": 3.0 + }, + { + "key": "3-7", + "doc_count": 2, + "from": 3.0, + "to": 7.0 + }, + { + "key": "7-20", + "doc_count": 3, + "from": 7.0, + "to": 20.0 + }, + { + "key": "20-*", + "doc_count": 1, + "from": 20.0 + } + ]) + ); + + Ok(()) +} + #[test] fn test_aggregation_level1() -> crate::Result<()> { let index = get_test_index_2_segments(true)?; @@ -449,14 +519,14 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { let reader = index.reader()?; let avg_on_field = |field_name: &str| { - let agg_req_1: Aggregations = vec![( - "average".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name(field_name.to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "average": { + "avg": { + "field": field_name, + }, + } + })) + .unwrap(); let collector = get_collector(agg_req_1); @@ -471,6 +541,32 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { r#"InvalidArgument("Field \"dummy_text\" is not configured as fast field")"# ); + let agg_req_1: Result = serde_json::from_value(json!({ + "average": { + "avg": { + "fieldd": "a", + }, + } + })); + + assert_eq!(agg_req_1.is_err(), true); + assert_eq!(agg_req_1.unwrap_err().to_string(), "missing field `field`"); + + let agg_req_1: Result = serde_json::from_value(json!({ + "average": { + "doesnotmatchanyagg": { + "field": "a", + }, + } + })); + + assert_eq!(agg_req_1.is_err(), true); + // TODO: This should list valid values + assert_eq!( + agg_req_1.unwrap_err().to_string(), + "no variant of enum AggregationVariants found in flattened data" + ); + // TODO: This should return an error // let agg_res = avg_on_field("not_exist_field").unwrap_err(); // assert_eq!( diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 4056bcb4e..7d6112f59 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -8,18 +8,19 @@ use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; use crate::aggregation::agg_limits::MemoryConsumption; -use crate::aggregation::agg_req::AggregationsInternal; +use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateHistogramBucketEntry, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, format_date}; use crate::TantivyError; /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. @@ -190,11 +191,13 @@ impl SegmentHistogramBucketEntry { sub_aggregation: Box, agg_with_accessor: &AggregationsWithAccessor, ) -> crate::Result { + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + sub_aggregation + .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?; Ok(IntermediateHistogramBucketEntry { key: self.key, doc_count: self.doc_count, - sub_aggregation: sub_aggregation - .into_intermediate_aggregations_result(agg_with_accessor)?, + sub_aggregation: sub_aggregation_res, }) } } @@ -215,20 +218,18 @@ pub struct SegmentHistogramCollector { } impl SegmentAggregationCollector for SegmentHistogramCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] @@ -246,7 +247,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; let mem_pre = self.get_memory_consumption(); @@ -280,7 +281,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits; + let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; limits.add_memory_consumed(mem_delta as u64); limits.validate_memory_consumption()?; @@ -289,7 +290,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; for sub_aggregation in self.sub_aggregations.values_mut() { sub_aggregation.flush(sub_aggregation_accessor)?; @@ -308,7 +309,7 @@ impl SegmentHistogramCollector { } pub fn into_intermediate_bucket_result( self, - agg_with_accessor: &BucketAggregationWithAccessor, + agg_with_accessor: &AggregationWithAccessor, ) -> crate::Result { let mut buckets = Vec::with_capacity(self.buckets.len()); @@ -384,7 +385,7 @@ fn get_bucket_key_from_pos(bucket_pos: f64, interval: f64, offset: f64) -> f64 { fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, - sub_aggregation: &AggregationsInternal, + sub_aggregation: &Aggregations, limits: &AggregationLimits, ) -> crate::Result> { // Generate the full list of buckets without gaps. @@ -443,7 +444,7 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( buckets: Vec, column_type: Option, histogram_req: &HistogramAggregation, - sub_aggregation: &AggregationsInternal, + sub_aggregation: &Aggregations, limits: &AggregationLimits, ) -> crate::Result> { let mut buckets = if histogram_req.min_doc_count() == 0 { @@ -695,7 +696,7 @@ mod tests { assert_eq!( res.to_string(), "Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \ - 102.48 KB" + 59.71 KB" ); Ok(()) diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index 3d178d962..3ccc53e97 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -2,6 +2,16 @@ //! //! BucketAggregations create buckets of documents //! [`BucketAggregation`](super::agg_req::BucketAggregation). +//! Each bucket is associated with a rule which +//! determines whether or not a document in the falls into it. In other words, the buckets +//! effectively define document sets. Buckets are not necessarily disjunct, therefore a document can +//! fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also +//! compute and return the number of documents for each bucket. Bucket aggregations, as opposed to +//! metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for +//! the buckets created by their "parent" bucket aggregation. There are different bucket +//! aggregators, each with a different "bucketing" strategy. Some define a single bucket, some +//! define fixed number of multiple buckets, and others dynamically create the buckets during the +//! aggregation process. //! //! Results of final buckets are [`BucketResult`](super::agg_result::BucketResult). //! Results of intermediate buckets are diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 591db3831..82c4cbddc 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -7,14 +7,14 @@ use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, - IntermediateRangeBucketResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, }; use crate::aggregation::{ - f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames, + f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, }; use crate::TantivyError; @@ -157,8 +157,10 @@ impl SegmentRangeBucketEntry { self, agg_with_accessor: &AggregationsWithAccessor, ) -> crate::Result { - let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregation { - sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)? + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + if let Some(sub_aggregation) = self.sub_aggregation { + sub_aggregation + .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)? } else { Default::default() }; @@ -166,7 +168,7 @@ impl SegmentRangeBucketEntry { Ok(IntermediateRangeBucketEntry { key: self.key, doc_count: self.doc_count, - sub_aggregation, + sub_aggregation: sub_aggregation_res, from: self.from, to: self.to, }) @@ -174,13 +176,14 @@ impl SegmentRangeBucketEntry { } impl SegmentAggregationCollector for SegmentRangeCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { let field_type = self.column_type; - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let sub_agg = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; let buckets: FxHashMap = self .buckets @@ -200,12 +203,9 @@ impl SegmentAggregationCollector for SegmentRangeCollector { column_type: Some(self.column_type), }); - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] @@ -223,7 +223,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; bucket_agg_accessor .column_block_accessor @@ -245,7 +245,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; for bucket in self.buckets.iter_mut() { if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 329ea120b..0a86ecd67 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -7,16 +7,16 @@ use serde::{Deserialize, Serialize}; use super::{CustomOrder, Order, OrderTarget}; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateTermBucketEntry, - IntermediateTermBucketResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateTermBucketEntry, IntermediateTermBucketResult, }; use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, Key}; use crate::error::DataCorruption; use crate::TantivyError; @@ -246,20 +246,18 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { } impl SegmentAggregationCollector for SegmentTermCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx]; + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; - let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)])); + results.push(name, IntermediateAggregationResult::Bucket(bucket)); - Ok(IntermediateAggregationResults { - metrics: None, - buckets, - }) + Ok(()) } #[inline] @@ -277,7 +275,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx]; + let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; let mem_pre = self.get_memory_consumption(); @@ -301,7 +299,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { } let mem_delta = self.get_memory_consumption() - mem_pre; - let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits; + let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits; limits.add_memory_consumed(mem_delta as u64); limits.validate_memory_consumption()?; @@ -310,7 +308,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { let sub_aggregation_accessor = - &mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; + &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; self.term_buckets.force_flush(sub_aggregation_accessor)?; Ok(()) @@ -337,7 +335,7 @@ impl SegmentTermCollector { if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target { let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); - sub_aggregations.metrics.get(agg_name).ok_or_else(|| { + sub_aggregations.aggs.get(agg_name).ok_or_else(|| { TantivyError::InvalidArgument(format!( "could not find aggregation with name {} in metric sub_aggregations", agg_name @@ -366,7 +364,7 @@ impl SegmentTermCollector { #[inline] pub(crate) fn into_intermediate_bucket_result( mut self, - agg_with_accessor: &BucketAggregationWithAccessor, + agg_with_accessor: &AggregationWithAccessor, ) -> crate::Result { let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect(); @@ -410,21 +408,24 @@ impl SegmentTermCollector { let mut into_intermediate_bucket_entry = |id, doc_count| -> crate::Result { let intermediate_entry = if self.blueprint.as_ref().is_some() { + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + self.term_buckets + .sub_aggs + .remove(&id) + .unwrap_or_else(|| { + panic!( + "Internal Error: could not find subaggregation for id {}", + id + ) + }) + .add_intermediate_aggregation_result( + &agg_with_accessor.sub_aggregation, + &mut sub_aggregation_res, + )?; + IntermediateTermBucketEntry { doc_count, - sub_aggregation: self - .term_buckets - .sub_aggs - .remove(&id) - .unwrap_or_else(|| { - panic!( - "Internal Error: could not find subaggregation for id {}", - id - ) - }) - .into_intermediate_aggregations_result( - &agg_with_accessor.sub_aggregation, - )?, + sub_aggregation: sub_aggregation_res, } } else { IntermediateTermBucketEntry { @@ -522,8 +523,7 @@ pub(crate) fn cut_off_buckets( #[cfg(test)] mod tests { - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; - use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::tests::{ exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_from_terms, get_test_index_from_values_and_terms, @@ -638,23 +638,6 @@ mod tests { ]; let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?; - let _sub_agg: Aggregations = vec![ - ( - "avg_score".to_string(), - Aggregation::Metric(MetricAggregation::Average( - AverageAggregation::from_field_name("score".to_string()), - )), - ), - ( - "stats_score".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - ), - ] - .into_iter() - .collect(); - let sub_agg: Aggregations = serde_json::from_value(json!({ "avg_score": { "avg": { diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs index d8ec399bc..15be6281b 100644 --- a/src/aggregation/buf_collector.rs +++ b/src/aggregation/buf_collector.rs @@ -35,11 +35,12 @@ impl BufAggregationCollector { impl SegmentAggregationCollector for BufAggregationCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor) + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results) } #[inline] diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index dc80a7e53..183cc2425 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -184,6 +184,13 @@ impl SegmentCollector for AggregationSegmentCollector { return Err(err); } self.agg_collector.flush(&mut self.aggs_with_accessor)?; - Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor) + + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + Box::new(self.agg_collector).add_intermediate_aggregation_result( + &self.aggs_with_accessor, + &mut sub_aggregation_res, + )?; + + Ok(sub_aggregation_res) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 182fc90ab..8956d8af9 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -10,10 +10,7 @@ use rustc_hash::FxHashMap; use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use super::agg_req::{ - Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType, - MetricAggregation, -}; +use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry}; use super::bucket::{ cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets, @@ -34,20 +31,22 @@ use crate::TantivyError; /// intermediate results. #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateAggregationResults { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) metrics: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) buckets: Option>, + pub(crate) aggs_res: VecWithNames, } impl IntermediateAggregationResults { + /// Add a result + pub fn push(&mut self, key: String, value: IntermediateAggregationResult) { + self.aggs_res.push(key, value); + } + /// Convert intermediate result and its aggregation request to the final result. pub fn into_final_result( self, req: Aggregations, limits: &AggregationLimits, ) -> crate::Result { - let res = self.into_final_result_internal(&(req.into()), limits)?; + let res = self.into_final_result_internal(&req, limits)?; let bucket_count = res.get_bucket_count() as u32; if bucket_count > limits.get_bucket_limit() { return Err(TantivyError::AggregationError( @@ -66,67 +65,35 @@ impl IntermediateAggregationResults { /// for internal processing, by splitting metric and buckets into separate groups. pub(crate) fn into_final_result_internal( self, - req: &AggregationsInternal, + req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { - // Important assumption: - // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the - // request let mut results: FxHashMap = FxHashMap::default(); - - if let Some(buckets) = self.buckets { - convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)? - } else { - // When there are no buckets, we create empty buckets, so that the serialized json - // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)? - }; - - if let Some(metrics) = self.metrics { - convert_and_add_final_metrics_to_result(&mut results, metrics, &req.metrics); - } else { - // When there are no metrics, we create empty metric results, so that the serialized - // json format is constant - add_empty_final_metrics_to_result(&mut results, &req.metrics)?; + for (key, agg_res) in self.aggs_res.into_iter() { + let req = req.get(key.as_str()).unwrap(); + results.insert(key, agg_res.into_final_result(req, limits)?); + } + // Handle empty results + if results.len() != req.len() { + for (key, req) in req.iter() { + if !results.contains_key(key) { + let empty_res = empty_from_req(req); + results.insert(key.to_string(), empty_res.into_final_result(req, limits)?); + } + } } Ok(AggregationResults(results)) } - pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { - let metrics = if req.metrics.is_empty() { - None - } else { - let metrics = req - .metrics - .iter() - .map(|(key, req)| { - ( - key.to_string(), - IntermediateMetricResult::empty_from_req(req), - ) - }) - .collect(); - Some(VecWithNames::from_entries(metrics)) - }; + pub(crate) fn empty_from_req(req: &Aggregations) -> Self { + let mut aggs_res: VecWithNames = VecWithNames::default(); + for (key, req) in req.iter() { + let empty_res = empty_from_req(req); + aggs_res.push(key.to_string(), empty_res); + } - let buckets = if req.buckets.is_empty() { - None - } else { - let buckets = req - .buckets - .iter() - .map(|(key, req)| { - ( - key.to_string(), - IntermediateBucketResult::empty_from_req(&req.bucket_agg), - ) - }) - .collect(); - Some(VecWithNames::from_entries(buckets)) - }; - - Self { metrics, buckets } + Self { aggs_res } } /// Merge another intermediate aggregation result into this result. @@ -134,85 +101,50 @@ impl IntermediateAggregationResults { /// The order of the values need to be the same on both results. This is ensured when the same /// (key values) are present on the underlying `VecWithNames` struct. pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> { - if let (Some(buckets_left), Some(buckets_right)) = (&mut self.buckets, other.buckets) { - for (bucket_left, bucket_right) in - buckets_left.values_mut().zip(buckets_right.into_values()) - { - bucket_left.merge_fruits(bucket_right)?; - } - } - - if let (Some(metrics_left), Some(metrics_right)) = (&mut self.metrics, other.metrics) { - for (metric_left, metric_right) in - metrics_left.values_mut().zip(metrics_right.into_values()) - { - metric_left.merge_fruits(metric_right)?; - } + for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) { + left.merge_fruits(right)?; } Ok(()) } } -fn convert_and_add_final_metrics_to_result( - results: &mut FxHashMap, - metrics: VecWithNames, - metrics_req: &VecWithNames, -) { - let metric_result_with_request = metrics.into_iter().zip(metrics_req.values()); - results.extend( - metric_result_with_request - .into_iter() - .map(|((key, metric), req)| { - ( - key, - AggregationResult::MetricResult(metric.into_final_metric_result(req)), - ) - }), - ); -} - -fn add_empty_final_metrics_to_result( - results: &mut FxHashMap, - req_metrics: &VecWithNames, -) -> crate::Result<()> { - results.extend(req_metrics.iter().map(|(key, req)| { - let empty_bucket = IntermediateMetricResult::empty_from_req(req); - ( - key.to_string(), - AggregationResult::MetricResult(empty_bucket.into_final_metric_result(req)), - ) - })); - Ok(()) -} - -fn add_empty_final_buckets_to_result( - results: &mut FxHashMap, - req_buckets: &VecWithNames, - limits: &AggregationLimits, -) -> crate::Result<()> { - let requested_buckets = req_buckets.iter(); - for (key, req) in requested_buckets { - let empty_bucket = - AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?); - results.insert(key.to_string(), empty_bucket); +pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult { + use AggregationVariants::*; + match req.agg { + Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms( + Default::default(), + )), + Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + Default::default(), + )), + Histogram(_) | DateHistogram(_) => { + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Histogram { + buckets: Vec::new(), + column_type: None, + }) + } + Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average( + IntermediateAverage::default(), + )), + Count(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Count( + IntermediateCount::default(), + )), + Max(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Max( + IntermediateMax::default(), + )), + Min(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Min( + IntermediateMin::default(), + )), + Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats( + IntermediateStats::default(), + )), + Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum( + IntermediateSum::default(), + )), + Percentiles(_) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::Percentiles(PercentilesCollector::default()), + ), } - Ok(()) -} - -fn convert_and_add_final_buckets_to_result( - results: &mut FxHashMap, - buckets: VecWithNames, - req_buckets: &VecWithNames, - limits: &AggregationLimits, -) -> crate::Result<()> { - assert_eq!(buckets.len(), req_buckets.len()); - - let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); - for ((key, bucket), req) in buckets_with_request { - let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?); - results.insert(key, result); - } - Ok(()) } /// An aggregation is either a bucket or a metric. @@ -224,6 +156,37 @@ pub enum IntermediateAggregationResult { Metric(IntermediateMetricResult), } +impl IntermediateAggregationResult { + pub(crate) fn into_final_result( + self, + req: &Aggregation, + limits: &AggregationLimits, + ) -> crate::Result { + let res = match self { + IntermediateAggregationResult::Bucket(bucket) => { + AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?) + } + IntermediateAggregationResult::Metric(metric) => { + AggregationResult::MetricResult(metric.into_final_metric_result(req)) + } + }; + Ok(res) + } + fn merge_fruits(&mut self, other: IntermediateAggregationResult) -> crate::Result<()> { + match (self, other) { + ( + IntermediateAggregationResult::Bucket(b1), + IntermediateAggregationResult::Bucket(b2), + ) => b1.merge_fruits(b2), + ( + IntermediateAggregationResult::Metric(m1), + IntermediateAggregationResult::Metric(m2), + ) => m1.merge_fruits(m2), + _ => panic!("aggregation result type mismatch (mixed metric and buckets)"), + } + } +} + /// Holds the intermediate data for metric results #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateMetricResult { @@ -244,7 +207,7 @@ pub enum IntermediateMetricResult { } impl IntermediateMetricResult { - fn into_final_metric_result(self, req: &MetricAggregation) -> MetricResult { + fn into_final_metric_result(self, req: &Aggregation) -> MetricResult { match self { IntermediateMetricResult::Average(intermediate_avg) => { MetricResult::Average(intermediate_avg.finalize().into()) @@ -265,30 +228,12 @@ impl IntermediateMetricResult { MetricResult::Sum(intermediate_sum.finalize().into()) } IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles( - percentiles.into_final_result(req.as_percentile().expect("unexpected metric type")), + percentiles + .into_final_result(req.agg.as_percentile().expect("unexpected metric type")), ), } } - pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self { - match req { - MetricAggregation::Average(_) => { - IntermediateMetricResult::Average(IntermediateAverage::default()) - } - MetricAggregation::Count(_) => { - IntermediateMetricResult::Count(IntermediateCount::default()) - } - MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()), - MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()), - MetricAggregation::Stats(_) => { - IntermediateMetricResult::Stats(IntermediateStats::default()) - } - MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()), - MetricAggregation::Percentiles(_) => { - IntermediateMetricResult::Percentiles(PercentilesCollector::default()) - } - } - } fn merge_fruits(&mut self, other: IntermediateMetricResult) -> crate::Result<()> { match (self, other) { ( @@ -355,7 +300,7 @@ pub enum IntermediateBucketResult { impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, - req: &BucketAggregationInternal, + req: &Aggregation, limits: &AggregationLimits, ) -> crate::Result { match self { @@ -366,7 +311,8 @@ impl IntermediateBucketResult { .map(|bucket| { bucket.into_final_bucket_entry( req.sub_aggregation(), - req.as_range() + req.agg + .as_range() .expect("unexpected aggregation, expected histogram aggregation"), range_res.column_type, limits, @@ -381,6 +327,7 @@ impl IntermediateBucketResult { }); let is_keyed = req + .agg .as_range() .expect("unexpected aggregation, expected range aggregation") .keyed; @@ -401,6 +348,7 @@ impl IntermediateBucketResult { buckets, } => { let histogram_req = &req + .agg .as_histogram()? .expect("unexpected aggregation, expected histogram aggregation"); let buckets = intermediate_histogram_buckets_to_final_buckets( @@ -424,7 +372,8 @@ impl IntermediateBucketResult { Ok(BucketResult::Histogram { buckets }) } IntermediateBucketResult::Terms(terms) => terms.into_final_result( - req.as_term() + req.agg + .as_term() .expect("unexpected aggregation, expected term aggregation"), req.sub_aggregation(), limits, @@ -432,18 +381,6 @@ impl IntermediateBucketResult { } } - pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { - match req { - BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()), - BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()), - BucketAggregationType::Histogram(_) | BucketAggregationType::DateHistogram(_) => { - IntermediateBucketResult::Histogram { - buckets: vec![], - column_type: None, - } - } - } - } fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> { match (self, other) { ( @@ -551,7 +488,7 @@ impl IntermediateTermBucketResult { pub(crate) fn into_final_result( self, req: &TermsAggregation, - sub_aggregation_req: &AggregationsInternal, + sub_aggregation_req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); @@ -687,7 +624,7 @@ pub struct IntermediateHistogramBucketEntry { impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, - req: &AggregationsInternal, + req: &Aggregations, limits: &AggregationLimits, ) -> crate::Result { Ok(BucketEntry { @@ -732,7 +669,7 @@ pub struct IntermediateRangeBucketEntry { impl IntermediateRangeBucketEntry { pub(crate) fn into_final_bucket_entry( self, - req: &AggregationsInternal, + req: &Aggregations, _range_req: &RangeAggregation, column_type: Option, limits: &AggregationLimits, @@ -825,14 +762,15 @@ mod tests { } map.insert( "my_agg_level2".to_string(), - IntermediateBucketResult::Range(IntermediateRangeBucketResult { - buckets, - column_type: None, - }), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + IntermediateRangeBucketResult { + buckets, + column_type: None, + }, + )), ); IntermediateAggregationResults { - buckets: Some(VecWithNames::from_entries(map.into_iter().collect())), - metrics: Default::default(), + aggs_res: VecWithNames::from_entries(map.into_iter().collect()), } } @@ -858,14 +796,15 @@ mod tests { } map.insert( "my_agg_level1".to_string(), - IntermediateBucketResult::Range(IntermediateRangeBucketResult { - buckets, - column_type: None, - }), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( + IntermediateRangeBucketResult { + buckets, + column_type: None, + }, + )), ); IntermediateAggregationResults { - buckets: Some(VecWithNames::from_entries(map.into_iter().collect())), - metrics: Default::default(), + aggs_res: VecWithNames::from_entries(map.into_iter().collect()), } } diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index b1d05a05d..50ff0389f 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -1,7 +1,12 @@ //! Module for all metric aggregations. //! -//! The aggregations in this family compute metrics, see [super::agg_req::MetricAggregation] for -//! details. +//! The aggregations in this family compute metrics based on values extracted +//! from the documents that are being aggregated. Values are extracted from the fast field of +//! the document. +//! Some aggregations output a single numeric metric (e.g. Average) and are called +//! single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are +//! called multi-value numeric metrics aggregation. + mod average; mod count; mod max; diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index cf51606a5..db496a7dc 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -5,13 +5,13 @@ use serde::{Deserialize, Serialize}; use super::*; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, MetricAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateMetricResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, AggregationError, VecWithNames}; +use crate::aggregation::{f64_from_fastfield_u64, AggregationError}; use crate::{DocId, TantivyError}; /// # Percentiles @@ -240,7 +240,7 @@ impl SegmentPercentilesCollector { pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut MetricAggregationWithAccessor, + agg_accessor: &mut AggregationWithAccessor, ) { agg_accessor .column_block_accessor @@ -255,22 +255,20 @@ impl SegmentPercentilesCollector { impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); - let metrics = Some(VecWithNames::from_entries(vec![( + results.push( name, - intermediate_metric_result, - )])); + IntermediateAggregationResult::Metric(intermediate_metric_result), + ); - Ok(IntermediateAggregationResults { - metrics, - buckets: None, - }) + Ok(()) } #[inline] @@ -279,7 +277,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { doc: crate::DocId, agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor; + let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; for val in field.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); @@ -295,7 +293,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.metrics.values[self.accessor_idx]; + let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; self.collect_block_with_field(docs, field); Ok(()) } @@ -310,9 +308,8 @@ mod tests { use rand::SeedableRng; use serde_json::Value; - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; + use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::AggregationResults; - use crate::aggregation::metric::PercentilesAggregationReq; use crate::aggregation::tests::{ get_test_index_from_values, get_test_index_from_values_and_terms, }; @@ -326,14 +323,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "percentiles".to_string(), - Aggregation::Metric(MetricAggregation::Percentiles( - PercentilesAggregationReq::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "percentiles": { + "percentiles": { + "field": "score", + } + }, + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -364,14 +361,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "percentiles".to_string(), - Aggregation::Metric(MetricAggregation::Percentiles( - PercentilesAggregationReq::from_field_name("score".to_string()), - )), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "percentiles": { + "percentiles": { + "field": "score", + } + }, + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index b7bfe8f6c..bd63f08dd 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize}; use super::*; use crate::aggregation::agg_req_with_accessor::{ - AggregationsWithAccessor, MetricAggregationWithAccessor, + AggregationWithAccessor, AggregationsWithAccessor, }; +use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateMetricResult, + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, VecWithNames}; use crate::{DocId, TantivyError}; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that @@ -179,7 +179,7 @@ impl SegmentStatsCollector { pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut MetricAggregationWithAccessor, + agg_accessor: &mut AggregationWithAccessor, ) { agg_accessor .column_block_accessor @@ -194,11 +194,12 @@ impl SegmentStatsCollector { impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string(); + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); let intermediate_metric_result = match self.collecting_for { SegmentStatsType::Average => { @@ -219,15 +220,12 @@ impl SegmentAggregationCollector for SegmentStatsCollector { } }; - let metrics = Some(VecWithNames::from_entries(vec![( + results.push( name, - intermediate_metric_result, - )])); + IntermediateAggregationResult::Metric(intermediate_metric_result), + ); - Ok(IntermediateAggregationResults { - metrics, - buckets: None, - }) + Ok(()) } #[inline] @@ -236,7 +234,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector { doc: crate::DocId, agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor; + let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; for val in field.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); @@ -252,7 +250,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.metrics.values[self.accessor_idx]; + let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; self.collect_block_with_field(docs, field); Ok(()) } @@ -263,9 +261,8 @@ mod tests { use serde_json::Value; - use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation}; + use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; - use crate::aggregation::metric::StatsAggregation; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values}; use crate::aggregation::AggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -279,14 +276,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats": { + "stats": { + "field": "score", + }, + } + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -315,14 +312,14 @@ mod tests { let index = get_test_index_from_values(false, &values)?; - let agg_req_1: Aggregations = vec![( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - )] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats": { + "stats": { + "field": "score", + }, + } + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); @@ -374,29 +371,25 @@ mod tests { .unwrap() }; - let agg_req_1: Aggregations = vec![ - ( - "stats_i64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_i64".to_string(), - ))), - ), - ( - "stats_f64".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score_f64".to_string(), - ))), - ), - ( - "stats".to_string(), - Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( - "score".to_string(), - ))), - ), - ("range".to_string(), range_agg), - ] - .into_iter() - .collect(); + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "stats_i64": { + "stats": { + "field": "score_i64", + }, + }, + "stats_f64": { + "stats": { + "field": "score_f64", + }, + }, + "stats": { + "stats": { + "field": "score", + }, + }, + "range": range_agg + })) + .unwrap(); let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 71df51ca1..c47693817 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -49,34 +49,6 @@ //! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an //! `(String, agg_req::Aggregation)` iterator. //! -//! ``` -//! use tantivy::aggregation::agg_req::{Aggregations, Aggregation, MetricAggregation}; -//! use tantivy::aggregation::AggregationCollector; -//! use tantivy::aggregation::metric::AverageAggregation; -//! use tantivy::query::AllQuery; -//! use tantivy::aggregation::agg_result::AggregationResults; -//! use tantivy::IndexReader; -//! -//! # #[allow(dead_code)] -//! fn aggregate_on_index(reader: &IndexReader) { -//! let agg_req: Aggregations = vec![ -//! ( -//! "average".to_string(), -//! Aggregation::Metric(MetricAggregation::Average( -//! AverageAggregation::from_field_name("score".to_string()), -//! )), -//! ), -//! ] -//! .into_iter() -//! .collect(); -//! -//! let collector = AggregationCollector::from_aggs(agg_req, Default::default()); -//! -//! let searcher = reader.searcher(); -//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); -//! } -//! ``` -//! # Example JSON //! Requests are compatible with the elasticsearch JSON request format. //! //! ``` @@ -116,32 +88,24 @@ //! aggregation and then calculate the average on each bucket. //! ``` //! use tantivy::aggregation::agg_req::*; -//! use tantivy::aggregation::metric::AverageAggregation; -//! use tantivy::aggregation::bucket::RangeAggregation; -//! let sub_agg_req_1: Aggregations = vec![( -//! "average_in_range".to_string(), -//! Aggregation::Metric(MetricAggregation::Average( -//! AverageAggregation::from_field_name("score".to_string()), -//! )), -//! )] -//! .into_iter() -//! .collect(); +//! use serde_json::json; //! -//! let agg_req_1: Aggregations = vec![ -//! ( -//! "range".to_string(), -//! Aggregation::Bucket(Box::new(BucketAggregation { -//! bucket_agg: BucketAggregationType::Range(RangeAggregation{ -//! field: "score".to_string(), -//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], -//! keyed: false, -//! }), -//! sub_aggregation: sub_agg_req_1.clone(), -//! })), -//! ), -//! ] -//! .into_iter() -//! .collect(); +//! let agg_req_1: Aggregations = serde_json::from_value(json!({ +//! "rangef64": { +//! "range": { +//! "field": "score", +//! "ranges": [ +//! { "from": 3, "to": 7000 }, +//! { "from": 7000, "to": 20000 }, +//! { "from": 50000, "to": 60000 } +//! ] +//! }, +//! "aggs": { +//! "average_in_range": { "avg": { "field": "score" } } +//! } +//! }, +//! })) +//! .unwrap(); //! ``` //! //! # Distributed Aggregation @@ -216,9 +180,9 @@ impl From> for VecWithNames { } impl VecWithNames { - fn extend(&mut self, entries: VecWithNames) { - self.keys.extend(entries.keys); - self.values.extend(entries.values); + fn push(&mut self, key: String, value: T) { + self.keys.push(key); + self.values.push(value); } fn from_entries(mut entries: Vec<(String, T)>) -> Self { @@ -247,9 +211,6 @@ impl VecWithNames { fn into_values(self) -> impl Iterator { self.values.into_iter() } - fn values(&self) -> impl Iterator + '_ { - self.values.iter() - } fn values_mut(&mut self) -> impl Iterator + '_ { self.values.iter_mut() } diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 074ff782f..1353c56c6 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -6,10 +6,8 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimits; -use super::agg_req::MetricAggregation; -use super::agg_req_with_accessor::{ - AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, -}; +use super::agg_req::AggregationVariants; +use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; use super::intermediate_agg_result::IntermediateAggregationResults; use super::metric::{ @@ -17,14 +15,13 @@ use super::metric::{ SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, SumAggregation, }; -use super::VecWithNames; -use crate::aggregation::agg_req::BucketAggregationType; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result; + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()>; fn collect( &mut self, @@ -66,52 +63,79 @@ impl Clone for Box { pub(crate) fn build_segment_agg_collector( req: &AggregationsWithAccessor, ) -> crate::Result> { - // Single metric special case - if req.buckets.is_empty() && req.metrics.len() == 1 { - let req = &req.metrics.values[0]; + // Single collector special case + if req.aggs.is_empty() && req.aggs.len() == 1 { + let req = &req.aggs.values[0]; let accessor_idx = 0; - return build_metric_segment_agg_collector(req, accessor_idx); - } - - // Single bucket special case - if req.metrics.is_empty() && req.buckets.len() == 1 { - let req = &req.buckets.values[0]; - let accessor_idx = 0; - return build_bucket_segment_agg_collector(req, accessor_idx); + return build_single_agg_segment_collector(req, accessor_idx); } let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?; Ok(Box::new(agg)) } -pub(crate) fn build_metric_segment_agg_collector( - req: &MetricAggregationWithAccessor, +pub(crate) fn build_single_agg_segment_collector( + req: &AggregationWithAccessor, accessor_idx: usize, ) -> crate::Result> { - match &req.metric { - MetricAggregation::Average(AverageAggregation { .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Average, - accessor_idx, - ))) - } - MetricAggregation::Count(CountAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count, accessor_idx), - )), - MetricAggregation::Max(MaxAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max, accessor_idx), - )), - MetricAggregation::Min(MinAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min, accessor_idx), - )), - MetricAggregation::Stats(StatsAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats, accessor_idx), - )), - MetricAggregation::Sum(SumAggregation { .. }) => Ok(Box::new( - SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum, accessor_idx), - )), - MetricAggregation::Percentiles(percentiles_req) => Ok(Box::new( + use AggregationVariants::*; + match &req.agg.agg { + Terms(terms_req) => Ok(Box::new(SegmentTermCollector::from_req_and_validate( + terms_req, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( + range_req, + &req.sub_aggregation, + &req.limits, + req.field_type, + accessor_idx, + )?)), + Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + histogram, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + &histogram.to_histogram_req()?, + &req.sub_aggregation, + req.field_type, + accessor_idx, + )?)), + Average(AverageAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Average, + accessor_idx, + ))), + Count(CountAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Count, + accessor_idx, + ))), + Max(MaxAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Max, + accessor_idx, + ))), + Min(MinAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Min, + accessor_idx, + ))), + Stats(StatsAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Stats, + accessor_idx, + ))), + Sum(SumAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Sum, + accessor_idx, + ))), + Percentiles(percentiles_req) => Ok(Box::new( SegmentPercentilesCollector::from_req_and_validate( percentiles_req, req.field_type, @@ -121,96 +145,33 @@ pub(crate) fn build_metric_segment_agg_collector( } } -pub(crate) fn build_bucket_segment_agg_collector( - req: &BucketAggregationWithAccessor, - accessor_idx: usize, -) -> crate::Result> { - match &req.bucket_agg { - BucketAggregationType::Terms(terms_req) => { - Ok(Box::new(SegmentTermCollector::from_req_and_validate( - terms_req, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Range(range_req) => { - Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - range_req, - &req.sub_aggregation, - &req.limits, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::Histogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - histogram, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - BucketAggregationType::DateHistogram(histogram) => { - Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - &histogram.to_histogram_req()?, - &req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } - } -} - #[derive(Clone, Default)] /// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one /// and can provide specialized versions instead, that remove some of its overhead. pub(crate) struct GenericSegmentAggregationResultsCollector { - pub(crate) metrics: Option>>, - pub(crate) buckets: Option>>, + pub(crate) aggs: Vec>, } impl Debug for GenericSegmentAggregationResultsCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentAggregationResultsCollector") - .field("metrics", &self.metrics) - .field("buckets", &self.buckets) + .field("aggs", &self.aggs) .finish() } } impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { - fn into_intermediate_aggregations_result( + fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result { - let buckets = if let Some(buckets) = self.buckets { - let mut intermeditate_buckets = VecWithNames::default(); - for bucket in buckets { - // TODO too many allocations? - let res = bucket.into_intermediate_aggregations_result(agg_with_accessor)?; - // unwrap is fine since we only have buckets here - intermeditate_buckets.extend(res.buckets.unwrap()); - } - Some(intermeditate_buckets) - } else { - None - }; - let metrics = if let Some(metrics) = self.metrics { - let mut intermeditate_metrics = VecWithNames::default(); - for metric in metrics { - // TODO too many allocations? - let res = metric.into_intermediate_aggregations_result(agg_with_accessor)?; - // unwrap is fine since we only have metrics here - intermeditate_metrics.extend(res.metrics.unwrap()); - } - Some(intermeditate_metrics) - } else { - None - }; + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + for agg in self.aggs { + agg.add_intermediate_aggregation_result(agg_with_accessor, results)?; + } - Ok(IntermediateAggregationResults { metrics, buckets }) + Ok(()) } fn collect( @@ -228,31 +189,16 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { docs: &[crate::DocId], agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { - if let Some(metrics) = self.metrics.as_mut() { - for collector in metrics { - collector.collect_block(docs, agg_with_accessor)?; - } - } - - if let Some(buckets) = self.buckets.as_mut() { - for collector in buckets { - collector.collect_block(docs, agg_with_accessor)?; - } + for collector in &mut self.aggs { + collector.collect_block(docs, agg_with_accessor)?; } Ok(()) } fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { - if let Some(metrics) = &mut self.metrics { - for collector in metrics { - collector.flush(agg_with_accessor)?; - } - } - if let Some(buckets) = &mut self.buckets { - for collector in buckets { - collector.flush(agg_with_accessor)?; - } + for collector in &mut self.aggs { + collector.flush(agg_with_accessor)?; } Ok(()) } @@ -260,34 +206,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { impl GenericSegmentAggregationResultsCollector { pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result { - let buckets = req - .buckets + let aggs = req + .aggs .iter() .enumerate() .map(|(accessor_idx, (_key, req))| { - build_bucket_segment_agg_collector(req, accessor_idx) - }) - .collect::>>>()?; - let metrics = req - .metrics - .iter() - .enumerate() - .map(|(accessor_idx, (_key, req))| { - build_metric_segment_agg_collector(req, accessor_idx) + build_single_agg_segment_collector(req, accessor_idx) }) .collect::>>>()?; - let metrics = if metrics.is_empty() { - None - } else { - Some(metrics) - }; - - let buckets = if buckets.is_empty() { - None - } else { - Some(buckets) - }; - Ok(GenericSegmentAggregationResultsCollector { metrics, buckets }) + Ok(GenericSegmentAggregationResultsCollector { aggs }) } }