handle empty results, empty indices, add tests

This commit is contained in:
Pascal Seitz
2022-03-17 10:01:24 +08:00
parent 691245bf20
commit 47dcbdbeae
8 changed files with 531 additions and 98 deletions

View File

@@ -51,6 +51,7 @@ use serde::{Deserialize, Serialize};
use super::bucket::HistogramAggregation;
pub use super::bucket::RangeAggregation;
use super::metric::{AverageAggregation, StatsAggregation};
use super::VecWithNames;
/// The top-level aggregation request structure, which contains [Aggregation] and their user defined
/// names. It is also used in [buckets](BucketAggregation) to define sub-aggregations.
@@ -58,6 +59,54 @@ use super::metric::{AverageAggregation, StatsAggregation};
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
/// Like Aggregations, but optimized to work with the aggregation result
#[derive(Clone, Debug)]
pub(crate) struct CollectorAggregations {
pub(crate) metrics: VecWithNames<MetricAggregation>,
pub(crate) buckets: VecWithNames<CollectorBucketAggregation>,
}
impl From<Aggregations> for CollectorAggregations {
fn from(aggs: Aggregations) -> Self {
let mut metrics = vec![];
let mut buckets = vec![];
for (key, agg) in aggs {
match agg {
Aggregation::Bucket(bucket) => buckets.push((
key,
CollectorBucketAggregation {
bucket_agg: bucket.bucket_agg,
sub_aggregation: bucket.sub_aggregation.into(),
},
)),
Aggregation::Metric(metric) => metrics.push((key, metric)),
}
}
Self {
metrics: VecWithNames::from_entries(metrics),
buckets: VecWithNames::from_entries(buckets),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CollectorBucketAggregation {
/// Bucket aggregation strategy to group documents.
pub bucket_agg: BucketAggregationType,
/// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the
/// bucket.
pub sub_aggregation: CollectorAggregations,
}
impl CollectorBucketAggregation {
pub(crate) fn as_histogram(&self) -> &HistogramAggregation {
match &self.bucket_agg {
BucketAggregationType::Range(_) => panic!("unexpected aggregation"),
BucketAggregationType::Histogram(histogram) => histogram,
}
}
}
/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut fast_field_names = Default::default();

View File

@@ -10,6 +10,7 @@ use std::collections::HashMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregations, CollectorAggregations, CollectorBucketAggregation};
use super::bucket::generate_buckets;
use super::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
@@ -22,21 +23,52 @@ use super::Key;
/// The final aggegation result.
pub struct AggregationResults(pub HashMap<String, AggregationResult>);
impl From<IntermediateAggregationResults> for AggregationResults {
fn from(tree: IntermediateAggregationResults) -> Self {
Self(
tree.buckets
.unwrap_or_default()
.into_iter()
.map(|(key, bucket)| (key, AggregationResult::BucketResult(bucket.into())))
.chain(
tree.metrics
.unwrap_or_default()
.into_iter()
.map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))),
impl From<(IntermediateAggregationResults, Aggregations)> for AggregationResults {
fn from(tree_and_req: (IntermediateAggregationResults, Aggregations)) -> Self {
let agg: CollectorAggregations = tree_and_req.1.into();
(tree_and_req.0, &agg).into()
}
}
impl From<(IntermediateAggregationResults, &CollectorAggregations)> for AggregationResults {
fn from(data: (IntermediateAggregationResults, &CollectorAggregations)) -> Self {
let tree = data.0;
let req = data.1;
let mut result = HashMap::default();
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
// request
if let Some(buckets) = tree.buckets {
result.extend(buckets.into_iter().zip(req.buckets.values()).map(
|((key, bucket), req)| (key, AggregationResult::BucketResult((bucket, req).into())),
));
} else {
result.extend(req.buckets.iter().map(|(key, req)| {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
(
key.to_string(),
AggregationResult::BucketResult((empty_bucket, req).into()),
)
.collect(),
)
}));
}
if let Some(metrics) = tree.metrics {
result.extend(
metrics
.into_iter()
.map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))),
);
} else {
result.extend(req.metrics.iter().map(|(key, req)| {
let empty_bucket = IntermediateMetricResult::empty_from_req(req);
(
key.to_string(),
AggregationResult::MetricResult(empty_bucket.into()),
)
}));
}
Self(result)
}
}
@@ -95,13 +127,15 @@ pub enum BucketResult {
},
}
impl From<IntermediateBucketResult> for BucketResult {
fn from(result: IntermediateBucketResult) -> Self {
match result {
impl From<(IntermediateBucketResult, &CollectorBucketAggregation)> for BucketResult {
fn from(result_and_req: (IntermediateBucketResult, &CollectorBucketAggregation)) -> Self {
let bucket_result = result_and_req.0;
let req = result_and_req.1;
match bucket_result {
IntermediateBucketResult::Range(range_map) => {
let mut buckets: Vec<RangeBucketEntry> = range_map
.into_iter()
.map(|(_, bucket)| bucket.into())
.map(|(_, bucket)| (bucket, &req.sub_aggregation).into())
.collect_vec();
buckets.sort_by(|a, b| {
@@ -112,20 +146,26 @@ impl From<IntermediateBucketResult> for BucketResult {
});
BucketResult::Range { buckets }
}
IntermediateBucketResult::Histogram { buckets, req } => {
let buckets = if req.min_doc_count() == 0 {
IntermediateBucketResult::Histogram { buckets } => {
let histogram_req = req.as_histogram();
let buckets = if histogram_req.min_doc_count() == 0 {
// With min_doc_count != 0, we may need to add buckets, so that there are no
// gaps, since intermediate result does not contain empty buckets (filtered to
// reduce serialization size).
let fill_gaps_buckets = if buckets.len() > 1 {
// buckets are sorted
let (min, max) = if buckets.is_empty() {
(f64::MAX, f64::MIN)
} else {
let min = buckets[0].key;
let max = buckets[buckets.len() - 1].key;
generate_buckets(&req, min, max)
} else {
vec![]
(min, max)
};
let fill_gaps_buckets = generate_buckets(histogram_req, min, max);
let sub_aggregation =
IntermediateAggregationResults::empty_from_req(&req.sub_aggregation);
buckets
.into_iter()
.merge_join_by(
@@ -138,21 +178,26 @@ impl From<IntermediateBucketResult> for BucketResult {
},
)
.map(|either| match either {
itertools::EitherOrBoth::Both(existing, _) => existing.into(),
itertools::EitherOrBoth::Left(existing) => existing.into(),
itertools::EitherOrBoth::Both(existing, _) => {
(existing, &req.sub_aggregation).into()
}
itertools::EitherOrBoth::Left(existing) => {
(existing, &req.sub_aggregation).into()
}
// Add missing bucket
itertools::EitherOrBoth::Right(bucket) => BucketEntry {
key: Key::F64(bucket),
doc_count: 0,
sub_aggregation: Default::default(),
sub_aggregation: (sub_aggregation.clone(), &req.sub_aggregation)
.into(),
},
})
.collect_vec()
} else {
buckets
.into_iter()
.filter(|bucket| bucket.doc_count >= req.min_doc_count())
.map(|bucket| bucket.into())
.filter(|bucket| bucket.doc_count >= histogram_req.min_doc_count())
.map(|bucket| (bucket, &req.sub_aggregation).into())
.collect_vec()
};
@@ -199,12 +244,14 @@ pub struct BucketEntry {
pub sub_aggregation: AggregationResults,
}
impl From<IntermediateHistogramBucketEntry> for BucketEntry {
fn from(entry: IntermediateHistogramBucketEntry) -> Self {
impl From<(IntermediateHistogramBucketEntry, &CollectorAggregations)> for BucketEntry {
fn from(entry_and_req: (IntermediateHistogramBucketEntry, &CollectorAggregations)) -> Self {
let entry = entry_and_req.0;
let req = entry_and_req.1;
BucketEntry {
key: Key::F64(entry.key),
doc_count: entry.doc_count,
sub_aggregation: entry.sub_aggregation.into(),
sub_aggregation: (entry.sub_aggregation, req).into(),
}
}
}
@@ -256,12 +303,14 @@ pub struct RangeBucketEntry {
pub to: Option<f64>,
}
impl From<IntermediateRangeBucketEntry> for RangeBucketEntry {
fn from(entry: IntermediateRangeBucketEntry) -> Self {
impl From<(IntermediateRangeBucketEntry, &CollectorAggregations)> for RangeBucketEntry {
fn from(entry_and_req: (IntermediateRangeBucketEntry, &CollectorAggregations)) -> Self {
let entry = entry_and_req.0;
let req = entry_and_req.1;
RangeBucketEntry {
key: entry.key,
doc_count: entry.doc_count,
sub_aggregation: entry.sub_aggregation.into(),
sub_aggregation: (entry.sub_aggregation, req).into(),
to: entry.to,
from: entry.from,
}

View File

@@ -158,7 +158,7 @@ pub struct SegmentHistogramCollector {
buckets: Vec<SegmentHistogramBucketEntry>,
sub_aggregations: Option<Vec<SegmentAggregationResultsCollector>>,
field_type: Type,
req: HistogramAggregation,
interval: f64,
offset: f64,
first_bucket_num: i64,
bounds: HistogramBounds,
@@ -195,10 +195,7 @@ impl SegmentHistogramCollector {
);
};
IntermediateBucketResult::Histogram {
buckets,
req: self.req,
}
IntermediateBucketResult::Histogram { buckets }
}
pub(crate) fn from_req_and_validate(
@@ -247,7 +244,7 @@ impl SegmentHistogramCollector {
Ok(Self {
buckets,
field_type,
req: req.clone(),
interval: req.interval,
offset: req.offset.unwrap_or(0f64),
first_bucket_num,
bounds,
@@ -263,7 +260,7 @@ impl SegmentHistogramCollector {
force_flush: bool,
) {
let bounds = self.bounds;
let interval = self.req.interval;
let interval = self.interval;
let offset = self.offset;
let first_bucket_num = self.first_bucket_num;
let get_bucket_num =
@@ -316,12 +313,12 @@ impl SegmentHistogramCollector {
if !bounds.contains(val) {
continue;
}
let bucket_pos = (get_bucket_num_f64(val, self.req.interval, self.offset) as i64
let bucket_pos = (get_bucket_num_f64(val, self.interval, self.offset) as i64
- self.first_bucket_num) as usize;
debug_assert_eq!(
self.buckets[bucket_pos].key,
get_bucket_val(val, self.req.interval, self.offset) as f64
get_bucket_val(val, self.interval, self.offset) as f64
);
self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation);
}
@@ -347,7 +344,7 @@ impl SegmentHistogramCollector {
if bounds.contains(val) {
debug_assert_eq!(
self.buckets[bucket_pos].key,
get_bucket_val(val, self.req.interval, self.offset) as f64
get_bucket_val(val, self.interval, self.offset) as f64
);
self.increment_bucket(bucket_pos, doc, bucket_with_accessor);
@@ -449,6 +446,10 @@ fn generate_buckets_test() {
let buckets = generate_buckets(&histogram_req, 0.5, 0.75);
assert_eq!(buckets, vec![0.5]);
// no bucket
let buckets = generate_buckets(&histogram_req, f64::MAX, f64::MIN);
assert_eq!(buckets, vec![] as Vec<f64>);
// With extended_bounds
let histogram_req = HistogramAggregation {
field: "dummy".to_string(),
@@ -470,6 +471,10 @@ fn generate_buckets_test() {
let buckets = generate_buckets(&histogram_req, 0.5, 0.75);
assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]);
// no bucket, but extended_bounds
let buckets = generate_buckets(&histogram_req, f64::MAX, f64::MIN);
assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]);
// With invalid extended_bounds
let histogram_req = HistogramAggregation {
field: "dummy".to_string(),
@@ -525,8 +530,9 @@ mod tests {
use super::*;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::tests::{
get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs,
};
@@ -536,11 +542,29 @@ mod tests {
use crate::{Index, Term};
fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result<Value> {
exec_request_with_query(agg_req, index, None)
}
fn exec_request_with_query(
agg_req: Aggregations,
index: &Index,
query: Option<(&str, &str)>,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector)?;
let agg_res = if let Some((field, term)) = query {
let text_field = reader.searcher().schema().get_field(field).unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, term),
IndexRecordOption::Basic,
);
searcher.search(&term_query, &collector)?
} else {
searcher.search(&AllQuery, &collector)?
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
Ok(res)
@@ -760,6 +784,113 @@ mod tests {
Ok(())
}
#[test]
fn histogram_extended_bounds_test_multi_segment() -> crate::Result<()> {
histogram_extended_bounds_test_with_opt(false)
}
#[test]
fn histogram_extended_bounds_test_single_segment() -> crate::Result<()> {
histogram_extended_bounds_test_with_opt(true)
}
fn histogram_extended_bounds_test_with_opt(merge_segments: bool) -> crate::Result<()> {
let values = vec![5.0];
let index = get_test_index_from_values(merge_segments, &values)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0);
assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0);
assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0);
// 2 hits
let values = vec![5.0, 5.5];
let index = get_test_index_from_values(merge_segments, &values)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["histogram"]["buckets"][0]["key"], 3.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][1]["key"], 4.0);
assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][2]["key"], 5.0);
assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 2);
assert_eq!(res["histogram"]["buckets"][3]["key"], 6.0);
assert_eq!(res["histogram"]["buckets"][3]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][4], Value::Null);
// 1 hit outside bounds
let values = vec![15.0];
let index = get_test_index_from_values(merge_segments, &values)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }),
hard_bounds: Some(HistogramBounds { min: 3.0, max: 6.0 }),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["histogram"]["buckets"][0]["key"], 3.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][1]["key"], 4.0);
assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][2]["key"], 5.0);
assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][3]["key"], 6.0);
assert_eq!(res["histogram"]["buckets"][3]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][4], Value::Null);
Ok(())
}
#[test]
fn histogram_hard_bounds_test_multi_segment() -> crate::Result<()> {
histogram_hard_bounds_test_with_opt(false)
@@ -871,16 +1002,16 @@ mod tests {
}
#[test]
fn histogram_empty_bucket_behaviour_test_single_segment() -> crate::Result<()> {
histogram_empty_bucket_behaviour_test_with_opt(true)
fn histogram_empty_result_behaviour_test_single_segment() -> crate::Result<()> {
histogram_empty_result_behaviour_test_with_opt(true)
}
#[test]
fn histogram_empty_bucket_behaviour_test_multi_segment() -> crate::Result<()> {
histogram_empty_bucket_behaviour_test_with_opt(false)
fn histogram_empty_result_behaviour_test_multi_segment() -> crate::Result<()> {
histogram_empty_result_behaviour_test_with_opt(false)
}
fn histogram_empty_bucket_behaviour_test_with_opt(merge_segments: bool) -> crate::Result<()> {
fn histogram_empty_result_behaviour_test_with_opt(merge_segments: bool) -> crate::Result<()> {
let index = get_test_index_2_segments(merge_segments)?;
let agg_req: Aggregations = vec![(
@@ -897,30 +1028,130 @@ mod tests {
.into_iter()
.collect();
// let res = exec_request(agg_req, &index)?;
let res = exec_request_with_query(agg_req.clone(), &index, Some(("text", "blubberasdf")))?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "nohit"),
IndexRecordOption::Basic,
assert_eq!(
res,
json!({
"histogram": {
"buckets": []
}
})
);
let collector = AggregationCollector::from_aggs(agg_req);
// test index without segments
let values = vec![];
let searcher = reader.searcher();
let agg_res = searcher.search(&term_query, &collector).unwrap();
// Don't merge empty segments
let index = get_test_index_from_values(false, &values)?;
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
let res = exec_request_with_query(agg_req, &index, Some(("text", "blubberasdf")))?;
assert_eq!(res["histogram"]["buckets"][0]["key"], 6.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["histogram"]["buckets"][37]["key"], 43.0);
assert_eq!(res["histogram"]["buckets"][37]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][38]["key"], 44.0);
assert_eq!(res["histogram"]["buckets"][38]["doc_count"], 1);
assert_eq!(res["histogram"]["buckets"][39], Value::Null);
assert_eq!(
res,
json!({
"histogram": {
"buckets": []
}
})
);
// test index without segments
let values = vec![];
// Don't merge empty segments
let index = get_test_index_from_values(false, &values)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0);
assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0);
assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0);
let agg_req: Aggregations = vec![
(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation {
field: "score_f64".to_string(),
})),
),
(
"avg".to_string(),
Aggregation::Metric(MetricAggregation::Average(AverageAggregation {
field: "score_f64".to_string(),
})),
),
]
.into_iter()
.collect();
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
sub_aggregation: agg_req,
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(
res["histogram"]["buckets"][0],
json!({
"avg": {
"value": Value::Null
},
"doc_count": 0,
"key": 2.0,
"stats": {
"sum": 0.0,
"count": 0,
"min": Value::Null,
"max": Value::Null,
"avg": Value::Null,
"standard_deviation": Value::Null,
}
})
);
assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0);
assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0);
assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 0);
assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0);
assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 0);
Ok(())
}

View File

@@ -5,7 +5,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::{SegmentReader, TantivyError};
use crate::SegmentReader;
/// Collector for aggregations.
///
@@ -86,7 +86,7 @@ impl Collector for AggregationCollector {
&self,
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
merge_fruits(segment_fruits).map(|res| res.into())
merge_fruits(segment_fruits).map(|res| (res, self.agg.clone()).into())
}
}
@@ -99,9 +99,7 @@ fn merge_fruits(
}
Ok(fruit)
} else {
Err(TantivyError::InvalidArgument(
"no fruits provided in merge_fruits".to_string(),
))
Ok(IntermediateAggregationResults::default())
}
}

View File

@@ -8,7 +8,7 @@ use fnv::FnvHashMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::bucket::HistogramAggregation;
use super::agg_req::{BucketAggregationType, CollectorAggregations, MetricAggregation};
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::{
SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentHistogramBucketEntry,
@@ -34,6 +34,42 @@ impl From<SegmentAggregationResultsCollector> for IntermediateAggregationResults
}
impl IntermediateAggregationResults {
pub(crate) fn empty_from_req(req: &CollectorAggregations) -> Self {
let metrics = if req.metrics.is_empty() {
None
} else {
let metrics = req
.metrics
.iter()
.map(|(key, req)| {
(
key.to_string(),
IntermediateMetricResult::empty_from_req(req),
)
})
.collect();
Some(VecWithNames::from_entries(metrics))
};
let buckets = if req.buckets.is_empty() {
None
} else {
let buckets = req
.buckets
.iter()
.map(|(key, req)| {
(
key.to_string(),
IntermediateBucketResult::empty_from_req(&req.bucket_agg),
)
})
.collect();
Some(VecWithNames::from_entries(buckets))
};
Self { metrics, buckets }
}
/// Merge an other intermediate aggregation result into this result.
///
/// The order of the values need to be the same on both results. This is ensured when the same
@@ -89,6 +125,16 @@ impl From<SegmentMetricResultCollector> for IntermediateMetricResult {
}
impl IntermediateMetricResult {
pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self {
match req {
MetricAggregation::Average(_) => {
IntermediateMetricResult::Average(IntermediateAverage::default())
}
MetricAggregation::Stats(_) => {
IntermediateMetricResult::Stats(IntermediateStats::default())
}
}
}
fn merge_fruits(&mut self, other: IntermediateMetricResult) {
match (self, other) {
(
@@ -122,9 +168,6 @@ pub enum IntermediateBucketResult {
Histogram {
/// The buckets
buckets: Vec<IntermediateHistogramBucketEntry>,
/// The original request. It is used to compute the total range after merging segments and
/// get min_doc_count after merging all segment results.
req: HistogramAggregation,
},
}
@@ -140,6 +183,14 @@ impl From<SegmentBucketResultCollector> for IntermediateBucketResult {
}
impl IntermediateBucketResult {
pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self {
match req {
BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()),
BucketAggregationType::Histogram(_) => {
IntermediateBucketResult::Histogram { buckets: vec![] }
}
}
}
fn merge_fruits(&mut self, other: IntermediateBucketResult) {
match (self, other) {
(
@@ -332,7 +383,9 @@ mod tests {
}
}
fn get_test_tree(data: &[(String, u64, String, u64)]) -> IntermediateAggregationResults {
fn get_intermediat_tree_with_ranges(
data: &[(String, u64, String, u64)],
) -> IntermediateAggregationResults {
let mut map = HashMap::new();
let mut buckets: FnvHashMap<_, _> = Default::default();
for (key, doc_count, sub_aggregation_key, sub_aggregation_count) in data {
@@ -363,18 +416,18 @@ mod tests {
#[test]
fn test_merge_fruits_tree_1() {
let mut tree_left = get_test_tree(&[
let mut tree_left = get_intermediat_tree_with_ranges(&[
("red".to_string(), 50, "1900".to_string(), 25),
("blue".to_string(), 30, "1900".to_string(), 30),
]);
let tree_right = get_test_tree(&[
let tree_right = get_intermediat_tree_with_ranges(&[
("red".to_string(), 60, "1900".to_string(), 30),
("blue".to_string(), 25, "1900".to_string(), 50),
]);
tree_left.merge_fruits(tree_right);
let tree_expected = get_test_tree(&[
let tree_expected = get_intermediat_tree_with_ranges(&[
("red".to_string(), 110, "1900".to_string(), 55),
("blue".to_string(), 55, "1900".to_string(), 80),
]);
@@ -384,18 +437,18 @@ mod tests {
#[test]
fn test_merge_fruits_tree_2() {
let mut tree_left = get_test_tree(&[
let mut tree_left = get_intermediat_tree_with_ranges(&[
("red".to_string(), 50, "1900".to_string(), 25),
("blue".to_string(), 30, "1900".to_string(), 30),
]);
let tree_right = get_test_tree(&[
let tree_right = get_intermediat_tree_with_ranges(&[
("red".to_string(), 60, "1900".to_string(), 30),
("green".to_string(), 25, "1900".to_string(), 50),
]);
tree_left.merge_fruits(tree_right);
let tree_expected = get_test_tree(&[
let tree_expected = get_intermediat_tree_with_ranges(&[
("red".to_string(), 110, "1900".to_string(), 55),
("blue".to_string(), 30, "1900".to_string(), 30),
("green".to_string(), 25, "1900".to_string(), 50),
@@ -403,4 +456,18 @@ mod tests {
assert_eq!(tree_left, tree_expected);
}
#[test]
fn test_merge_fruits_tree_empty() {
let mut tree_left = get_intermediat_tree_with_ranges(&[
("red".to_string(), 50, "1900".to_string(), 25),
("blue".to_string(), 30, "1900".to_string(), 30),
]);
let orig = tree_left.clone();
tree_left.merge_fruits(IntermediateAggregationResults::default());
assert_eq!(tree_left, orig);
}
}

View File

@@ -20,7 +20,7 @@ use crate::DocId;
/// "field": "score",
/// }
/// }
/// ```
/// ```
pub struct AverageAggregation {
/// The field name to compute the stats on.
pub field: String,

View File

@@ -17,7 +17,7 @@ use crate::DocId;
/// "field": "score",
/// }
/// }
/// ```
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct StatsAggregation {
@@ -62,9 +62,8 @@ pub struct IntermediateStats {
min: f64,
max: f64,
}
impl IntermediateStats {
fn new() -> Self {
impl Default for IntermediateStats {
fn default() -> Self {
Self {
count: 0,
sum: 0.0,
@@ -73,7 +72,9 @@ impl IntermediateStats {
max: f64::MIN,
}
}
}
impl IntermediateStats {
pub(crate) fn avg(&self) -> Option<f64> {
if self.count == 0 {
None
@@ -142,7 +143,7 @@ impl SegmentStatsCollector {
pub fn from_req(field_type: Type) -> Self {
Self {
field_type,
stats: IntermediateStats::new(),
stats: IntermediateStats::default(),
}
}
pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &DynamicFastFieldReader<u64>) {
@@ -182,12 +183,50 @@ mod tests {
};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::metric::StatsAggregation;
use crate::aggregation::tests::get_test_index_2_segments;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values};
use crate::aggregation::AggregationCollector;
use crate::query::TermQuery;
use crate::query::{AllQuery, TermQuery};
use crate::schema::IndexRecordOption;
use crate::Term;
#[test]
fn test_aggregation_stats_empty_index() -> crate::Result<()> {
// test index without segments
let values = vec![];
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(
res["stats"],
json!({
"avg": Value::Null,
"count": 0,
"max": Value::Null,
"min": Value::Null,
"standard_deviation": Value::Null,
"sum": 0.0
})
);
Ok(())
}
#[test]
fn test_aggregation_stats() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;

View File

@@ -453,10 +453,10 @@ mod tests {
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(agg_req);
let collector = DistributedAggregationCollector::from_aggs(agg_req.clone());
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap().into()
(searcher.search(&term_query, &collector).unwrap(), agg_req).into()
} else {
let collector = AggregationCollector::from_aggs(agg_req);
@@ -835,7 +835,7 @@ mod tests {
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
res.into()
(res, agg_req.clone()).into()
} else {
let collector = AggregationCollector::from_aggs(agg_req.clone());