From 1095c9b0736ff005da7ed10b730c44b053f2dec9 Mon Sep 17 00:00:00 2001 From: giovannicuccu Date: Tue, 4 Jun 2024 08:25:17 +0200 Subject: [PATCH] Issue 1787 extended stats (#2247) * first version of extended stats along with its tests * using IntermediateExtendStats instead of IntermediateStats with all tests passing * Created struct for request and response * first test with extended_stats * kahan summation and tests with approximate equality * version ready for merge * removed approx dependency * refactor for using ExtendedStats only when needed * interim version * refined version with code formatted * refactored a struct * cosmetic refactor * fix after merge * fix format * added extended_stat bench * merge and new benchmark for extended stats * split stat segment collectors * wrapped intermediate extended stat with a box to limit memory usage * Revert "wrapped intermediate extended stat with a box to limit memory usage" This reverts commit 5b4aa9f393edca90ee68ebbb883a28fbc3942c21. * some code reformat, commented kahan summation * refactor after review * refactor after code review * fix after incorrectly restoring kahan summation * modifications for code review + bug fix in merge_fruit * refactor assert_nearly_equals macro * update after code review --------- Co-authored-by: Giovanni Cuccu --- benches/agg_bench.rs | 8 +- src/aggregation/agg_req.rs | 8 +- src/aggregation/agg_req_with_accessor.rs | 8 +- src/aggregation/agg_result.rs | 7 +- src/aggregation/intermediate_agg_result.rs | 18 +- src/aggregation/metric/extended_stats.rs | 1182 ++++++++++++++++++++ src/aggregation/metric/mod.rs | 2 + src/aggregation/metric/stats.rs | 32 +- src/aggregation/segment_agg_result.rs | 7 +- src/collector/top_score_collector.rs | 5 +- src/lib.rs | 20 +- 11 files changed, 1271 insertions(+), 26 deletions(-) create mode 100644 src/aggregation/metric/extended_stats.rs diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 3359c426b..f7a06608f 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -47,6 +47,7 @@ fn bench_agg(mut group: InputGroup) { register!(group, average_f64); register!(group, average_f64_u64); register!(group, stats_f64); + register!(group, extendedstats_f64); register!(group, percentiles_f64); register!(group, terms_few); register!(group, terms_many); @@ -105,7 +106,12 @@ fn stats_f64(index: &Index) { }); exec_term_with_agg(index, agg_req) } - +fn extendedstats_f64(index: &Index) { + let agg_req = json!({ + "extendedstats_f64": { "extended_stats": { "field": "score_f64", } } + }); + exec_term_with_agg(index, agg_req) +} fn percentiles_f64(index: &Index) { let agg_req = json!({ "mypercentiles": { diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index eee4090de..fea06bdf9 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -34,7 +34,7 @@ use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; use super::metric::{ - AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, + AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation, }; @@ -146,6 +146,11 @@ pub enum AggregationVariants { /// extracted values. #[serde(rename = "stats")] Stats(StatsAggregation), + /// Computes a collection of estended statistics (`min`, `max`, `sum`, `count`, `avg`, + /// `sum_of_squares`, `variance`, `variance_sampling`, `std_deviation`, + /// `std_deviation_sampling`) over the extracted values. + #[serde(rename = "extended_stats")] + ExtendedStats(ExtendedStatsAggregation), /// Computes the sum of the extracted values. #[serde(rename = "sum")] Sum(SumAggregation), @@ -170,6 +175,7 @@ impl AggregationVariants { AggregationVariants::Max(max) => vec![max.field_name()], AggregationVariants::Min(min) => vec![min.field_name()], AggregationVariants::Stats(stats) => vec![stats.field_name()], + AggregationVariants::ExtendedStats(extended_stats) => vec![extended_stats.field_name()], AggregationVariants::Sum(sum) => vec![sum.field_name()], AggregationVariants::Percentiles(per) => vec![per.field_name()], AggregationVariants::TopHits(top_hits) => top_hits.field_names(), diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 76b09dea0..a1fb15ec9 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -11,8 +11,8 @@ use super::bucket::{ DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, }; use super::metric::{ - AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation, - SumAggregation, + AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, + StatsAggregation, SumAggregation, }; use super::segment_agg_result::AggregationLimits; use super::VecWithNames; @@ -276,6 +276,10 @@ impl AggregationWithAccessor { field: ref field_name, .. }) + | ExtendedStats(ExtendedStatsAggregation { + field: ref field_name, + .. + }) | Sum(SumAggregation { field: ref field_name, .. diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index b7ef24706..aac1efac4 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -8,7 +8,9 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::GetDocCount; -use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult}; +use super::metric::{ + ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult, +}; use super::{AggregationError, Key}; use crate::TantivyError; @@ -88,6 +90,8 @@ pub enum MetricResult { Min(SingleMetricResult), /// Stats metric result. Stats(Stats), + /// ExtendedStats metric result. + ExtendedStats(Box), /// Sum metric result. Sum(SingleMetricResult), /// Percentiles metric result. @@ -104,6 +108,7 @@ impl MetricResult { MetricResult::Max(max) => Ok(max.value), MetricResult::Min(min) => Ok(min.value), MetricResult::Stats(stats) => stats.get_value(agg_property), + MetricResult::ExtendedStats(extended_stats) => extended_stats.get_value(agg_property), MetricResult::Sum(sum) => Ok(sum.value), MetricResult::Percentiles(_) => Err(TantivyError::AggregationError( AggregationError::InvalidRequest("percentiles can't be used to order".to_string()), diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index d59c58c62..3f472854b 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -19,8 +19,8 @@ use super::bucket::{ GetDocCount, Order, OrderTarget, RangeAggregation, TermsAggregation, }; use super::metric::{ - IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, - IntermediateSum, PercentilesCollector, TopHitsTopNComputer, + IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax, + IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, TopHitsTopNComputer, }; use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; @@ -215,6 +215,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats( IntermediateStats::default(), )), + ExtendedStats(_) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::ExtendedStats(IntermediateExtendedStats::default()), + ), Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum( IntermediateSum::default(), )), @@ -282,6 +285,8 @@ pub enum IntermediateMetricResult { Min(IntermediateMin), /// Intermediate stats result. Stats(IntermediateStats), + /// Intermediate stats result. + ExtendedStats(IntermediateExtendedStats), /// Intermediate sum result. Sum(IntermediateSum), /// Intermediate top_hits result @@ -306,6 +311,9 @@ impl IntermediateMetricResult { IntermediateMetricResult::Stats(intermediate_stats) => { MetricResult::Stats(intermediate_stats.finalize()) } + IntermediateMetricResult::ExtendedStats(intermediate_stats) => { + MetricResult::ExtendedStats(intermediate_stats.finalize()) + } IntermediateMetricResult::Sum(intermediate_sum) => { MetricResult::Sum(intermediate_sum.finalize().into()) } @@ -346,6 +354,12 @@ impl IntermediateMetricResult { ) => { stats_left.merge_fruits(stats_right); } + ( + IntermediateMetricResult::ExtendedStats(extended_stats_left), + IntermediateMetricResult::ExtendedStats(extended_stats_right), + ) => { + extended_stats_left.merge_fruits(extended_stats_right); + } (IntermediateMetricResult::Sum(sum_left), IntermediateMetricResult::Sum(sum_right)) => { sum_left.merge_fruits(sum_right); } diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs new file mode 100644 index 000000000..ba80c1f4d --- /dev/null +++ b/src/aggregation/metric/extended_stats.rs @@ -0,0 +1,1182 @@ +use std::fmt::Debug; +use std::mem; + +use serde::{Deserialize, Serialize}; + +use super::*; +use crate::aggregation::agg_req_with_accessor::{ + AggregationWithAccessor, AggregationsWithAccessor, +}; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::*; +use crate::{DocId, TantivyError}; + +/// A multi-value metric aggregation that computes a collection of extended statistics +/// on numeric values that are extracted +/// from the aggregated documents. +/// See [`ExtendedStats`] for returned statistics. +/// +/// # JSON Format +/// ```json +/// { +/// "extended_stats": { +/// "field": "score" +/// } +/// } +/// ``` + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ExtendedStatsAggregation { + /// The field name to compute the stats on. + pub field: String, + /// The missing parameter defines how documents that are missing a value should be treated. + /// By default they will be ignored but it is also possible to treat them as if they had a + /// value. Examples in JSON format: + /// { "field": "my_numbers", "missing": "10.0" } + #[serde(default)] + pub missing: Option, + /// The sigma parameter defines how standard_deviation_bound_are_calculated. + /// This can be a useful way to visualize variance of your data. + /// The default value is 2. Examples in JSON format: + /// { "field": "my_numbers", "sigma": "3.0" } + #[serde(default)] + pub sigma: Option, +} + +impl ExtendedStatsAggregation { + /// Creates a new [`ExtendedStatsAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + ExtendedStatsAggregation { + field: field_name, + missing: None, + sigma: None, + } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Extended stats contains a collection of statistics +/// they extends stats adding variance, standard deviation +/// and bound informations +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ExtendedStats { + /// The number of documents. + pub count: u64, + /// The sum of the fast field values. + pub sum: f64, + /// The min value of the fast field values. + pub min: Option, + /// The max value of the fast field values. + pub max: Option, + /// The average of the fast field values. `None` if count equals zero. + pub avg: Option, + /// The sum of squares of the fast field values. `None` if count equals zero. + pub sum_of_squares: Option, + /// The variance of the fast field values. `None` if count is is 0 or 1. + pub variance: Option, + /// The variance population of the fast field values, always equal to variance. `None` if count + /// is is 0 or 1. + pub variance_population: Option, + /// The variance sampling of the fast field values, always equal to variance. `None` if count + /// is is 0 or 1. + pub variance_sampling: Option, + /// The standard deviation of the fast field values. `None` if count is is 0 or 1. + pub std_deviation: Option, + /// The standard deviation of the fast field values, always equal to variance. `None` if count + /// is is 0 or 1. + pub std_deviation_population: Option, + /// The standard deviation sampling of the fast field values. `None` + /// if count is is 0 or 1. + pub std_deviation_sampling: Option, + /// The standard deviation bounds of the fast field values, always equal to variance. `None` + /// if count is is 0 or 1. + pub std_deviation_bounds: Option, +} + +/// A sub struct for ExtendedStat containing deviation bounds +/// the values depend on sigma and represent +/// the bounds from the average with a distance of +/// std_deviation*sigma +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct StandardDeviationBounds { + /// upper bound -> avg + std_dev*sigma + pub upper: f64, + /// lower bound -> avg - std_dev*sigma + pub lower: f64, + /// upper bound sampling -> avg + std_dev_sampling*sigma + pub upper_sampling: f64, + /// lower bound sampling -> avg - std_dev_sampling*sigma + pub lower_sampling: f64, + /// same as upper + pub upper_population: f64, + /// same as lower + pub lower_population: f64, +} + +impl ExtendedStats { + pub(crate) fn get_value(&self, agg_property: &str) -> crate::Result> { + match agg_property { + "count" => Ok(Some(self.count as f64)), + "sum" => Ok(Some(self.sum)), + "min" => Ok(self.min), + "max" => Ok(self.max), + "avg" => Ok(self.avg), + "variance" => Ok(self.variance), + "variance_sampling" => Ok(self.variance_sampling), + "variance_population" => Ok(self.variance_population), + "sum_of_squares" => Ok(self.sum_of_squares), + "std_deviation" => Ok(self.std_deviation), + "std_deviation_sampling" => Ok(self.std_deviation_sampling), + "std_deviation_population" => Ok(self.std_deviation_population), + "std_deviation_bounds.lower" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower)), + "std_deviation_bounds.lower_population" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower_population)), + "std_deviation_bounds.lower_sampling" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower_sampling)), + "std_deviation_bounds.upper" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper)), + "std_deviation_bounds.upper_population" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper_population)), + "std_deviation_bounds.upper_sampling" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper_sampling)), + _ => Err(TantivyError::InvalidArgument(format!( + "Unknown property {agg_property} on stats metric aggregation" + ))), + } + } +} + +/// Intermediate result of the extended stats aggregation that can be combined with other +/// intermediate results. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateExtendedStats { + intermediate_stats: IntermediateStats, + /// The number of extracted values. + /// The sum of square values, it's referred as M2 in Welford's online algorithm + sum_of_squares: f64, + /// The sum of square values as computed by elastic search + sum_of_squares_elastic: f64, + /// The delta for sum of squares as computed by elastic search needed for the Kahan algorithm + delta_sum_for_squares_elastic: f64, + /// The mean is an intermediate value need for calculating the variance + /// as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) + mean: f64, + /// The value used for computing standard deviation bounds + sigma: f64, +} + +impl Default for IntermediateExtendedStats { + fn default() -> Self { + Self { + intermediate_stats: IntermediateStats::default(), + sum_of_squares: 0.0, + sum_of_squares_elastic: 0.0, + delta_sum_for_squares_elastic: 0.0, + mean: 0.0, + // The default value is the same of ElasticSearch + sigma: 2.0, + } + } +} + +impl IntermediateExtendedStats { + /// Creates a new IntermediateExtendedStats using an option + /// containing the sigma to be used for calculating bound values. + pub fn with_sigma(sigma: Option) -> Self { + Self { + intermediate_stats: IntermediateStats::default(), + sum_of_squares: 0.0, + sum_of_squares_elastic: 0.0, + delta_sum_for_squares_elastic: 0.0, + mean: 0.0, + // The default value is the same of ElasticSearch + sigma: sigma.unwrap_or(2.0), + } + } + /// Merges the other stats intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) { + if other.intermediate_stats.count == 0 { + return; + } + if self.intermediate_stats.count == 0 { + let _ = mem::replace(self, other); + return; + } + let new_count = self.intermediate_stats.count + other.intermediate_stats.count; + let delta = other.mean - self.mean; + self.sum_of_squares += other.sum_of_squares + + delta + * delta + * self.intermediate_stats.count as f64 + * other.intermediate_stats.count as f64 + / new_count as f64; + self.mean = (self.intermediate_stats.sum as f64 + other.intermediate_stats.sum as f64) + / new_count as f64; + self.sum_of_squares_elastic += other.sum_of_squares_elastic; + self.delta_sum_for_squares_elastic += other.delta_sum_for_squares_elastic; + self.intermediate_stats + .merge_fruits(other.intermediate_stats); + } + + /// Computes the final stats value. + pub fn finalize(&self) -> Box { + let (min, max, avg, sum_of_squares) = if self.intermediate_stats.count == 0 { + (None, None, None, None) + } else { + ( + Some(self.intermediate_stats.min), + Some(self.intermediate_stats.max), + Some(self.mean), + Some(self.sum_of_squares_elastic), + ) + }; + let (variance, variance_sampling) = if self.intermediate_stats.count <= 1 { + (None, None) + } else { + ( + Some(self.sum_of_squares / self.intermediate_stats.count as f64), + Some(self.sum_of_squares / (self.intermediate_stats.count - 1) as f64), + ) + }; + let std_deviation = variance.map(|v| v.sqrt()); + let std_deviation_sampling = variance_sampling.map(|v| v.sqrt()); + let std_deviation_bounds = + if let (Some(std_deviation_val), Some(std_deviation_sampling_val)) = + (std_deviation, std_deviation_sampling) + { + let upper = self.mean + std_deviation_val * self.sigma; + let lower = self.mean - std_deviation_val * self.sigma; + let upper_sampling = self.mean + std_deviation_sampling_val * self.sigma; + let lower_sampling = self.mean - std_deviation_sampling_val * self.sigma; + Some(StandardDeviationBounds { + upper, + lower, + upper_sampling, + lower_sampling, + upper_population: upper, + lower_population: lower, + }) + } else { + None + }; + Box::new(ExtendedStats { + count: self.intermediate_stats.count, + sum: self.intermediate_stats.sum, + min, + max, + avg, + sum_of_squares, + variance, + variance_population: variance, + variance_sampling, + std_deviation, + std_deviation_population: std_deviation, + std_deviation_sampling, + std_deviation_bounds, + }) + } + + fn update_variance(&mut self, value: f64) { + let delta = value - self.mean; + // this is not what the Welford's online algorithm prescribes but + // using the pseudo code from wikipedia there was a small rounding + // error (in 15th decimal place) that caused a test + //(test_aggregation_level1 in agg_test.rs) + // failure + self.mean = self.intermediate_stats.sum / self.intermediate_stats.count as f64; + // self.mean += delta / self.count as f64; + let delta2 = value - self.mean; + self.sum_of_squares += delta * delta2; + } + + #[inline] + fn collect(&mut self, value: f64) { + self.intermediate_stats.collect(value); + // kahan algorithm for sum_of_squares_elastic + let y = value * value - self.delta_sum_for_squares_elastic; + let t = self.sum_of_squares_elastic + y; + self.delta_sum_for_squares_elastic = (t - self.sum_of_squares_elastic) - y; + self.sum_of_squares_elastic = t; + self.update_variance(value); + } +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct SegmentExtendedStatsCollector { + missing: Option, + field_type: ColumnType, + pub(crate) extended_stats: IntermediateExtendedStats, + pub(crate) accessor_idx: usize, + val_cache: Vec, +} + +impl SegmentExtendedStatsCollector { + pub fn from_req( + field_type: ColumnType, + sigma: Option, + accessor_idx: usize, + missing: Option, + ) -> Self { + let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + Self { + field_type, + extended_stats: IntermediateExtendedStats::with_sigma(sigma), + accessor_idx, + missing, + val_cache: Default::default(), + } + } + #[inline] + pub(crate) fn collect_block_with_field( + &mut self, + docs: &[DocId], + agg_accessor: &mut AggregationWithAccessor, + ) { + if let Some(missing) = self.missing.as_ref() { + agg_accessor.column_block_accessor.fetch_block_with_missing( + docs, + &agg_accessor.accessor, + *missing, + ); + } else { + agg_accessor + .column_block_accessor + .fetch_block(docs, &agg_accessor.accessor); + } + for val in agg_accessor.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + self.extended_stats.collect(val1); + } + } +} + +impl SegmentAggregationCollector for SegmentExtendedStatsCollector { + #[inline] + fn add_intermediate_aggregation_result( + self: Box, + agg_with_accessor: &AggregationsWithAccessor, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + results.push( + name, + IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( + self.extended_stats, + )), + )?; + + Ok(()) + } + + #[inline] + fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &mut AggregationsWithAccessor, + ) -> crate::Result<()> { + let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; + if let Some(missing) = self.missing { + let mut has_val = false; + for val in field.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + self.extended_stats.collect(val1); + has_val = true; + } + if !has_val { + self.extended_stats + .collect(f64_from_fastfield_u64(missing, &self.field_type)); + } + } else { + for val in field.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + self.extended_stats.collect(val1); + } + } + + Ok(()) + } + + #[inline] + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_with_accessor: &mut AggregationsWithAccessor, + ) -> crate::Result<()> { + let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; + self.collect_block_with_field(docs, field); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::metric::IntermediateExtendedStats; + use crate::aggregation::tests::get_test_index_from_values; + use crate::aggregation::AggregationCollector; + use crate::assert_nearly_equals; + use crate::query::AllQuery; + + const EPSILON_FOR_TEST: f64 = 0.000000000002; + + #[test] + fn test_aggregation_extended_stats_no_variance() -> crate::Result<()> { + let values = vec![1.0]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "extended_stats": { + "field": "score_f64", + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 1.0 + ); + + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .is_none()); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 1.0 + ); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .is_none()); + + Ok(()) + } + + #[test] + fn test_aggregation_extended_stats() -> crate::Result<()> { + let values = vec![1.0, 3.0, 4.0, 5.0, 8.0, 10.0]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "extended_stats": { + "field": "score_f64", + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + const EXPECTED_VARIANCE: f64 = 9.138888888888888; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 10.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 31.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 5.166666666666667, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 3.311595788538611, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + -0.8794523824056837, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + -0.8794523824056837, + 0.00000000000001 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + -1.4565249104105549, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 11.212785715739017, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 11.212785715739017, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 11.78985824374389, + EPSILON_FOR_TEST + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 215.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 10.966666666666663, + EPSILON_FOR_TEST + ); + + Ok(()) + } + + #[test] + fn test_aggregation_extended_stats_with_sigma() -> crate::Result<()> { + let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "extended_stats": { + "field": "score_f64", + "sigma": 1.5 + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + + const EXPECTED_VARIANCE: f64 = 2.9166666666666665; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 21.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 3.5 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 1.8708286933869709, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + 0.9382623085101005, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + 0.9382623085101005, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + 0.6937569599195434, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 6.061737691489899, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 6.061737691489899, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 6.3062430400804566, + EPSILON_FOR_TEST + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 91.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 3.5, + EPSILON_FOR_TEST + ); + + Ok(()) + } + + #[test] + fn test_aggregation_extended_stats_with_variance_similar_to_mean() -> crate::Result<()> { + let values = vec![50.01, 50.02, 50.01, 50.03, 50.01, 50.02]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "extended_stats": { + "field": "score_f64", + "sigma": 1.5 + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + const EXPECTED_VARIANCE: f64 = 5.5555555555608854e-5; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 50.01 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 50.03 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 300.1 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 50.01666666666667, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 0.008164965809279263, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + 50.00548632677917, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + 50.00548632677917, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + 50.00441921795275, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 50.027847006554175, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 50.027847006554175, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 50.028914115380594, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 15010.002, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 6.666666666670718e-5, + EPSILON_FOR_TEST + ); + + Ok(()) + } + + #[test] + fn extended_stat_zero_value() { + let intermediate_extend_stats = IntermediateExtendedStats::default(); + let extended_stats = intermediate_extend_stats.finalize(); + assert!(extended_stats.variance.is_none()); + assert!(extended_stats.variance_population.is_none()); + assert!(extended_stats.variance_sampling.is_none()); + assert!(extended_stats.sum_of_squares.is_none()); + assert!(extended_stats.std_deviation.is_none()); + assert!(extended_stats.std_deviation_population.is_none()); + assert!(extended_stats.std_deviation_sampling.is_none()); + assert!(extended_stats.std_deviation_bounds.is_none()); + } + + #[test] + fn extended_stat_one_value() { + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.collect(1.0f64); + let extended_stats = intermediate_extend_stats.finalize(); + assert!(extended_stats.variance.is_none()); + assert!(extended_stats.variance_population.is_none()); + assert!(extended_stats.variance_sampling.is_none()); + assert!(extended_stats.std_deviation.is_none()); + assert!(extended_stats.std_deviation_population.is_none()); + assert!(extended_stats.std_deviation_sampling.is_none()); + assert!(extended_stats.std_deviation_bounds.is_none()); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(1.0f64, sum_of_squares); + } + + #[test] + fn extended_stat_multiple_values() { + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.collect(1.0f64); + intermediate_extend_stats.collect(3.0f64); + intermediate_extend_stats.collect(4.0f64); + intermediate_extend_stats.collect(5.0f64); + intermediate_extend_stats.collect(8.0f64); + intermediate_extend_stats.collect(10.0f64); + let extended_stats = intermediate_extend_stats.finalize(); + let variance = extended_stats.variance.unwrap(); + const EXPECTED_VARIANCE: f64 = 9.138888888888888; + assert_eq!(EXPECTED_VARIANCE, variance); + let variance_population = extended_stats.variance_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE, variance_population); + let variance_sampling = extended_stats.variance_sampling.unwrap(); + assert_eq!(10.966666666666665f64, variance_sampling); + let std_deviation = extended_stats.std_deviation.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation); + let std_deviation_population = extended_stats.std_deviation_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation_population); + let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); + assert_eq!(10.966666666666665f64.sqrt(), std_deviation_sampling); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(215.0, sum_of_squares); + let avg = extended_stats.avg.unwrap(); + assert_eq!(5.166666666666667, avg); + } + + #[test] + fn merge_empty_with_one_value() { + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + let mut intermediate_extend_stats1 = IntermediateExtendedStats::default(); + intermediate_extend_stats1.collect(1.0f64); + intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); + let extended_stats = intermediate_extend_stats.finalize(); + assert!(extended_stats.variance.is_none()); + assert!(extended_stats.variance_population.is_none()); + assert!(extended_stats.variance_sampling.is_none()); + assert!(extended_stats.std_deviation.is_none()); + assert!(extended_stats.std_deviation_population.is_none()); + assert!(extended_stats.std_deviation_sampling.is_none()); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(1.0f64, sum_of_squares); + } + + #[test] + fn merge_empty_with_multiple_values() { + let mut intermediate_extend_stats1 = IntermediateExtendedStats::default(); + intermediate_extend_stats1.collect(1.0f64); + intermediate_extend_stats1.collect(2.0f64); + intermediate_extend_stats1.collect(3.0f64); + intermediate_extend_stats1.collect(4.0f64); + intermediate_extend_stats1.collect(5.0f64); + + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); + let extended_stats = intermediate_extend_stats.finalize(); + const EXPECTED_VARIANCE: f64 = 2.0; + let variance = extended_stats.variance.unwrap(); + assert_eq!(EXPECTED_VARIANCE, variance); + let variance_population = extended_stats.variance_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE, variance_population); + let variance_sampling = extended_stats.variance_sampling.unwrap(); + assert_eq!(2.5f64, variance_sampling); + let std_deviation = extended_stats.std_deviation.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation); + let std_deviation_population = extended_stats.std_deviation_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation_population); + let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); + assert_eq!(2.5f64.sqrt(), std_deviation_sampling); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(55f64, sum_of_squares); + } + + #[test] + fn merge_non_empty_extended_stats() { + let mut intermediate_extend_stats1 = IntermediateExtendedStats::default(); + intermediate_extend_stats1.collect(3.0f64); + intermediate_extend_stats1.collect(4.0f64); + intermediate_extend_stats1.collect(5.0f64); + + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.collect(1.0f64); + intermediate_extend_stats.collect(2.0f64); + intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); + let extended_stats = intermediate_extend_stats.finalize(); + + let variance = extended_stats.variance.unwrap(); + assert_eq!(2.0f64, variance); + let variance_population = extended_stats.variance_population.unwrap(); + assert_eq!(2.0f64, variance_population); + let variance_sampling = extended_stats.variance_sampling.unwrap(); + assert_eq!(2.5f64, variance_sampling); + let std_deviation = extended_stats.std_deviation.unwrap(); + assert_eq!(2.0f64.sqrt(), std_deviation); + let std_deviation_population = extended_stats.std_deviation_population.unwrap(); + assert_eq!(2.0f64.sqrt(), std_deviation_population); + let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); + assert_eq!(2.5f64.sqrt(), std_deviation_sampling); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(55f64, sum_of_squares); + + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.collect(1.0f64); + intermediate_extend_stats.collect(3.0f64); + intermediate_extend_stats.collect(4.0f64); + let mut intermediate_extend_stats1 = IntermediateExtendedStats::default(); + intermediate_extend_stats1.collect(5.0f64); + intermediate_extend_stats1.collect(8.0f64); + intermediate_extend_stats1.collect(10.0f64); + intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); + let extended_stats = intermediate_extend_stats.finalize(); + const EXPECTED_VARIANCE: f64 = 9.138888888888888; + let variance = extended_stats.variance.unwrap(); + assert_eq!(EXPECTED_VARIANCE, variance); + let variance_population = extended_stats.variance_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE, variance_population); + let variance_sampling = extended_stats.variance_sampling.unwrap(); + assert_eq!(10.966666666666665f64, variance_sampling); + let std_deviation = extended_stats.std_deviation.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation); + let std_deviation_population = extended_stats.std_deviation_population.unwrap(); + assert_eq!(EXPECTED_VARIANCE.sqrt(), std_deviation_population); + let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); + assert_eq!(10.966666666666665f64.sqrt(), std_deviation_sampling); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(215f64, sum_of_squares); + let avg = extended_stats.avg.unwrap(); + assert_eq!(5.166666666666667, avg); + } + + #[test] + fn merge_and_then_collect_non_empty_extended_stats() { + let mut intermediate_extend_stats = IntermediateExtendedStats::default(); + intermediate_extend_stats.collect(1.0f64); + intermediate_extend_stats.collect(3.0f64); + + let mut intermediate_extend_stats1 = IntermediateExtendedStats::default(); + intermediate_extend_stats1.collect(5.0f64); + intermediate_extend_stats1.collect(8.0f64); + intermediate_extend_stats1.collect(10.0f64); + intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); + intermediate_extend_stats.collect(4.0f64); + let extended_stats = intermediate_extend_stats.finalize(); + const EXPECTED_VARIANCE: f64 = 9.138888888888888; + let variance = extended_stats.variance.unwrap(); + assert_nearly_equals!(EXPECTED_VARIANCE, variance, EPSILON_FOR_TEST); + let variance_population = extended_stats.variance_population.unwrap(); + assert_nearly_equals!(EXPECTED_VARIANCE, variance_population, EPSILON_FOR_TEST); + let variance_sampling = extended_stats.variance_sampling.unwrap(); + assert_nearly_equals!(10.966666666666665, variance_sampling, EPSILON_FOR_TEST); + let std_deviation = extended_stats.std_deviation.unwrap(); + assert_nearly_equals!(EXPECTED_VARIANCE.sqrt(), std_deviation, EPSILON_FOR_TEST); + let std_deviation_population = extended_stats.std_deviation_population.unwrap(); + assert_nearly_equals!( + EXPECTED_VARIANCE.sqrt(), + std_deviation_population, + EPSILON_FOR_TEST + ); + let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); + assert_nearly_equals!( + 10.966666666666665_f64.sqrt(), + std_deviation_sampling, + EPSILON_FOR_TEST + ); + let sum_of_squares = extended_stats.sum_of_squares.unwrap(); + assert_eq!(215.0, sum_of_squares); + let avg = extended_stats.avg.unwrap(); + assert_eq!(5.166666666666667, avg); + } +} diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 6da583a59..4c64737b5 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -18,6 +18,7 @@ mod average; mod count; +mod extended_stats; mod max; mod min; mod percentiles; @@ -29,6 +30,7 @@ use std::collections::HashMap; pub use average::*; pub use count::*; +pub use extended_stats::*; pub use max::*; pub use min::*; pub use percentiles::*; diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index bd0c52269..8b8be3999 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use serde::{Deserialize, Serialize}; use super::*; @@ -85,13 +87,15 @@ impl Stats { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { /// The number of extracted values. - count: u64, + pub(crate) count: u64, /// The sum of the extracted values. - sum: f64, + pub(crate) sum: f64, + /// delta for sum needed for [Kahan algorithm for summation](https://en.wikipedia.org/wiki/Kahan_summation_algorithm) + pub(crate) delta: f64, /// The min value. - min: f64, + pub(crate) min: f64, /// The max value. - max: f64, + pub(crate) max: f64, } impl Default for IntermediateStats { @@ -99,6 +103,7 @@ impl Default for IntermediateStats { Self { count: 0, sum: 0.0, + delta: 0.0, min: f64::MAX, max: f64::MIN, } @@ -109,7 +114,13 @@ impl IntermediateStats { /// Merges the other stats intermediate result into self. pub fn merge_fruits(&mut self, other: IntermediateStats) { self.count += other.count; - self.sum += other.sum; + + // kahan algorithm for sum + let y = other.sum - (self.delta + other.delta); + let t = self.sum + y; + self.delta = (t - self.sum) - y; + self.sum = t; + self.min = self.min.min(other.min); self.max = self.max.max(other.max); } @@ -141,9 +152,15 @@ impl IntermediateStats { } #[inline] - fn collect(&mut self, value: f64) { + pub(in crate::aggregation::metric) fn collect(&mut self, value: f64) { self.count += 1; - self.sum += value; + + // kahan algorithm for sum + let y = value - self.delta; + let t = self.sum + y; + self.delta = (t - self.sum) - y; + self.sum = t; + self.min = self.min.min(value); self.max = self.max.max(value); } @@ -288,7 +305,6 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[cfg(test)] mod tests { - use serde_json::Value; use crate::aggregation::agg_req::{Aggregation, Aggregations}; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 76f0eb284..cc0aab5a6 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -11,12 +11,12 @@ use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAcce use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; use super::intermediate_agg_result::IntermediateAggregationResults; use super::metric::{ - AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, + AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, SumAggregation, }; use crate::aggregation::bucket::TermMissingAgg; -use crate::aggregation::metric::TopHitsSegmentCollector; +use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector}; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( @@ -148,6 +148,9 @@ pub(crate) fn build_single_agg_segment_collector( accessor_idx, *missing, ))), + ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new( + SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing), + )), Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( req.field_type, SegmentStatsType::Sum, diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 415625bc1..4545e54d7 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -871,7 +871,10 @@ mod tests { use crate::schema::{Field, Schema, FAST, STORED, TEXT}; use crate::time::format_description::well_known::Rfc3339; use crate::time::OffsetDateTime; - use crate::{DateTime, DocAddress, DocId, Index, IndexWriter, Order, Score, SegmentReader}; + use crate::{ + assert_nearly_equals, DateTime, DocAddress, DocId, Index, IndexWriter, Order, Score, + SegmentReader, + }; fn make_index() -> crate::Result { let mut schema_builder = Schema::builder(); diff --git a/src/lib.rs b/src/lib.rs index 5ac35c98d..ed6463c22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -397,16 +397,20 @@ pub mod tests { #[macro_export] macro_rules! assert_nearly_equals { ($left:expr, $right:expr) => {{ - match (&$left, &$right) { - (left_val, right_val) => { + assert_nearly_equals!($left, $right, 0.0005); + }}; + ($left:expr, $right:expr, $epsilon:expr) => {{ + match (&$left, &$right, &$epsilon) { + (left_val, right_val, epsilon_val) => { let diff = (left_val - right_val).abs(); - let add = left_val.abs() + right_val.abs(); - if diff > 0.0005 * add { + + if diff > *epsilon_val { panic!( - r#"assertion failed: `(left ~= right)` - left: `{:?}`, - right: `{:?}`"#, - &*left_val, &*right_val + r#"assertion failed: `abs(left-right)>epsilon` + left: `{:?}`, + right: `{:?}`, + epsilon: `{:?}`"#, + &*left_val, &*right_val, &*epsilon_val ) } }