diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 8e3599108..3730a1aec 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -35,7 +35,7 @@ use super::bucket::{ }; use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, - PercentilesAggregationReq, StatsAggregation, SumAggregation, + PercentilesAggregationReq, StatsAggregation, SumAggregation, ExtendedStatsAggregation }; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user @@ -141,6 +141,11 @@ pub enum AggregationVariants { /// extracted values. #[serde(rename = "stats")] Stats(StatsAggregation), + /// Computes a collection of estended statistics (`min`, `max`, `sum`, `count`, `avg`, + /// `sum_of_squares`, `variance`, `variance_sampling`, `std_deviation`, + /// `std_deviation_sampling`) over the extracted values. + #[serde(rename = "extended_stats")] + ExtendedStats(ExtendedStatsAggregation), /// Computes the sum of the extracted values. #[serde(rename = "sum")] Sum(SumAggregation), @@ -162,6 +167,7 @@ impl AggregationVariants { AggregationVariants::Max(max) => max.field_name(), AggregationVariants::Min(min) => min.field_name(), AggregationVariants::Stats(stats) => stats.field_name(), + AggregationVariants::ExtendedStats(extended_stats) => extended_stats.field_name(), AggregationVariants::Sum(sum) => sum.field_name(), AggregationVariants::Percentiles(per) => per.field_name(), } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index e6f960d05..d9b00cf1a 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -9,7 +9,7 @@ use super::bucket::{ }; use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation, - SumAggregation, + SumAggregation, ExtendedStatsAggregation, }; use super::segment_agg_result::AggregationLimits; use super::VecWithNames; @@ -229,6 +229,9 @@ impl AggregationWithAccessor { | Stats(StatsAggregation { field: field_name, .. }) + | ExtendedStats(ExtendedStatsAggregation { + field: field_name, .. + }) | Sum(SumAggregation { field: field_name, .. }) => { diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index ff9e7716f..64db3c66e 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::GetDocCount; -use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats}; +use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats, ExtendedStats}; use super::{AggregationError, Key}; use crate::TantivyError; @@ -88,6 +88,8 @@ pub enum MetricResult { Min(SingleMetricResult), /// Stats metric result. Stats(Stats), + /// ExtendedStats metric result. + ExtendedStats(ExtendedStats), /// Sum metric result. Sum(SingleMetricResult), /// Sum metric result. @@ -102,6 +104,7 @@ impl MetricResult { MetricResult::Max(max) => Ok(max.value), MetricResult::Min(min) => Ok(min.value), MetricResult::Stats(stats) => stats.get_value(agg_property), + MetricResult::ExtendedStats(extended_stats) => extended_stats.get_value(agg_property), MetricResult::Sum(sum) => Ok(sum.value), MetricResult::Percentiles(_) => Err(TantivyError::AggregationError( AggregationError::InvalidRequest("percentiles can't be used to order".to_string()), diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 4bb056d5c..8193f68d7 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -19,7 +19,7 @@ use super::bucket::{ }; use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, - IntermediateSum, PercentilesCollector, + IntermediateSum, PercentilesCollector,IntermediateExtendedStats, }; use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; @@ -199,6 +199,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats( IntermediateStats::default(), )), + ExtendedStats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( + IntermediateExtendedStats::default(), + )), Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum( IntermediateSum::default(), )), @@ -263,6 +266,8 @@ pub enum IntermediateMetricResult { Min(IntermediateMin), /// Intermediate stats result. Stats(IntermediateStats), + /// Intermediate stats result. + ExtendedStats(IntermediateExtendedStats), /// Intermediate sum result. Sum(IntermediateSum), } @@ -285,6 +290,9 @@ impl IntermediateMetricResult { IntermediateMetricResult::Stats(intermediate_stats) => { MetricResult::Stats(intermediate_stats.finalize()) } + IntermediateMetricResult::ExtendedStats(intermediate_stats) => { + MetricResult::ExtendedStats(intermediate_stats.finalize()) + } IntermediateMetricResult::Sum(intermediate_sum) => { MetricResult::Sum(intermediate_sum.finalize().into()) } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index e1bd383b4..aa3cb90d7 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -51,6 +51,41 @@ impl StatsAggregation { } } + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ExtendedStatsAggregation { + /// The field name to compute the stats on. + pub field: String, + /// The missing parameter defines how documents that are missing a value should be treated. + /// By default they will be ignored but it is also possible to treat them as if they had a + /// value. Examples in JSON format: + /// { "field": "my_numbers", "missing": "10.0" } + #[serde(default)] + pub missing: Option, + /// The sigma parameter defines how standard_deviation_bound_are_calculated. + /// This can be a useful way to visualize variance of your data. + /// The default value is 2. Examples in JSON format: + /// { "field": "my_numbers", "sigma": "3.0" } + #[serde(default)] + pub sigma: Option, +} + +impl ExtendedStatsAggregation { + /// Creates a new [`ExtendedStatsAggregation`] instance from a field name. + pub fn from_field_name(field_name: String) -> Self { + ExtendedStatsAggregation { + field: field_name, + missing: None, + sigma: None, + } + } + /// Returns the field name the aggregation is computed on. + pub fn field_name(&self) -> &str { + &self.field + } +} + + /// Stats contains a collection of statistics. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct Stats { @@ -81,6 +116,60 @@ impl Stats { } } +/// Extended stats contains a collection of statistics +/// they extends stats adding variance, standard deviation +/// and bound informations +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ExtendedStats { + /// The number of documents. + pub count: u64, + /// The sum of the fast field values. + pub sum: f64, + /// The min value of the fast field values. + pub min: Option, + /// The max value of the fast field values. + pub max: Option, + /// The average of the fast field values. `None` if count equals zero. + pub avg: Option, + /// The sum of squares of the fast field values. `None` if count equals zero. + pub sum_of_squares: Option, + /// The variance of the fast field values. `None` if count is less then 2. + pub variance: Option, + /// The variance population of the fast field values, always equal to variance. `None` if count is less then 2. + pub variance_population: Option, + /// The variance sampling of the fast field values, always equal to variance. `None` if count is less then 2. + pub variance_sampling: Option, + /// The standard deviation of the fast field values. `None` if count is less then 2. + pub standard_deviation: Option, + /// The standard deviation of the fast field values, always equal to variance. `None` if count is less then 2. + pub standard_deviation_population: Option, + /// The standard deviation sampling of the fast field values, always equal to variance. `None` if count is less then 2. + pub standard_deviation_sampling: Option, +} + +impl ExtendedStats { + pub(crate) fn get_value(&self, agg_property: &str) -> crate::Result> { + match agg_property { + "count" => Ok(Some(self.count as f64)), + "sum" => Ok(Some(self.sum)), + "min" => Ok(self.min), + "max" => Ok(self.max), + "avg" => Ok(self.avg), + "variance" => Ok(self.variance), + "variance_sampling" => Ok(self.variance_sampling), + "variance_population" => Ok(self.variance_population), + "sum_of_squares" => Ok(self.sum_of_squares), + "standard_deviation" => Ok(self.standard_deviation), + "standard_deviation_sampling" => Ok(self.standard_deviation_sampling), + "standard_deviation_population" => Ok(self.standard_deviation_population), + _ => Err(TantivyError::InvalidArgument(format!( + "Unknown property {agg_property} on stats metric aggregation" + ))), + } + } +} + + /* #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -185,6 +274,157 @@ impl IntermediateStats { } +/// Intermediate result of the extended stats aggregation that can be combined with other intermediate +/// results. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateExtendedStats { + /// The number of extracted values. + count: u64, + /// The sum of the extracted values. + sum: f64, + /// The min value. + min: f64, + /// The max value. + max: f64, + // The sum of the square values it's referred as M2 in Welford's online algorithm + sum_of_squares: f64, + // The mean an intermediate value need for calculating the variance + // as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) + mean: f64, + // the value used for computing standard deviation bounds + sigma: f64, +} + +impl Default for IntermediateExtendedStats { + fn default() -> Self { + Self { + count: 0, + sum: 0.0, + min: f64::MAX, + max: f64::MIN, + sum_of_squares: 0.0, + mean: 0.0, + sigma: 2.0, + } + } +} + +impl IntermediateExtendedStats { + + pub fn with_sigma(sigma: Option) -> Self { + Self { + count: 0, + sum: 0.0, + min: f64::MAX, + max: f64::MIN, + sum_of_squares: 0.0, + mean: 0.0, + sigma: sigma.unwrap_or(2.0), + } + } + /// Merges the other stats intermediate result into self. + pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) { + + self.min = self.min.min(other.min); + self.max = self.max.max(other.max); + + + if other.count!=0 { + if self.count==0 { + self.sum_of_squares=other.sum_of_squares; + self.count=other.count; + self.mean=other.mean; + } else { + // parallel version of Welford's online algorithm + // the mean is computed using sum and count because + // it's more precise (and sum is already available) + let new_count=self.count+other.count; + let delta = other.sum/other.count as f64 - self.sum/self.count as f64; + self.sum_of_squares += other.sum_of_squares + delta * delta * self.count as f64 * other.count as f64/new_count as f64; + self.count =new_count; + //self.mean=self.mean + delta*other.count as f64/new_count as f64; + self.mean=(self.sum as f64 + other.sum as f64)/new_count as f64; + + } + self.sum += other.sum; + } + + } + + + + /// Computes the final stats value. + pub fn finalize(&self) -> ExtendedStats { + let min = if self.count == 0 { + None + } else { + Some(self.min) + }; + let max = if self.count == 0 { + None + } else { + Some(self.max) + }; + let avg = if self.count == 0 { + None + } else { + Some(self.mean) + }; + let sum_of_squares = if self.count == 0 { + None + } else { + Some(self.sum_of_squares) + }; + let variance = if self.count <= 1 { + None + } else { + Some(self.sum_of_squares/self.count as f64) + }; + let variance_sampling = if self.count <= 1 { + None + } else { + Some(self.sum_of_squares/(self.count-1) as f64) + }; + let standard_deviation = variance.map(|v| v.sqrt()); + let standard_deviation_sampling = variance_sampling.map(|v| v.sqrt()); + + ExtendedStats { + count: self.count, + sum: self.sum, + min, + max, + avg, + sum_of_squares, + variance, + variance_population: variance, + variance_sampling, + standard_deviation, + standard_deviation_population: standard_deviation, + standard_deviation_sampling + } + } + + fn collect(&mut self, value: f64) { + self.count += 1; + self.sum += value; + self.min = self.min.min(value); + self.max = self.max.max(value); + self.update_variance(value); + } + + fn update_variance(&mut self, value: f64) { + let delta = value - self.mean; + //this is not what the Welford's online algorithm prescribes but + //using the pseudo code from wikipedia there was a small rounding + //error (in 15th decimal place) that caused a test + //(test_aggregation_level1 in agg_test.rs) + //failure + self.mean = self.sum / self.count as f64; + let delta2 = value - self.mean; + self.sum_of_squares += delta * delta2; + } +} + #[derive(Clone, Debug, PartialEq)] pub(crate) enum SegmentStatsType { Average, @@ -198,6 +438,7 @@ pub(crate) enum SegmentStatsType { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentStatsCollector { missing: Option, + sigma: Option, field_type: ColumnType, pub(crate) collecting_for: SegmentStatsType, pub(crate) stats: IntermediateExtendedStats, @@ -211,14 +452,16 @@ impl SegmentStatsCollector { collecting_for: SegmentStatsType, accessor_idx: usize, missing: Option, + sigma: Option, ) -> Self { let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); Self { field_type, collecting_for, - stats: IntermediateExtendedStats::default(), + stats: IntermediateExtendedStats::with_sigma(sigma), accessor_idx, missing, + sigma, val_cache: Default::default(), } } @@ -324,169 +567,6 @@ impl SegmentAggregationCollector for SegmentStatsCollector { } } -/// Extended stats contains a collection of statistics -/// they extends stats adding variance, standard deviation -/// and bound informations -pub struct ExtendedStats { - /// The number of documents. - pub count: u64, - /// The sum of the fast field values. - pub sum: f64, - /// The min value of the fast field values. - pub min: Option, - /// The max value of the fast field values. - pub max: Option, - /// The average of the fast field values. `None` if count equals zero. - pub avg: Option, - /// The sum of squares of the fast field values. `None` if count equals zero. - pub sum_of_squares: Option, - /// The variance of the fast field values. `None` if count is less then 2. - pub variance: Option, - /// The variance population of the fast field values, always equal to variance. `None` if count is less then 2. - pub variance_population: Option, - /// The variance sampling of the fast field values, always equal to variance. `None` if count is less then 2. - pub variance_sampling: Option, - /// The standard deviation of the fast field values. `None` if count is less then 2. - pub standard_deviation: Option, - /// The standard deviation of the fast field values, always equal to variance. `None` if count is less then 2. - pub standard_deviation_population: Option, - /// The standard deviation sampling of the fast field values, always equal to variance. `None` if count is less then 2. - pub standard_deviation_sampling: Option, -} - -/// Intermediate result of the extended stats aggregation that can be combined with other intermediate -/// results. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct IntermediateExtendedStats { - /// The number of extracted values. - count: u64, - /// The sum of the extracted values. - sum: f64, - /// The min value. - min: f64, - /// The max value. - max: f64, - // The sum of the square values it's referred as M2 in Welford's online algorithm - sum_of_squares: f64, - // The mean an intermediate value need for calculating the variance - // as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) - mean: f64, -} - -impl Default for IntermediateExtendedStats { - fn default() -> Self { - Self { - count: 0, - sum: 0.0, - min: f64::MAX, - max: f64::MIN, - sum_of_squares: 0.0, - mean: 0.0, - } - } -} - -impl IntermediateExtendedStats { - /// Merges the other stats intermediate result into self. - pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) { - - self.min = self.min.min(other.min); - self.max = self.max.max(other.max); - - - if other.count!=0 { - if self.count==0 { - self.sum_of_squares=other.sum_of_squares; - self.count=other.count; - self.mean=other.mean; - } else { - let new_count=self.count+other.count; - let delta = other.sum/other.count as f64 - self.sum/self.count as f64; - self.sum_of_squares += other.sum_of_squares + delta * delta * self.count as f64 * other.count as f64/new_count as f64; - self.count =new_count; - //self.mean=self.mean + delta*other.count as f64/new_count as f64; - self.mean=(self.sum as f64 + other.sum as f64)/new_count as f64; - - } - self.sum += other.sum; - } - - } - - - - /// Computes the final stats value. - pub fn finalize(&self) -> ExtendedStats { - let min = if self.count == 0 { - None - } else { - Some(self.min) - }; - let max = if self.count == 0 { - None - } else { - Some(self.max) - }; - let avg = if self.count == 0 { - None - } else { - Some(self.mean) - }; - let sum_of_squares = if self.count == 0 { - None - } else { - Some(self.sum_of_squares) - }; - let variance = if self.count <= 1 { - None - } else { - Some(self.sum_of_squares/self.count as f64) - }; - let variance_sampling = if self.count <= 1 { - None - } else { - Some(self.sum_of_squares/(self.count-1) as f64) - }; - let standard_deviation = variance.map(|v| v.sqrt()); - let standard_deviation_sampling = variance_sampling.map(|v| v.sqrt()); - - ExtendedStats { - count: self.count, - sum: self.sum, - min, - max, - avg, - sum_of_squares, - variance, - variance_population: variance, - variance_sampling, - standard_deviation, - standard_deviation_population: standard_deviation, - standard_deviation_sampling - } - } - - fn collect(&mut self, value: f64) { - self.count += 1; - self.sum += value; - self.min = self.min.min(value); - self.max = self.max.max(value); - self.update_variance(value); - } - - fn update_variance(&mut self, value: f64) { - let delta = value - self.mean; - //this is not what the Welford's online algorithm prescribes but - //using the pseudo code from wikipedia there was a small rounding - //error (in 15th decimal place) that caused a test - //(test_aggregation_level1 in agg_test.rs) - //failure - self.mean = self.sum / self.count as f64; - let delta2 = value - self.mean; - self.sum_of_squares += delta * delta2; - } -} - #[cfg(test)] mod tests { diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index e57579647..8ae9382b3 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -13,7 +13,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::metric::{ AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, - SumAggregation, + SumAggregation, ExtendedStatsAggregation, }; use crate::aggregation::bucket::TermMissingAgg; @@ -121,6 +121,7 @@ pub(crate) fn build_single_agg_segment_collector( SegmentStatsType::Average, accessor_idx, *missing, + None, ))) } Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( @@ -128,30 +129,42 @@ pub(crate) fn build_single_agg_segment_collector( SegmentStatsType::Count, accessor_idx, *missing, + None, ))), Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( req.field_type, SegmentStatsType::Max, accessor_idx, *missing, + None, ))), Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( req.field_type, SegmentStatsType::Min, accessor_idx, *missing, + None, ))), Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( req.field_type, SegmentStatsType::Stats, accessor_idx, *missing, + None, ))), + ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( + req.field_type, + SegmentStatsType::Stats, + accessor_idx, + *missing, + *sigma, + ))), Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( req.field_type, SegmentStatsType::Sum, accessor_idx, *missing, + None, ))), Percentiles(percentiles_req) => Ok(Box::new( SegmentPercentilesCollector::from_req_and_validate(