Created struct for request and response

This commit is contained in:
Giovanni Cuccu
2023-11-04 16:02:10 +01:00
parent 86bdb8b95c
commit db91df9f70
6 changed files with 282 additions and 169 deletions

View File

@@ -35,7 +35,7 @@ use super::bucket::{
};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation,
PercentilesAggregationReq, StatsAggregation, SumAggregation,
PercentilesAggregationReq, StatsAggregation, SumAggregation, ExtendedStatsAggregation
};
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
@@ -141,6 +141,11 @@ pub enum AggregationVariants {
/// extracted values.
#[serde(rename = "stats")]
Stats(StatsAggregation),
/// Computes a collection of estended statistics (`min`, `max`, `sum`, `count`, `avg`,
/// `sum_of_squares`, `variance`, `variance_sampling`, `std_deviation`,
/// `std_deviation_sampling`) over the extracted values.
#[serde(rename = "extended_stats")]
ExtendedStats(ExtendedStatsAggregation),
/// Computes the sum of the extracted values.
#[serde(rename = "sum")]
Sum(SumAggregation),
@@ -162,6 +167,7 @@ impl AggregationVariants {
AggregationVariants::Max(max) => max.field_name(),
AggregationVariants::Min(min) => min.field_name(),
AggregationVariants::Stats(stats) => stats.field_name(),
AggregationVariants::ExtendedStats(extended_stats) => extended_stats.field_name(),
AggregationVariants::Sum(sum) => sum.field_name(),
AggregationVariants::Percentiles(per) => per.field_name(),
}

View File

@@ -9,7 +9,7 @@ use super::bucket::{
};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
SumAggregation, ExtendedStatsAggregation,
};
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
@@ -229,6 +229,9 @@ impl AggregationWithAccessor {
| Stats(StatsAggregation {
field: field_name, ..
})
| ExtendedStats(ExtendedStatsAggregation {
field: field_name, ..
})
| Sum(SumAggregation {
field: field_name, ..
}) => {

View File

@@ -8,7 +8,7 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::GetDocCount;
use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats};
use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats, ExtendedStats};
use super::{AggregationError, Key};
use crate::TantivyError;
@@ -88,6 +88,8 @@ pub enum MetricResult {
Min(SingleMetricResult),
/// Stats metric result.
Stats(Stats),
/// ExtendedStats metric result.
ExtendedStats(ExtendedStats),
/// Sum metric result.
Sum(SingleMetricResult),
/// Sum metric result.
@@ -102,6 +104,7 @@ impl MetricResult {
MetricResult::Max(max) => Ok(max.value),
MetricResult::Min(min) => Ok(min.value),
MetricResult::Stats(stats) => stats.get_value(agg_property),
MetricResult::ExtendedStats(extended_stats) => extended_stats.get_value(agg_property),
MetricResult::Sum(sum) => Ok(sum.value),
MetricResult::Percentiles(_) => Err(TantivyError::AggregationError(
AggregationError::InvalidRequest("percentiles can't be used to order".to_string()),

View File

@@ -19,7 +19,7 @@ use super::bucket::{
};
use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
IntermediateSum, PercentilesCollector,
IntermediateSum, PercentilesCollector,IntermediateExtendedStats,
};
use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey};
@@ -199,6 +199,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats(
IntermediateStats::default(),
)),
ExtendedStats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
IntermediateExtendedStats::default(),
)),
Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum(
IntermediateSum::default(),
)),
@@ -263,6 +266,8 @@ pub enum IntermediateMetricResult {
Min(IntermediateMin),
/// Intermediate stats result.
Stats(IntermediateStats),
/// Intermediate stats result.
ExtendedStats(IntermediateExtendedStats),
/// Intermediate sum result.
Sum(IntermediateSum),
}
@@ -285,6 +290,9 @@ impl IntermediateMetricResult {
IntermediateMetricResult::Stats(intermediate_stats) => {
MetricResult::Stats(intermediate_stats.finalize())
}
IntermediateMetricResult::ExtendedStats(intermediate_stats) => {
MetricResult::ExtendedStats(intermediate_stats.finalize())
}
IntermediateMetricResult::Sum(intermediate_sum) => {
MetricResult::Sum(intermediate_sum.finalize().into())
}

View File

@@ -51,6 +51,41 @@ impl StatsAggregation {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStatsAggregation {
/// The field name to compute the stats on.
pub field: String,
/// The missing parameter defines how documents that are missing a value should be treated.
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default)]
pub missing: Option<f64>,
/// The sigma parameter defines how standard_deviation_bound_are_calculated.
/// This can be a useful way to visualize variance of your data.
/// The default value is 2. Examples in JSON format:
/// { "field": "my_numbers", "sigma": "3.0" }
#[serde(default)]
pub sigma: Option<f64>,
}
impl ExtendedStatsAggregation {
/// Creates a new [`ExtendedStatsAggregation`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
ExtendedStatsAggregation {
field: field_name,
missing: None,
sigma: None,
}
}
/// Returns the field name the aggregation is computed on.
pub fn field_name(&self) -> &str {
&self.field
}
}
/// Stats contains a collection of statistics.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Stats {
@@ -81,6 +116,60 @@ impl Stats {
}
}
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound informations
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStats {
/// The number of documents.
pub count: u64,
/// The sum of the fast field values.
pub sum: f64,
/// The min value of the fast field values.
pub min: Option<f64>,
/// The max value of the fast field values.
pub max: Option<f64>,
/// The average of the fast field values. `None` if count equals zero.
pub avg: Option<f64>,
/// The sum of squares of the fast field values. `None` if count equals zero.
pub sum_of_squares: Option<f64>,
/// The variance of the fast field values. `None` if count is less then 2.
pub variance: Option<f64>,
/// The variance population of the fast field values, always equal to variance. `None` if count is less then 2.
pub variance_population: Option<f64>,
/// The variance sampling of the fast field values, always equal to variance. `None` if count is less then 2.
pub variance_sampling: Option<f64>,
/// The standard deviation of the fast field values. `None` if count is less then 2.
pub standard_deviation: Option<f64>,
/// The standard deviation of the fast field values, always equal to variance. `None` if count is less then 2.
pub standard_deviation_population: Option<f64>,
/// The standard deviation sampling of the fast field values, always equal to variance. `None` if count is less then 2.
pub standard_deviation_sampling: Option<f64>,
}
impl ExtendedStats {
pub(crate) fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
match agg_property {
"count" => Ok(Some(self.count as f64)),
"sum" => Ok(Some(self.sum)),
"min" => Ok(self.min),
"max" => Ok(self.max),
"avg" => Ok(self.avg),
"variance" => Ok(self.variance),
"variance_sampling" => Ok(self.variance_sampling),
"variance_population" => Ok(self.variance_population),
"sum_of_squares" => Ok(self.sum_of_squares),
"standard_deviation" => Ok(self.standard_deviation),
"standard_deviation_sampling" => Ok(self.standard_deviation_sampling),
"standard_deviation_population" => Ok(self.standard_deviation_population),
_ => Err(TantivyError::InvalidArgument(format!(
"Unknown property {agg_property} on stats metric aggregation"
))),
}
}
}
/*
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -185,6 +274,157 @@ impl IntermediateStats {
}
/// Intermediate result of the extended stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateExtendedStats {
/// The number of extracted values.
count: u64,
/// The sum of the extracted values.
sum: f64,
/// The min value.
min: f64,
/// The max value.
max: f64,
// The sum of the square values it's referred as M2 in Welford's online algorithm
sum_of_squares: f64,
// The mean an intermediate value need for calculating the variance
// as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
mean: f64,
// the value used for computing standard deviation bounds
sigma: f64,
}
impl Default for IntermediateExtendedStats {
fn default() -> Self {
Self {
count: 0,
sum: 0.0,
min: f64::MAX,
max: f64::MIN,
sum_of_squares: 0.0,
mean: 0.0,
sigma: 2.0,
}
}
}
impl IntermediateExtendedStats {
pub fn with_sigma(sigma: Option<f64>) -> Self {
Self {
count: 0,
sum: 0.0,
min: f64::MAX,
max: f64::MIN,
sum_of_squares: 0.0,
mean: 0.0,
sigma: sigma.unwrap_or(2.0),
}
}
/// Merges the other stats intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) {
self.min = self.min.min(other.min);
self.max = self.max.max(other.max);
if other.count!=0 {
if self.count==0 {
self.sum_of_squares=other.sum_of_squares;
self.count=other.count;
self.mean=other.mean;
} else {
// parallel version of Welford's online algorithm
// the mean is computed using sum and count because
// it's more precise (and sum is already available)
let new_count=self.count+other.count;
let delta = other.sum/other.count as f64 - self.sum/self.count as f64;
self.sum_of_squares += other.sum_of_squares + delta * delta * self.count as f64 * other.count as f64/new_count as f64;
self.count =new_count;
//self.mean=self.mean + delta*other.count as f64/new_count as f64;
self.mean=(self.sum as f64 + other.sum as f64)/new_count as f64;
}
self.sum += other.sum;
}
}
/// Computes the final stats value.
pub fn finalize(&self) -> ExtendedStats {
let min = if self.count == 0 {
None
} else {
Some(self.min)
};
let max = if self.count == 0 {
None
} else {
Some(self.max)
};
let avg = if self.count == 0 {
None
} else {
Some(self.mean)
};
let sum_of_squares = if self.count == 0 {
None
} else {
Some(self.sum_of_squares)
};
let variance = if self.count <= 1 {
None
} else {
Some(self.sum_of_squares/self.count as f64)
};
let variance_sampling = if self.count <= 1 {
None
} else {
Some(self.sum_of_squares/(self.count-1) as f64)
};
let standard_deviation = variance.map(|v| v.sqrt());
let standard_deviation_sampling = variance_sampling.map(|v| v.sqrt());
ExtendedStats {
count: self.count,
sum: self.sum,
min,
max,
avg,
sum_of_squares,
variance,
variance_population: variance,
variance_sampling,
standard_deviation,
standard_deviation_population: standard_deviation,
standard_deviation_sampling
}
}
fn collect(&mut self, value: f64) {
self.count += 1;
self.sum += value;
self.min = self.min.min(value);
self.max = self.max.max(value);
self.update_variance(value);
}
fn update_variance(&mut self, value: f64) {
let delta = value - self.mean;
//this is not what the Welford's online algorithm prescribes but
//using the pseudo code from wikipedia there was a small rounding
//error (in 15th decimal place) that caused a test
//(test_aggregation_level1 in agg_test.rs)
//failure
self.mean = self.sum / self.count as f64;
let delta2 = value - self.mean;
self.sum_of_squares += delta * delta2;
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentStatsType {
Average,
@@ -198,6 +438,7 @@ pub(crate) enum SegmentStatsType {
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentStatsCollector {
missing: Option<u64>,
sigma: Option<f64>,
field_type: ColumnType,
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateExtendedStats,
@@ -211,14 +452,16 @@ impl SegmentStatsCollector {
collecting_for: SegmentStatsType,
accessor_idx: usize,
missing: Option<f64>,
sigma: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
field_type,
collecting_for,
stats: IntermediateExtendedStats::default(),
stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
missing,
sigma,
val_cache: Default::default(),
}
}
@@ -324,169 +567,6 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
}
}
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound informations
pub struct ExtendedStats {
/// The number of documents.
pub count: u64,
/// The sum of the fast field values.
pub sum: f64,
/// The min value of the fast field values.
pub min: Option<f64>,
/// The max value of the fast field values.
pub max: Option<f64>,
/// The average of the fast field values. `None` if count equals zero.
pub avg: Option<f64>,
/// The sum of squares of the fast field values. `None` if count equals zero.
pub sum_of_squares: Option<f64>,
/// The variance of the fast field values. `None` if count is less then 2.
pub variance: Option<f64>,
/// The variance population of the fast field values, always equal to variance. `None` if count is less then 2.
pub variance_population: Option<f64>,
/// The variance sampling of the fast field values, always equal to variance. `None` if count is less then 2.
pub variance_sampling: Option<f64>,
/// The standard deviation of the fast field values. `None` if count is less then 2.
pub standard_deviation: Option<f64>,
/// The standard deviation of the fast field values, always equal to variance. `None` if count is less then 2.
pub standard_deviation_population: Option<f64>,
/// The standard deviation sampling of the fast field values, always equal to variance. `None` if count is less then 2.
pub standard_deviation_sampling: Option<f64>,
}
/// Intermediate result of the extended stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateExtendedStats {
/// The number of extracted values.
count: u64,
/// The sum of the extracted values.
sum: f64,
/// The min value.
min: f64,
/// The max value.
max: f64,
// The sum of the square values it's referred as M2 in Welford's online algorithm
sum_of_squares: f64,
// The mean an intermediate value need for calculating the variance
// as per [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
mean: f64,
}
impl Default for IntermediateExtendedStats {
fn default() -> Self {
Self {
count: 0,
sum: 0.0,
min: f64::MAX,
max: f64::MIN,
sum_of_squares: 0.0,
mean: 0.0,
}
}
}
impl IntermediateExtendedStats {
/// Merges the other stats intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateExtendedStats) {
self.min = self.min.min(other.min);
self.max = self.max.max(other.max);
if other.count!=0 {
if self.count==0 {
self.sum_of_squares=other.sum_of_squares;
self.count=other.count;
self.mean=other.mean;
} else {
let new_count=self.count+other.count;
let delta = other.sum/other.count as f64 - self.sum/self.count as f64;
self.sum_of_squares += other.sum_of_squares + delta * delta * self.count as f64 * other.count as f64/new_count as f64;
self.count =new_count;
//self.mean=self.mean + delta*other.count as f64/new_count as f64;
self.mean=(self.sum as f64 + other.sum as f64)/new_count as f64;
}
self.sum += other.sum;
}
}
/// Computes the final stats value.
pub fn finalize(&self) -> ExtendedStats {
let min = if self.count == 0 {
None
} else {
Some(self.min)
};
let max = if self.count == 0 {
None
} else {
Some(self.max)
};
let avg = if self.count == 0 {
None
} else {
Some(self.mean)
};
let sum_of_squares = if self.count == 0 {
None
} else {
Some(self.sum_of_squares)
};
let variance = if self.count <= 1 {
None
} else {
Some(self.sum_of_squares/self.count as f64)
};
let variance_sampling = if self.count <= 1 {
None
} else {
Some(self.sum_of_squares/(self.count-1) as f64)
};
let standard_deviation = variance.map(|v| v.sqrt());
let standard_deviation_sampling = variance_sampling.map(|v| v.sqrt());
ExtendedStats {
count: self.count,
sum: self.sum,
min,
max,
avg,
sum_of_squares,
variance,
variance_population: variance,
variance_sampling,
standard_deviation,
standard_deviation_population: standard_deviation,
standard_deviation_sampling
}
}
fn collect(&mut self, value: f64) {
self.count += 1;
self.sum += value;
self.min = self.min.min(value);
self.max = self.max.max(value);
self.update_variance(value);
}
fn update_variance(&mut self, value: f64) {
let delta = value - self.mean;
//this is not what the Welford's online algorithm prescribes but
//using the pseudo code from wikipedia there was a small rounding
//error (in 15th decimal place) that caused a test
//(test_aggregation_level1 in agg_test.rs)
//failure
self.mean = self.sum / self.count as f64;
let delta2 = value - self.mean;
self.sum_of_squares += delta * delta2;
}
}
#[cfg(test)]
mod tests {

View File

@@ -13,7 +13,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation,
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
SumAggregation,
SumAggregation, ExtendedStatsAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
@@ -121,6 +121,7 @@ pub(crate) fn build_single_agg_segment_collector(
SegmentStatsType::Average,
accessor_idx,
*missing,
None,
)))
}
Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
@@ -128,30 +129,42 @@ pub(crate) fn build_single_agg_segment_collector(
SegmentStatsType::Count,
accessor_idx,
*missing,
None,
))),
Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Max,
accessor_idx,
*missing,
None,
))),
Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Min,
accessor_idx,
*missing,
None,
))),
Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
*missing,
None,
))),
ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
*missing,
*sigma,
))),
Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Sum,
accessor_idx,
*missing,
None,
))),
Percentiles(percentiles_req) => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(