From 9efb1f7787b23b98ae08914f9505cc033cc0989d Mon Sep 17 00:00:00 2001 From: Giovanni Cuccu Date: Fri, 10 Nov 2023 16:47:18 +0100 Subject: [PATCH] version ready for merge --- src/aggregation/metric/stats.rs | 680 ++++++++++++++++++++++++++------ src/lib.rs | 17 + 2 files changed, 585 insertions(+), 112 deletions(-) diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 73d024c25..d7cc815f9 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -166,13 +166,23 @@ pub struct ExtendedStats { pub std_deviation_bounds: Option, } +/// A sub struct for ExtendedStat containing deviation bounds +/// the values depend on sigma and represent +/// the bounds from the average with a distance of +/// std_deviation*sigma #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct StandardDeviationBounds { + /// upper bound -> avg + std_dev*sigma pub upper: f64, + /// lower bound -> avg - std_dev*sigma pub lower: f64, + /// upper bound sampling -> avg + std_dev_sampling*sigma pub upper_sampling: f64, + /// lower bound sampling -> avg - std_dev_sampling*sigma pub lower_sampling: f64, + /// same as upper pub upper_population: f64, + /// same as lower pub lower_population: f64, } @@ -191,12 +201,30 @@ impl ExtendedStats { "std_deviation" => Ok(self.std_deviation), "std_deviation_sampling" => Ok(self.std_deviation_sampling), "std_deviation_population" => Ok(self.std_deviation_population), - "std_deviation_bounds.lower" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.lower)), - "std_deviation_bounds.lower_population" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.lower_population)), - "std_deviation_bounds.lower_sampling" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.lower_sampling)), - "std_deviation_bounds.upper" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.upper)), - "std_deviation_bounds.upper_population" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.upper_population)), - "std_deviation_bounds.upper_sampling" => Ok(self.std_deviation_bounds.as_ref().map(|bounds| bounds.upper_sampling)), + "std_deviation_bounds.lower" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower)), + "std_deviation_bounds.lower_population" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower_population)), + "std_deviation_bounds.lower_sampling" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.lower_sampling)), + "std_deviation_bounds.upper" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper)), + "std_deviation_bounds.upper_population" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper_population)), + "std_deviation_bounds.upper_sampling" => Ok(self + .std_deviation_bounds + .as_ref() + .map(|bounds| bounds.upper_sampling)), _ => Err(TantivyError::InvalidArgument(format!( "Unknown property {agg_property} on stats metric aggregation" ))), @@ -204,6 +232,8 @@ impl ExtendedStats { } } +/// Intermediate result of the stats aggregation that can be combined with other intermediate +/// results. #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct IntermediateStats { @@ -243,14 +273,18 @@ pub struct IntermediateExtendedStats { count: u64, /// The sum of the extracted values. sum: f64, - /// delta for sum needed by [Kahan algorithm for summation](https://en.wikipedia.org/wiki/Kahan_summation_algorithm) + /// delta for sum needed for [Kahan algorithm for summation](https://en.wikipedia.org/wiki/Kahan_summation_algorithm) delta: f64, /// The min value. min: f64, /// The max value. max: f64, - // The sum of the square values it's referred as M2 in Welford's online algorithm + // The sum of square values, it's referred as M2 in Welford's online algorithm sum_of_squares: f64, + // The sum of square values as computed by elastic search + sum_of_squares_elastic: f64, + /// delta for sum of squares as computed by elastic search needed for the Kahan algorithm + delta_sum_for_squares_elastic: f64, // The mean an intermediate value need for calculating the variance // as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) mean: f64, @@ -267,6 +301,8 @@ impl Default for IntermediateExtendedStats { min: f64::MAX, max: f64::MIN, sum_of_squares: 0.0, + sum_of_squares_elastic: 0.0, + delta_sum_for_squares_elastic: 0.0, mean: 0.0, sigma: 2.0, } @@ -284,6 +320,8 @@ impl IntermediateExtendedStats { min: f64::MAX, max: f64::MIN, sum_of_squares: 0.0, + sum_of_squares_elastic: 0.0, + delta_sum_for_squares_elastic: 0.0, mean: 0.0, sigma: sigma.unwrap_or(2.0), } @@ -292,24 +330,28 @@ impl IntermediateExtendedStats { pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) { self.min = self.min.min(other.min); self.max = self.max.max(other.max); - + if other.count != 0 { if self.count == 0 { self.sum_of_squares = other.sum_of_squares; + self.sum_of_squares_elastic = other.sum_of_squares_elastic; self.count = other.count; self.mean = other.mean; self.sum = other.sum; self.delta = other.delta; + self.delta_sum_for_squares_elastic = other.delta_sum_for_squares_elastic } else { - if other.count==1 { + if other.count == 1 { self.collect(other.sum); - } else if self.count==1 { - let sum=self.sum; + } else if self.count == 1 { + let sum = self.sum; self.sum_of_squares = other.sum_of_squares; + self.sum_of_squares_elastic = other.sum_of_squares_elastic; self.count = other.count; self.mean = other.mean; - self.sum=other.sum; + self.sum = other.sum; self.delta = other.delta; + self.delta_sum_for_squares_elastic = other.delta_sum_for_squares_elastic; self.collect(sum); } else { // parallel version of Welford's online algorithm @@ -324,10 +366,11 @@ impl IntermediateExtendedStats { self.mean = (self.sum as f64 + other.sum as f64) / new_count as f64; self.sum += other.sum; self.delta += other.delta; - } - } + self.sum_of_squares_elastic += other.sum_of_squares_elastic; + self.delta_sum_for_squares_elastic += other.delta_sum_for_squares_elastic + } + } } - } /// Computes the final stats value. @@ -350,7 +393,7 @@ impl IntermediateExtendedStats { let sum_of_squares = if self.count == 0 { None } else { - Some(self.sum_of_squares) + Some(self.sum_of_squares_elastic) }; let variance = if self.count <= 1 { None @@ -367,10 +410,10 @@ impl IntermediateExtendedStats { let std_deviation_bounds = if std_deviation.is_none() { None } else { - let upper=self.mean + std_deviation.unwrap() * self.sigma; - let lower=self.mean - std_deviation.unwrap() * self.sigma; - let upper_sampling=self.mean + std_deviation_sampling.unwrap() * self.sigma; - let lower_sampling=self.mean - std_deviation_sampling.unwrap() * self.sigma; + let upper = self.mean + std_deviation.unwrap() * self.sigma; + let lower = self.mean - std_deviation.unwrap() * self.sigma; + let upper_sampling = self.mean + std_deviation_sampling.unwrap() * self.sigma; + let lower_sampling = self.mean - std_deviation_sampling.unwrap() * self.sigma; Some(StandardDeviationBounds { upper, lower, @@ -399,13 +442,19 @@ impl IntermediateExtendedStats { fn collect(&mut self, value: f64) { self.count += 1; - - //kahan algorithm for summation - let y=value-self.delta; - let t=self.sum+y; - self.delta=(t-self.sum)-y; + + // kahan algorithm for sum + let y = value - self.delta; + let t = self.sum + y; + self.delta = (t - self.sum) - y; self.sum = t; + // kahan algorithm for sum_of_squares_elastic + let y = value * value - self.delta_sum_for_squares_elastic; + let t = self.sum_of_squares_elastic + y; + self.delta_sum_for_squares_elastic = (t - self.sum_of_squares_elastic) - y; + self.sum_of_squares_elastic = t; + self.min = self.min.min(value); self.max = self.max.max(value); self.update_variance(value); @@ -419,7 +468,7 @@ impl IntermediateExtendedStats { //(test_aggregation_level1 in agg_test.rs) // failure self.mean = self.sum / self.count as f64; - //self.mean += delta / self.count as f64; + // self.mean += delta / self.count as f64; let delta2 = value - self.mean; self.sum_of_squares += delta * delta2; } @@ -570,9 +619,8 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[cfg(test)] mod tests { - use serde_json::Value; - use approx::assert_relative_eq; + use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::metric::IntermediateExtendedStats; @@ -582,9 +630,9 @@ mod tests { use crate::aggregation::AggregationCollector; use crate::query::{AllQuery, TermQuery}; use crate::schema::{IndexRecordOption, Schema, FAST}; - use crate::{Index, IndexWriter, Term}; + use crate::{assert_nearly_equals, Index, IndexWriter, Term}; - const EPSILON_FOR_TEST : f64 = 0.00000000000001; + const EPSILON_FOR_TEST: f64 = 0.00000000000002; #[test] fn test_aggregation_stats_empty_index() -> crate::Result<()> { @@ -927,9 +975,105 @@ mod tests { Ok(()) } + #[test] + fn test_aggregation_extended_stats_no_variance() -> crate::Result<()> { + let values = vec![1.0]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "extended_stats": { + "field": "score_f64", + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 1.0 + ); + + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .is_none()); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 1.0 + ); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance")? + .is_none()); + assert!(agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .is_none()); + + Ok(()) + } + #[test] fn test_aggregation_extended_stats() -> crate::Result<()> { - let values = vec![1.0, 3.0, 4.0, 5.0, 8.0, 10.0]; let index = get_test_index_from_values(false, &values)?; @@ -948,25 +1092,128 @@ mod tests { let reader = index.reader()?; let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); - const EXPECTED_VARIANCE : f64 = 9.138888888888888; - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "count")?.unwrap(),6.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "min")?.unwrap(),1.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "max")?.unwrap(),10.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "sum")?.unwrap(),31.0); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "avg")?.unwrap(),5.166666666666667, epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_population")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_sampling")?.unwrap(),3.311595788538611, epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")?.unwrap(),-0.8794523824056837, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")?.unwrap(),-0.8794523824056837, epsilon = 0.00000000000001); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")?.unwrap(),-1.4565249104105549, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")?.unwrap(),11.212785715739017, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")?.unwrap(),11.212785715739017, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")?.unwrap(),11.78985824374389, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "sum_of_squares")?.unwrap(),54.83333333333333, epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_population")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_sampling")?.unwrap(),10.966666666666663, epsilon = EPSILON_FOR_TEST); + const EXPECTED_VARIANCE: f64 = 9.138888888888888; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 10.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 31.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 5.166666666666667, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 3.311595788538611, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + -0.8794523824056837, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + -0.8794523824056837, + 0.00000000000001 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + -1.4565249104105549, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 11.212785715739017, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 11.212785715739017, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 11.78985824374389, + EPSILON_FOR_TEST + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 215.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 10.966666666666663, + EPSILON_FOR_TEST + ); Ok(()) } @@ -993,28 +1240,130 @@ mod tests { let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); - const EXPECTED_VARIANCE : f64 = 2.9166666666666665; - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "count")?.unwrap(),6.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "min")?.unwrap(),1.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "max")?.unwrap(),6.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "sum")?.unwrap(),21.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "avg")?.unwrap(),3.5); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_population")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_sampling")?.unwrap(),1.8708286933869709, epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")?.unwrap(),0.9382623085101005, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")?.unwrap(),0.9382623085101005, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")?.unwrap(),0.6937569599195434, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")?.unwrap(),6.061737691489899, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")?.unwrap(),6.061737691489899, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")?.unwrap(),6.3062430400804566, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "sum_of_squares")?.unwrap(),17.5, epsilon = f64::EPSILON); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_population")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_sampling")?.unwrap(),3.5, epsilon = EPSILON_FOR_TEST); + const EXPECTED_VARIANCE: f64 = 2.9166666666666665; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 1.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 21.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 3.5 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 1.8708286933869709, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + 0.9382623085101005, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + 0.9382623085101005, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + 0.6937569599195434, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 6.061737691489899, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 6.061737691489899, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 6.3062430400804566, + EPSILON_FOR_TEST + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 91.0 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 3.5, + EPSILON_FOR_TEST + ); Ok(()) - } + } #[test] fn test_aggregation_extended_stats_with_variance_similar_to_mean() -> crate::Result<()> { @@ -1037,25 +1386,128 @@ mod tests { let reader = index.reader()?; let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); - const EXPECTED_VARIANCE : f64 = 5.5555555555608854e-5; - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "count")?.unwrap(),6.0); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "min")?.unwrap(),50.01); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "max")?.unwrap(),50.03); - assert_eq!(agg_res.get_value_from_aggregation("my_stats", "sum")?.unwrap(),300.1); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "avg")?.unwrap(),50.01666666666667,epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_population")?.unwrap(),EXPECTED_VARIANCE.sqrt(), epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_sampling")?.unwrap(),0.008164965809279263, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")?.unwrap(),50.00548632677917, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")?.unwrap(),50.00548632677917, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")?.unwrap(),50.00441921795275, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")?.unwrap(),50.027847006554175, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")?.unwrap(),50.027847006554175, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")?.unwrap(),50.028914115380594, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "sum_of_squares")?.unwrap(),0.00033333333333346484, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_population")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance")?.unwrap(),EXPECTED_VARIANCE, epsilon = EPSILON_FOR_TEST); - assert_relative_eq!(agg_res.get_value_from_aggregation("my_stats", "variance_sampling")?.unwrap(),6.666666666670718e-5, epsilon = EPSILON_FOR_TEST); + const EXPECTED_VARIANCE: f64 = 5.5555555555608854e-5; + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "count")? + .unwrap(), + 6.0 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "min")? + .unwrap(), + 50.01 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "max")? + .unwrap(), + 50.03 + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum")? + .unwrap(), + 300.1 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "avg")? + .unwrap(), + 50.01666666666667, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_population")? + .unwrap(), + EXPECTED_VARIANCE.sqrt(), + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_sampling")? + .unwrap(), + 0.008164965809279263, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower")? + .unwrap(), + 50.00548632677917, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_population")? + .unwrap(), + 50.00548632677917, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.lower_sampling")? + .unwrap(), + 50.00441921795275, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper")? + .unwrap(), + 50.027847006554175, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_population")? + .unwrap(), + 50.027847006554175, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "std_deviation_bounds.upper_sampling")? + .unwrap(), + 50.028914115380594, + EPSILON_FOR_TEST + ); + assert_eq!( + agg_res + .get_value_from_aggregation("my_stats", "sum_of_squares")? + .unwrap(), + 15010.002 + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_population")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance")? + .unwrap(), + EXPECTED_VARIANCE, + EPSILON_FOR_TEST + ); + assert_nearly_equals!( + agg_res + .get_value_from_aggregation("my_stats", "variance_sampling")? + .unwrap(), + 6.666666666670718e-5, + EPSILON_FOR_TEST + ); Ok(()) } @@ -1087,12 +1539,11 @@ mod tests { assert!(extended_stats.std_deviation_sampling.is_none()); assert!(extended_stats.std_deviation_bounds.is_none()); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(0.0f64, sum_of_squares); + assert_eq!(1.0f64, sum_of_squares); } #[test] fn extended_stat_multiple_values() { - let mut intermediate_extend_stats = IntermediateExtendedStats::default(); intermediate_extend_stats.collect(1.0f64); intermediate_extend_stats.collect(3.0f64); @@ -1102,7 +1553,7 @@ mod tests { intermediate_extend_stats.collect(10.0f64); let extended_stats = intermediate_extend_stats.finalize(); let variance = extended_stats.variance.unwrap(); - const EXPECTED_VARIANCE : f64 = 9.138888888888888; + const EXPECTED_VARIANCE: f64 = 9.138888888888888; assert_eq!(EXPECTED_VARIANCE, variance); let variance_population = extended_stats.variance_population.unwrap(); assert_eq!(EXPECTED_VARIANCE, variance_population); @@ -1115,7 +1566,7 @@ mod tests { let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); assert_eq!(10.966666666666665f64.sqrt(), std_deviation_sampling); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(54.83333333333333f64, sum_of_squares); + assert_eq!(215.0, sum_of_squares); let avg = extended_stats.avg.unwrap(); assert_eq!(5.166666666666667, avg); } @@ -1134,7 +1585,7 @@ mod tests { assert!(extended_stats.std_deviation_population.is_none()); assert!(extended_stats.std_deviation_sampling.is_none()); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(0.0f64, sum_of_squares); + assert_eq!(1.0f64, sum_of_squares); } #[test] @@ -1149,7 +1600,7 @@ mod tests { let mut intermediate_extend_stats = IntermediateExtendedStats::default(); intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); let extended_stats = intermediate_extend_stats.finalize(); - const EXPECTED_VARIANCE : f64 = 2.0; + const EXPECTED_VARIANCE: f64 = 2.0; let variance = extended_stats.variance.unwrap(); assert_eq!(EXPECTED_VARIANCE, variance); let variance_population = extended_stats.variance_population.unwrap(); @@ -1163,7 +1614,7 @@ mod tests { let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); assert_eq!(2.5f64.sqrt(), std_deviation_sampling); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(10f64, sum_of_squares); + assert_eq!(55f64, sum_of_squares); } #[test] @@ -1192,7 +1643,7 @@ mod tests { let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); assert_eq!(2.5f64.sqrt(), std_deviation_sampling); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(10f64, sum_of_squares); + assert_eq!(55f64, sum_of_squares); let mut intermediate_extend_stats = IntermediateExtendedStats::default(); intermediate_extend_stats.collect(1.0f64); @@ -1204,7 +1655,7 @@ mod tests { intermediate_extend_stats1.collect(10.0f64); intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); let extended_stats = intermediate_extend_stats.finalize(); - const EXPECTED_VARIANCE : f64 = 9.138888888888888; + const EXPECTED_VARIANCE: f64 = 9.138888888888888; let variance = extended_stats.variance.unwrap(); assert_eq!(EXPECTED_VARIANCE, variance); let variance_population = extended_stats.variance_population.unwrap(); @@ -1218,13 +1669,11 @@ mod tests { let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); assert_eq!(10.966666666666665f64.sqrt(), std_deviation_sampling); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_eq!(54.83333333333333f64, sum_of_squares); + assert_eq!(215f64, sum_of_squares); let avg = extended_stats.avg.unwrap(); assert_eq!(5.166666666666667, avg); } - - #[test] fn merge_and_then_collect_non_empty_extended_stats() { let mut intermediate_extend_stats = IntermediateExtendedStats::default(); @@ -1238,23 +1687,30 @@ mod tests { intermediate_extend_stats.merge_fruits(intermediate_extend_stats1); intermediate_extend_stats.collect(4.0f64); let extended_stats = intermediate_extend_stats.finalize(); - const EXPECTED_VARIANCE : f64 = 9.138888888888888; + const EXPECTED_VARIANCE: f64 = 9.138888888888888; let variance = extended_stats.variance.unwrap(); - assert_relative_eq!(EXPECTED_VARIANCE,variance, epsilon = EPSILON_FOR_TEST); - let variance_population = extended_stats.variance_population.unwrap(); - assert_relative_eq!(EXPECTED_VARIANCE,variance_population, epsilon = EPSILON_FOR_TEST); + assert_nearly_equals!(EXPECTED_VARIANCE, variance, EPSILON_FOR_TEST); + let variance_population = extended_stats.variance_population.unwrap(); + assert_nearly_equals!(EXPECTED_VARIANCE, variance_population, EPSILON_FOR_TEST); let variance_sampling = extended_stats.variance_sampling.unwrap(); - assert_relative_eq!(10.966666666666665,variance_sampling, epsilon = EPSILON_FOR_TEST); + assert_nearly_equals!(10.966666666666665, variance_sampling, EPSILON_FOR_TEST); let std_deviation = extended_stats.std_deviation.unwrap(); - assert_relative_eq!(EXPECTED_VARIANCE.sqrt(),std_deviation, epsilon = EPSILON_FOR_TEST); + assert_nearly_equals!(EXPECTED_VARIANCE.sqrt(), std_deviation, EPSILON_FOR_TEST); let std_deviation_population = extended_stats.std_deviation_population.unwrap(); - assert_relative_eq!(EXPECTED_VARIANCE.sqrt(),std_deviation_population, epsilon = EPSILON_FOR_TEST); + assert_nearly_equals!( + EXPECTED_VARIANCE.sqrt(), + std_deviation_population, + EPSILON_FOR_TEST + ); let std_deviation_sampling = extended_stats.std_deviation_sampling.unwrap(); - assert_relative_eq!(10.966666666666665_f64.sqrt(),std_deviation_sampling, epsilon = EPSILON_FOR_TEST); + assert_nearly_equals!( + 10.966666666666665_f64.sqrt(), + std_deviation_sampling, + EPSILON_FOR_TEST + ); let sum_of_squares = extended_stats.sum_of_squares.unwrap(); - assert_relative_eq!(54.83333333333333,sum_of_squares, epsilon = EPSILON_FOR_TEST); + assert_eq!(215.0, sum_of_squares); let avg = extended_stats.avg.unwrap(); assert_eq!(5.166666666666667, avg); } - } diff --git a/src/lib.rs b/src/lib.rs index 61751e8dd..1196fea1f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -418,6 +418,23 @@ pub mod tests { } } }}; + ($left:expr, $right:expr, $epsilon:expr) => {{ + match (&$left, &$right, &$epsilon) { + (left_val, right_val, epsilon_val) => { + let diff = (left_val - right_val).abs(); + + if diff > *epsilon_val { + panic!( + r#"assertion failed: `abs(left-right)>epsilon` + left: `{:?}`, + right: `{:?}`, + epsilon: `{:?}`"#, + &*left_val, &*right_val, &*epsilon_val + ) + } + } + } + }}; } pub fn generate_nonunique_unsorted(max_value: u32, n_elems: usize) -> Vec {