From c45eb9a9face33e8ad7955de70e8bb5482032e58 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 26 Apr 2022 11:22:26 +0800 Subject: [PATCH] improve readability, add json test --- src/aggregation/agg_result.rs | 132 +++++++++++++-------- src/aggregation/bucket/mod.rs | 4 +- src/aggregation/bucket/term_agg.rs | 52 +++++++- src/aggregation/intermediate_agg_result.rs | 17 +-- src/aggregation/mod.rs | 3 + 5 files changed, 145 insertions(+), 63 deletions(-) diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index fc1b225a9..9eca13d63 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -9,14 +9,16 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use super::agg_req::{Aggregations, AggregationsInternal, BucketAggregationInternal}; +use super::agg_req::{ + Aggregations, AggregationsInternal, BucketAggregationInternal, MetricAggregation, +}; use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount}; use super::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, IntermediateMetricResult, IntermediateRangeBucketEntry, }; use super::metric::{SingleMetricResult, Stats}; -use super::Key; +use super::{Key, VecWithNames}; use crate::TantivyError; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] @@ -53,60 +55,86 @@ impl AggregationResults { /// Internal function, CollectorAggregations is used instead Aggregations, which is optimized /// for internal processing, by splitting metric and buckets into seperate groups. pub(crate) fn from_intermediate_and_req_internal( - results: IntermediateAggregationResults, + intermediate_results: IntermediateAggregationResults, req: &AggregationsInternal, ) -> crate::Result { // Important assumption: // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the // request - let mut result: HashMap<_, _> = if let Some(buckets) = results.buckets { - buckets - .into_iter() - .zip(req.buckets.values()) - .map(|((key, bucket), req)| { - Ok(( - key, - AggregationResult::BucketResult(BucketResult::from_intermediate_and_req( - bucket, req, - )?), - )) - }) - .collect::>>()? + let mut results: HashMap = HashMap::new(); + + if let Some(buckets) = intermediate_results.buckets { + add_coverted_final_buckets_to_result(&mut results, buckets, &req.buckets)? } else { - req.buckets - .iter() - .map(|(key, req)| { - let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - Ok(( - key.to_string(), - AggregationResult::BucketResult(BucketResult::from_intermediate_and_req( - empty_bucket, - req, - )?), - )) - }) - .collect::>>()? + // When there are no buckets, we create empty buckets, so that the serialized json + // format is constant + add_empty_final_buckets_to_result(&mut results, &req.buckets)? }; - if let Some(metrics) = results.metrics { - result.extend( - metrics - .into_iter() - .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), - ); + if let Some(metrics) = intermediate_results.metrics { + add_converted_final_metrics_to_result(&mut results, metrics); } else { - result.extend(req.metrics.iter().map(|(key, req)| { - let empty_bucket = IntermediateMetricResult::empty_from_req(req); - ( - key.to_string(), - AggregationResult::MetricResult(empty_bucket.into()), - ) - })); + // When there are no metrics, we create empty metric results, so that the serialized + // json format is constant + add_empty_final_metrics_to_result(&mut results, &req.metrics)?; } - Ok(Self(result)) + Ok(Self(results)) } } +fn add_converted_final_metrics_to_result( + results: &mut HashMap, + metrics: VecWithNames, +) { + results.extend( + metrics + .into_iter() + .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), + ); +} + +fn add_empty_final_metrics_to_result( + results: &mut HashMap, + req_metrics: &VecWithNames, +) -> crate::Result<()> { + results.extend(req_metrics.iter().map(|(key, req)| { + let empty_bucket = IntermediateMetricResult::empty_from_req(req); + ( + key.to_string(), + AggregationResult::MetricResult(empty_bucket.into()), + ) + })); + Ok(()) +} + +fn add_empty_final_buckets_to_result( + results: &mut HashMap, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + let requested_buckets = req_buckets.iter(); + for (key, req) in requested_buckets { + let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); + results.insert(key.to_string(), empty_bucket); + } + Ok(()) +} + +fn add_coverted_final_buckets_to_result( + results: &mut HashMap, + buckets: VecWithNames, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + assert_eq!(buckets.len(), req_buckets.len()); + + let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); + for ((key, bucket), req) in buckets_with_request { + let result = + AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(bucket, req)?); + results.insert(key, result); + } + Ok(()) +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] /// An aggregation is either a bucket or a metric. @@ -200,6 +228,12 @@ pub enum BucketResult { } impl BucketResult { + pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result { + let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); + + Ok(BucketResult::from_intermediate_and_req(empty_bucket, req)?) + } + fn from_intermediate_and_req( bucket_result: IntermediateBucketResult, req: &BucketAggregationInternal, @@ -214,11 +248,11 @@ impl BucketResult { }) .collect::>>()?; - buckets.sort_by(|a, b| { + buckets.sort_by(|left, right| { // TODO use total_cmp next stable rust release - a.from + left.from .unwrap_or(f64::MIN) - .partial_cmp(&b.from.unwrap_or(f64::MIN)) + .partial_cmp(&right.from.unwrap_or(f64::MIN)) .unwrap_or(Ordering::Equal) }); Ok(BucketResult::Range { buckets }) @@ -226,14 +260,16 @@ impl BucketResult { IntermediateBucketResult::Histogram { buckets } => { let buckets = intermediate_buckets_to_final_buckets( buckets, - req.as_histogram().expect("unexpected aggregation"), + req.as_histogram() + .expect("unexpected aggregation, expected histogram aggregation"), &req.sub_aggregation, )?; Ok(BucketResult::Histogram { buckets }) } IntermediateBucketResult::Terms(terms) => terms.into_final_result( - req.as_term().expect("unexpected aggregation"), + req.as_term() + .expect("unexpected aggregation, expected term aggregation"), &req.sub_aggregation, ), } diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index e4cdc8dc6..a47437952 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -132,9 +132,9 @@ fn custom_order_serde_test() { assert_eq!(order, order_deser); - let order_deser: serde_json::Result = serde_json::from_str(&"{}"); + let order_deser: serde_json::Result = serde_json::from_str("{}"); assert!(order_deser.is_err()); - let order_deser: serde_json::Result = serde_json::from_str(&"[]"); + let order_deser: serde_json::Result = serde_json::from_str("[]"); assert!(order_deser.is_err()); } diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 01a967a10..3323e09bc 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -70,6 +70,7 @@ pub struct TermsAggregation { pub field: String, /// By default, the top 10 terms with the most documents are returned. /// Larger values for size are more expensive. + #[serde(skip_serializing_if = "Option::is_none", default)] pub size: Option, /// Unused by tantivy. @@ -79,6 +80,7 @@ pub struct TermsAggregation { /// The default value in elasticsearch is size * 1.5 + 10. /// /// Should never be smaller than size. + #[serde(skip_serializing_if = "Option::is_none", default)] pub shard_size: Option, /// The get more accurate results, we fetch more than `size` from each segment. @@ -86,6 +88,7 @@ pub struct TermsAggregation { /// Increasing this value is will increase the cost for more accuracy. /// /// Defaults to 10 * size. + #[serde(skip_serializing_if = "Option::is_none", default)] pub segment_size: Option, /// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will @@ -94,11 +97,13 @@ pub struct TermsAggregation { /// each segment that didn’t fit into `shard_size`. /// /// Defaults to true when ordering by counts desc. + #[serde(skip_serializing_if = "Option::is_none", default)] pub show_term_doc_count_error: Option, /// Filter all terms than are lower `min_doc_count`. Defaults to 1. /// /// **Expensive**: When set to 0, this will return all terms in the field. + #[serde(skip_serializing_if = "Option::is_none", default)] pub min_doc_count: Option, /// Set the order. `String` is here a target, which is either "_count", "_key", or the name of @@ -112,6 +117,7 @@ pub struct TermsAggregation { /// { "_count": "asc" } /// { "_key": "asc" } /// { "average_price": "asc" } + #[serde(skip_serializing_if = "Option::is_none", default)] pub order: Option, } @@ -290,6 +296,7 @@ impl SegmentTermCollector { TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?; if let Some(custom_order) = req.order.as_ref() { + // Validate sub aggregtion exists if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target { let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); @@ -335,8 +342,8 @@ impl SegmentTermCollector { // defer order and cut_off after loading the texts from the dictionary } OrderTarget::SubAggregation(_name) => { - // don't sort of cutt off since it's hard to make assumptions on the quality of the - // results when cutting off, du to unknown nature of the sub_aggregation (possible + // don't sort and cut off since it's hard to make assumptions on the quality of the + // results when cutting off du to unknown nature of the sub_aggregation (possible // to check). } OrderTarget::Count => { @@ -1164,6 +1171,47 @@ mod tests { Ok(()) } + + #[test] + fn test_json_format() -> crate::Result<()> { + let agg_req: Aggregations = vec![( + "term_agg_test".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + size: Some(2), + segment_size: Some(2), + order: Some(CustomOrder { + target: OrderTarget::Key, + order: Order::Desc, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let elasticsearch_compatible_json = json!( + { + "term_agg_test":{ + "terms": { + "field": "string_id", + "size": 2u64, + "segment_size": 2u64, + "order": {"_key": "desc"} + } + } + }); + + let agg_req_deser: Aggregations = + serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap()) + .unwrap(); + assert_eq!(agg_req, agg_req_deser); + + Ok(()) + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 59d02ac60..936caf38a 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -206,7 +206,7 @@ impl IntermediateBucketResult { .. }, ) => { - let mut buckets = buckets_left + let buckets = buckets_left .drain(..) .merge_join_by(buckets_right.into_iter(), |left, right| { left.key.partial_cmp(&right.key).unwrap_or(Ordering::Equal) @@ -221,7 +221,7 @@ impl IntermediateBucketResult { }) .collect(); - std::mem::swap(buckets_left, &mut buckets); + *buckets_left = buckets; } (IntermediateBucketResult::Range(_), _) => { panic!("try merge on different types") @@ -276,18 +276,13 @@ impl IntermediateTermBucketResult { let order = req.order.order; match req.order.target { OrderTarget::Key => { - buckets.sort_by(|bucket1, bucket2| { + buckets.sort_by(|left, right| { if req.order.order == Order::Desc { - bucket1 - .key - .partial_cmp(&bucket2.key) - .expect("expected type string, which is always sortable") + left.key.partial_cmp(&right.key) } else { - bucket2 - .key - .partial_cmp(&bucket1.key) - .expect("expected type string, which is always sortable") + right.key.partial_cmp(&left.key) } + .expect("expected type string, which is always sortable") }); } OrderTarget::Count => { diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 7709a926d..193a94d04 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -247,6 +247,9 @@ impl VecWithNames { fn is_empty(&self) -> bool { self.keys.is_empty() } + fn len(&self) -> usize { + self.keys.len() + } fn get(&self, name: &str) -> Option<&T> { self.keys() .position(|key| key == name)