diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index aaf4ee760..10b6db989 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -51,6 +51,7 @@ use serde::{Deserialize, Serialize}; use super::bucket::HistogramAggregation; pub use super::bucket::RangeAggregation; use super::metric::{AverageAggregation, StatsAggregation}; +use super::VecWithNames; /// The top-level aggregation request structure, which contains [Aggregation] and their user defined /// names. It is also used in [buckets](BucketAggregation) to define sub-aggregations. @@ -58,6 +59,54 @@ use super::metric::{AverageAggregation, StatsAggregation}; /// The key is the user defined name of the aggregation. pub type Aggregations = HashMap; +/// Like Aggregations, but optimized to work with the aggregation result +#[derive(Clone, Debug)] +pub(crate) struct CollectorAggregations { + pub(crate) metrics: VecWithNames, + pub(crate) buckets: VecWithNames, +} + +impl From for CollectorAggregations { + fn from(aggs: Aggregations) -> Self { + let mut metrics = vec![]; + let mut buckets = vec![]; + for (key, agg) in aggs { + match agg { + Aggregation::Bucket(bucket) => buckets.push(( + key, + CollectorBucketAggregation { + bucket_agg: bucket.bucket_agg, + sub_aggregation: bucket.sub_aggregation.into(), + }, + )), + Aggregation::Metric(metric) => metrics.push((key, metric)), + } + } + Self { + metrics: VecWithNames::from_entries(metrics), + buckets: VecWithNames::from_entries(buckets), + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct CollectorBucketAggregation { + /// Bucket aggregation strategy to group documents. + pub bucket_agg: BucketAggregationType, + /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the + /// bucket. + pub sub_aggregation: CollectorAggregations, +} + +impl CollectorBucketAggregation { + pub(crate) fn as_histogram(&self) -> &HistogramAggregation { + match &self.bucket_agg { + BucketAggregationType::Range(_) => panic!("unexpected aggregation"), + BucketAggregationType::Histogram(histogram) => histogram, + } + } +} + /// Extract all fast field names used in the tree. pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { let mut fast_field_names = Default::default(); diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 4c34d0151..5918389a5 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use super::agg_req::{Aggregations, CollectorAggregations, CollectorBucketAggregation}; use super::bucket::generate_buckets; use super::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, @@ -22,21 +23,52 @@ use super::Key; /// The final aggegation result. pub struct AggregationResults(pub HashMap); -impl From for AggregationResults { - fn from(tree: IntermediateAggregationResults) -> Self { - Self( - tree.buckets - .unwrap_or_default() - .into_iter() - .map(|(key, bucket)| (key, AggregationResult::BucketResult(bucket.into()))) - .chain( - tree.metrics - .unwrap_or_default() - .into_iter() - .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), +impl From<(IntermediateAggregationResults, Aggregations)> for AggregationResults { + fn from(tree_and_req: (IntermediateAggregationResults, Aggregations)) -> Self { + let agg: CollectorAggregations = tree_and_req.1.into(); + (tree_and_req.0, &agg).into() + } +} + +impl From<(IntermediateAggregationResults, &CollectorAggregations)> for AggregationResults { + fn from(data: (IntermediateAggregationResults, &CollectorAggregations)) -> Self { + let tree = data.0; + let req = data.1; + let mut result = HashMap::default(); + + // Important assumption: + // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the + // request + if let Some(buckets) = tree.buckets { + result.extend(buckets.into_iter().zip(req.buckets.values()).map( + |((key, bucket), req)| (key, AggregationResult::BucketResult((bucket, req).into())), + )); + } else { + result.extend(req.buckets.iter().map(|(key, req)| { + let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); + ( + key.to_string(), + AggregationResult::BucketResult((empty_bucket, req).into()), ) - .collect(), - ) + })); + } + + if let Some(metrics) = tree.metrics { + result.extend( + metrics + .into_iter() + .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), + ); + } else { + result.extend(req.metrics.iter().map(|(key, req)| { + let empty_bucket = IntermediateMetricResult::empty_from_req(req); + ( + key.to_string(), + AggregationResult::MetricResult(empty_bucket.into()), + ) + })); + } + Self(result) } } @@ -95,13 +127,15 @@ pub enum BucketResult { }, } -impl From for BucketResult { - fn from(result: IntermediateBucketResult) -> Self { - match result { +impl From<(IntermediateBucketResult, &CollectorBucketAggregation)> for BucketResult { + fn from(result_and_req: (IntermediateBucketResult, &CollectorBucketAggregation)) -> Self { + let bucket_result = result_and_req.0; + let req = result_and_req.1; + match bucket_result { IntermediateBucketResult::Range(range_map) => { let mut buckets: Vec = range_map .into_iter() - .map(|(_, bucket)| bucket.into()) + .map(|(_, bucket)| (bucket, &req.sub_aggregation).into()) .collect_vec(); buckets.sort_by(|a, b| { @@ -112,20 +146,26 @@ impl From for BucketResult { }); BucketResult::Range { buckets } } - IntermediateBucketResult::Histogram { buckets, req } => { - let buckets = if req.min_doc_count() == 0 { + IntermediateBucketResult::Histogram { buckets } => { + let histogram_req = req.as_histogram(); + let buckets = if histogram_req.min_doc_count() == 0 { // With min_doc_count != 0, we may need to add buckets, so that there are no // gaps, since intermediate result does not contain empty buckets (filtered to // reduce serialization size). - let fill_gaps_buckets = if buckets.len() > 1 { - // buckets are sorted + + let (min, max) = if buckets.is_empty() { + (f64::MAX, f64::MIN) + } else { let min = buckets[0].key; let max = buckets[buckets.len() - 1].key; - generate_buckets(&req, min, max) - } else { - vec![] + (min, max) }; + let fill_gaps_buckets = generate_buckets(histogram_req, min, max); + + let sub_aggregation = + IntermediateAggregationResults::empty_from_req(&req.sub_aggregation); + buckets .into_iter() .merge_join_by( @@ -138,21 +178,26 @@ impl From for BucketResult { }, ) .map(|either| match either { - itertools::EitherOrBoth::Both(existing, _) => existing.into(), - itertools::EitherOrBoth::Left(existing) => existing.into(), + itertools::EitherOrBoth::Both(existing, _) => { + (existing, &req.sub_aggregation).into() + } + itertools::EitherOrBoth::Left(existing) => { + (existing, &req.sub_aggregation).into() + } // Add missing bucket itertools::EitherOrBoth::Right(bucket) => BucketEntry { key: Key::F64(bucket), doc_count: 0, - sub_aggregation: Default::default(), + sub_aggregation: (sub_aggregation.clone(), &req.sub_aggregation) + .into(), }, }) .collect_vec() } else { buckets .into_iter() - .filter(|bucket| bucket.doc_count >= req.min_doc_count()) - .map(|bucket| bucket.into()) + .filter(|bucket| bucket.doc_count >= histogram_req.min_doc_count()) + .map(|bucket| (bucket, &req.sub_aggregation).into()) .collect_vec() }; @@ -199,12 +244,14 @@ pub struct BucketEntry { pub sub_aggregation: AggregationResults, } -impl From for BucketEntry { - fn from(entry: IntermediateHistogramBucketEntry) -> Self { +impl From<(IntermediateHistogramBucketEntry, &CollectorAggregations)> for BucketEntry { + fn from(entry_and_req: (IntermediateHistogramBucketEntry, &CollectorAggregations)) -> Self { + let entry = entry_and_req.0; + let req = entry_and_req.1; BucketEntry { key: Key::F64(entry.key), doc_count: entry.doc_count, - sub_aggregation: entry.sub_aggregation.into(), + sub_aggregation: (entry.sub_aggregation, req).into(), } } } @@ -256,12 +303,14 @@ pub struct RangeBucketEntry { pub to: Option, } -impl From for RangeBucketEntry { - fn from(entry: IntermediateRangeBucketEntry) -> Self { +impl From<(IntermediateRangeBucketEntry, &CollectorAggregations)> for RangeBucketEntry { + fn from(entry_and_req: (IntermediateRangeBucketEntry, &CollectorAggregations)) -> Self { + let entry = entry_and_req.0; + let req = entry_and_req.1; RangeBucketEntry { key: entry.key, doc_count: entry.doc_count, - sub_aggregation: entry.sub_aggregation.into(), + sub_aggregation: (entry.sub_aggregation, req).into(), to: entry.to, from: entry.from, } diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 8c56dcbf6..c4ee8da08 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -158,7 +158,7 @@ pub struct SegmentHistogramCollector { buckets: Vec, sub_aggregations: Option>, field_type: Type, - req: HistogramAggregation, + interval: f64, offset: f64, first_bucket_num: i64, bounds: HistogramBounds, @@ -195,10 +195,7 @@ impl SegmentHistogramCollector { ); }; - IntermediateBucketResult::Histogram { - buckets, - req: self.req, - } + IntermediateBucketResult::Histogram { buckets } } pub(crate) fn from_req_and_validate( @@ -247,7 +244,7 @@ impl SegmentHistogramCollector { Ok(Self { buckets, field_type, - req: req.clone(), + interval: req.interval, offset: req.offset.unwrap_or(0f64), first_bucket_num, bounds, @@ -263,7 +260,7 @@ impl SegmentHistogramCollector { force_flush: bool, ) { let bounds = self.bounds; - let interval = self.req.interval; + let interval = self.interval; let offset = self.offset; let first_bucket_num = self.first_bucket_num; let get_bucket_num = @@ -316,12 +313,12 @@ impl SegmentHistogramCollector { if !bounds.contains(val) { continue; } - let bucket_pos = (get_bucket_num_f64(val, self.req.interval, self.offset) as i64 + let bucket_pos = (get_bucket_num_f64(val, self.interval, self.offset) as i64 - self.first_bucket_num) as usize; debug_assert_eq!( self.buckets[bucket_pos].key, - get_bucket_val(val, self.req.interval, self.offset) as f64 + get_bucket_val(val, self.interval, self.offset) as f64 ); self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); } @@ -347,7 +344,7 @@ impl SegmentHistogramCollector { if bounds.contains(val) { debug_assert_eq!( self.buckets[bucket_pos].key, - get_bucket_val(val, self.req.interval, self.offset) as f64 + get_bucket_val(val, self.interval, self.offset) as f64 ); self.increment_bucket(bucket_pos, doc, bucket_with_accessor); @@ -449,6 +446,10 @@ fn generate_buckets_test() { let buckets = generate_buckets(&histogram_req, 0.5, 0.75); assert_eq!(buckets, vec![0.5]); + // no bucket + let buckets = generate_buckets(&histogram_req, f64::MAX, f64::MIN); + assert_eq!(buckets, vec![] as Vec); + // With extended_bounds let histogram_req = HistogramAggregation { field: "dummy".to_string(), @@ -470,6 +471,10 @@ fn generate_buckets_test() { let buckets = generate_buckets(&histogram_req, 0.5, 0.75); assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + // no bucket, but extended_bounds + let buckets = generate_buckets(&histogram_req, f64::MAX, f64::MIN); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + // With invalid extended_bounds let histogram_req = HistogramAggregation { field: "dummy".to_string(), @@ -525,8 +530,9 @@ mod tests { use super::*; use crate::aggregation::agg_req::{ - Aggregation, Aggregations, BucketAggregation, BucketAggregationType, + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, }; + use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; use crate::aggregation::tests::{ get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs, }; @@ -536,11 +542,29 @@ mod tests { use crate::{Index, Term}; fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result { + exec_request_with_query(agg_req, index, None) + } + fn exec_request_with_query( + agg_req: Aggregations, + index: &Index, + query: Option<(&str, &str)>, + ) -> crate::Result { let collector = AggregationCollector::from_aggs(agg_req); let reader = index.reader()?; let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector)?; + let agg_res = if let Some((field, term)) = query { + let text_field = reader.searcher().schema().get_field(field).unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, term), + IndexRecordOption::Basic, + ); + + searcher.search(&term_query, &collector)? + } else { + searcher.search(&AllQuery, &collector)? + }; let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; Ok(res) @@ -760,6 +784,113 @@ mod tests { Ok(()) } + + #[test] + fn histogram_extended_bounds_test_multi_segment() -> crate::Result<()> { + histogram_extended_bounds_test_with_opt(false) + } + #[test] + fn histogram_extended_bounds_test_single_segment() -> crate::Result<()> { + histogram_extended_bounds_test_with_opt(true) + } + fn histogram_extended_bounds_test_with_opt(merge_segments: bool) -> crate::Result<()> { + let values = vec![5.0]; + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + extended_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0); + + // 2 hits + let values = vec![5.0, 5.5]; + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + extended_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 4.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["key"], 5.0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 2); + assert_eq!(res["histogram"]["buckets"][3]["key"], 6.0); + assert_eq!(res["histogram"]["buckets"][3]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][4], Value::Null); + + // 1 hit outside bounds + let values = vec![15.0]; + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + extended_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }), + hard_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 4.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["key"], 5.0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][3]["key"], 6.0); + assert_eq!(res["histogram"]["buckets"][3]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][4], Value::Null); + + Ok(()) + } + #[test] fn histogram_hard_bounds_test_multi_segment() -> crate::Result<()> { histogram_hard_bounds_test_with_opt(false) @@ -871,16 +1002,16 @@ mod tests { } #[test] - fn histogram_empty_bucket_behaviour_test_single_segment() -> crate::Result<()> { - histogram_empty_bucket_behaviour_test_with_opt(true) + fn histogram_empty_result_behaviour_test_single_segment() -> crate::Result<()> { + histogram_empty_result_behaviour_test_with_opt(true) } #[test] - fn histogram_empty_bucket_behaviour_test_multi_segment() -> crate::Result<()> { - histogram_empty_bucket_behaviour_test_with_opt(false) + fn histogram_empty_result_behaviour_test_multi_segment() -> crate::Result<()> { + histogram_empty_result_behaviour_test_with_opt(false) } - fn histogram_empty_bucket_behaviour_test_with_opt(merge_segments: bool) -> crate::Result<()> { + fn histogram_empty_result_behaviour_test_with_opt(merge_segments: bool) -> crate::Result<()> { let index = get_test_index_2_segments(merge_segments)?; let agg_req: Aggregations = vec![( @@ -897,30 +1028,130 @@ mod tests { .into_iter() .collect(); - // let res = exec_request(agg_req, &index)?; + let res = exec_request_with_query(agg_req.clone(), &index, Some(("text", "blubberasdf")))?; - let reader = index.reader()?; - let text_field = reader.searcher().schema().get_field("text").unwrap(); - - let term_query = TermQuery::new( - Term::from_field_text(text_field, "nohit"), - IndexRecordOption::Basic, + assert_eq!( + res, + json!({ + "histogram": { + "buckets": [] + } + }) ); - let collector = AggregationCollector::from_aggs(agg_req); + // test index without segments + let values = vec![]; - let searcher = reader.searcher(); - let agg_res = searcher.search(&term_query, &collector).unwrap(); + // Don't merge empty segments + let index = get_test_index_from_values(false, &values)?; - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, Some(("text", "blubberasdf")))?; - assert_eq!(res["histogram"]["buckets"][0]["key"], 6.0); - assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 1); - assert_eq!(res["histogram"]["buckets"][37]["key"], 43.0); - assert_eq!(res["histogram"]["buckets"][37]["doc_count"], 0); - assert_eq!(res["histogram"]["buckets"][38]["key"], 44.0); - assert_eq!(res["histogram"]["buckets"][38]["doc_count"], 1); - assert_eq!(res["histogram"]["buckets"][39], Value::Null); + assert_eq!( + res, + json!({ + "histogram": { + "buckets": [] + } + }) + ); + + // test index without segments + let values = vec![]; + + // Don't merge empty segments + let index = get_test_index_from_values(false, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + extended_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0); + + let agg_req: Aggregations = vec![ + ( + "stats".to_string(), + Aggregation::Metric(MetricAggregation::Stats(StatsAggregation { + field: "score_f64".to_string(), + })), + ), + ( + "avg".to_string(), + Aggregation::Metric(MetricAggregation::Average(AverageAggregation { + field: "score_f64".to_string(), + })), + ), + ] + .into_iter() + .collect(); + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + extended_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: agg_req, + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!( + res["histogram"]["buckets"][0], + json!({ + "avg": { + "value": Value::Null + }, + "doc_count": 0, + "key": 2.0, + "stats": { + "sum": 0.0, + "count": 0, + "min": Value::Null, + "max": Value::Null, + "avg": Value::Null, + "standard_deviation": Value::Null, + } + }) + ); + assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0); Ok(()) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 36c74c38e..0e8b06217 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -5,7 +5,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::SegmentAggregationResultsCollector; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; -use crate::{SegmentReader, TantivyError}; +use crate::SegmentReader; /// Collector for aggregations. /// @@ -86,7 +86,7 @@ impl Collector for AggregationCollector { &self, segment_fruits: Vec<::Fruit>, ) -> crate::Result { - merge_fruits(segment_fruits).map(|res| res.into()) + merge_fruits(segment_fruits).map(|res| (res, self.agg.clone()).into()) } } @@ -99,9 +99,7 @@ fn merge_fruits( } Ok(fruit) } else { - Err(TantivyError::InvalidArgument( - "no fruits provided in merge_fruits".to_string(), - )) + Ok(IntermediateAggregationResults::default()) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index a5d36d2e9..023ff43c7 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -8,7 +8,7 @@ use fnv::FnvHashMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use super::bucket::HistogramAggregation; +use super::agg_req::{BucketAggregationType, CollectorAggregations, MetricAggregation}; use super::metric::{IntermediateAverage, IntermediateStats}; use super::segment_agg_result::{ SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentHistogramBucketEntry, @@ -34,6 +34,42 @@ impl From for IntermediateAggregationResults } impl IntermediateAggregationResults { + pub(crate) fn empty_from_req(req: &CollectorAggregations) -> Self { + let metrics = if req.metrics.is_empty() { + None + } else { + let metrics = req + .metrics + .iter() + .map(|(key, req)| { + ( + key.to_string(), + IntermediateMetricResult::empty_from_req(req), + ) + }) + .collect(); + Some(VecWithNames::from_entries(metrics)) + }; + + let buckets = if req.buckets.is_empty() { + None + } else { + let buckets = req + .buckets + .iter() + .map(|(key, req)| { + ( + key.to_string(), + IntermediateBucketResult::empty_from_req(&req.bucket_agg), + ) + }) + .collect(); + Some(VecWithNames::from_entries(buckets)) + }; + + Self { metrics, buckets } + } + /// Merge an other intermediate aggregation result into this result. /// /// The order of the values need to be the same on both results. This is ensured when the same @@ -89,6 +125,16 @@ impl From for IntermediateMetricResult { } impl IntermediateMetricResult { + pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self { + match req { + MetricAggregation::Average(_) => { + IntermediateMetricResult::Average(IntermediateAverage::default()) + } + MetricAggregation::Stats(_) => { + IntermediateMetricResult::Stats(IntermediateStats::default()) + } + } + } fn merge_fruits(&mut self, other: IntermediateMetricResult) { match (self, other) { ( @@ -122,9 +168,6 @@ pub enum IntermediateBucketResult { Histogram { /// The buckets buckets: Vec, - /// The original request. It is used to compute the total range after merging segments and - /// get min_doc_count after merging all segment results. - req: HistogramAggregation, }, } @@ -140,6 +183,14 @@ impl From for IntermediateBucketResult { } impl IntermediateBucketResult { + pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { + match req { + BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()), + BucketAggregationType::Histogram(_) => { + IntermediateBucketResult::Histogram { buckets: vec![] } + } + } + } fn merge_fruits(&mut self, other: IntermediateBucketResult) { match (self, other) { ( @@ -332,7 +383,9 @@ mod tests { } } - fn get_test_tree(data: &[(String, u64, String, u64)]) -> IntermediateAggregationResults { + fn get_intermediat_tree_with_ranges( + data: &[(String, u64, String, u64)], + ) -> IntermediateAggregationResults { let mut map = HashMap::new(); let mut buckets: FnvHashMap<_, _> = Default::default(); for (key, doc_count, sub_aggregation_key, sub_aggregation_count) in data { @@ -363,18 +416,18 @@ mod tests { #[test] fn test_merge_fruits_tree_1() { - let mut tree_left = get_test_tree(&[ + let mut tree_left = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 50, "1900".to_string(), 25), ("blue".to_string(), 30, "1900".to_string(), 30), ]); - let tree_right = get_test_tree(&[ + let tree_right = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 60, "1900".to_string(), 30), ("blue".to_string(), 25, "1900".to_string(), 50), ]); tree_left.merge_fruits(tree_right); - let tree_expected = get_test_tree(&[ + let tree_expected = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 110, "1900".to_string(), 55), ("blue".to_string(), 55, "1900".to_string(), 80), ]); @@ -384,18 +437,18 @@ mod tests { #[test] fn test_merge_fruits_tree_2() { - let mut tree_left = get_test_tree(&[ + let mut tree_left = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 50, "1900".to_string(), 25), ("blue".to_string(), 30, "1900".to_string(), 30), ]); - let tree_right = get_test_tree(&[ + let tree_right = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 60, "1900".to_string(), 30), ("green".to_string(), 25, "1900".to_string(), 50), ]); tree_left.merge_fruits(tree_right); - let tree_expected = get_test_tree(&[ + let tree_expected = get_intermediat_tree_with_ranges(&[ ("red".to_string(), 110, "1900".to_string(), 55), ("blue".to_string(), 30, "1900".to_string(), 30), ("green".to_string(), 25, "1900".to_string(), 50), @@ -403,4 +456,18 @@ mod tests { assert_eq!(tree_left, tree_expected); } + + #[test] + fn test_merge_fruits_tree_empty() { + let mut tree_left = get_intermediat_tree_with_ranges(&[ + ("red".to_string(), 50, "1900".to_string(), 25), + ("blue".to_string(), 30, "1900".to_string(), 30), + ]); + + let orig = tree_left.clone(); + + tree_left.merge_fruits(IntermediateAggregationResults::default()); + + assert_eq!(tree_left, orig); + } } diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index 7a12c7ef9..69f353824 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -20,7 +20,7 @@ use crate::DocId; /// "field": "score", /// } /// } -/// ``` +/// ``` pub struct AverageAggregation { /// The field name to compute the stats on. pub field: String, diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index daf5c0400..d4c95c09b 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -17,7 +17,7 @@ use crate::DocId; /// "field": "score", /// } /// } -/// ``` +/// ``` #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct StatsAggregation { @@ -62,9 +62,8 @@ pub struct IntermediateStats { min: f64, max: f64, } - -impl IntermediateStats { - fn new() -> Self { +impl Default for IntermediateStats { + fn default() -> Self { Self { count: 0, sum: 0.0, @@ -73,7 +72,9 @@ impl IntermediateStats { max: f64::MIN, } } +} +impl IntermediateStats { pub(crate) fn avg(&self) -> Option { if self.count == 0 { None @@ -142,7 +143,7 @@ impl SegmentStatsCollector { pub fn from_req(field_type: Type) -> Self { Self { field_type, - stats: IntermediateStats::new(), + stats: IntermediateStats::default(), } } pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &DynamicFastFieldReader) { @@ -182,12 +183,50 @@ mod tests { }; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::metric::StatsAggregation; - use crate::aggregation::tests::get_test_index_2_segments; + use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values}; use crate::aggregation::AggregationCollector; - use crate::query::TermQuery; + use crate::query::{AllQuery, TermQuery}; use crate::schema::IndexRecordOption; use crate::Term; + #[test] + fn test_aggregation_stats_empty_index() -> crate::Result<()> { + // test index without segments + let values = vec![]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = vec![( + "stats".to_string(), + Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( + "score".to_string(), + ))), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + assert_eq!( + res["stats"], + json!({ + "avg": Value::Null, + "count": 0, + "max": Value::Null, + "min": Value::Null, + "standard_deviation": Value::Null, + "sum": 0.0 + }) + ); + + Ok(()) + } + #[test] fn test_aggregation_stats() -> crate::Result<()> { let index = get_test_index_2_segments(false)?; diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index b91e96cad..a76095ae1 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -453,10 +453,10 @@ mod tests { .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); let searcher = reader.searcher(); - searcher.search(&term_query, &collector).unwrap().into() + (searcher.search(&term_query, &collector).unwrap(), agg_req).into() } else { let collector = AggregationCollector::from_aggs(agg_req); @@ -835,7 +835,7 @@ mod tests { // Test de/serialization roundtrip on intermediate_agg_result let res: IntermediateAggregationResults = serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); - res.into() + (res, agg_req.clone()).into() } else { let collector = AggregationCollector::from_aggs(agg_req.clone());