mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-06 09:12:55 +00:00
Merge pull request #1794 from quickwit-oss/guilload/count-min-max-sum-aggs
Add count, min, max, and sum aggregations
This commit is contained in:
@@ -51,7 +51,10 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use super::bucket::RangeAggregation;
|
||||
use super::bucket::{HistogramAggregation, TermsAggregation};
|
||||
use super::metric::{AverageAggregation, StatsAggregation};
|
||||
use super::metric::{
|
||||
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
|
||||
SumAggregation,
|
||||
};
|
||||
use super::VecWithNames;
|
||||
|
||||
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
|
||||
@@ -237,20 +240,37 @@ impl BucketAggregationType {
|
||||
/// called multi-value numeric metrics aggregation.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum MetricAggregation {
|
||||
/// Calculates the average.
|
||||
/// Computes the average.
|
||||
#[serde(rename = "avg")]
|
||||
Average(AverageAggregation),
|
||||
/// Counts the number of extracted values.
|
||||
#[serde(rename = "value_count")]
|
||||
Count(CountAggregation),
|
||||
/// Finds the maximum value.
|
||||
#[serde(rename = "max")]
|
||||
Max(MaxAggregation),
|
||||
/// Finds the minimum value.
|
||||
#[serde(rename = "min")]
|
||||
Min(MinAggregation),
|
||||
/// Calculates stats sum, average, min, max, standard_deviation on a field.
|
||||
#[serde(rename = "stats")]
|
||||
Stats(StatsAggregation),
|
||||
/// Computes the sum.
|
||||
#[serde(rename = "sum")]
|
||||
Sum(SumAggregation),
|
||||
}
|
||||
|
||||
impl MetricAggregation {
|
||||
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
|
||||
match self {
|
||||
MetricAggregation::Average(avg) => fast_field_names.insert(avg.field.to_string()),
|
||||
MetricAggregation::Stats(stats) => fast_field_names.insert(stats.field.to_string()),
|
||||
let fast_field_name = match self {
|
||||
MetricAggregation::Average(avg) => avg.field_name(),
|
||||
MetricAggregation::Count(count) => count.field_name(),
|
||||
MetricAggregation::Max(max) => max.field_name(),
|
||||
MetricAggregation::Min(min) => min.field_name(),
|
||||
MetricAggregation::Stats(stats) => stats.field_name(),
|
||||
MetricAggregation::Sum(sum) => sum.field_name(),
|
||||
};
|
||||
fast_field_names.insert(fast_field_name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,6 +278,38 @@ impl MetricAggregation {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_metric_aggregations_deser() {
|
||||
let agg_req_json = r#"{
|
||||
"price_avg": { "avg": { "field": "price" } },
|
||||
"price_count": { "value_count": { "field": "price" } },
|
||||
"price_max": { "max": { "field": "price" } },
|
||||
"price_min": { "min": { "field": "price" } },
|
||||
"price_stats": { "stats": { "field": "price" } },
|
||||
"price_sum": { "sum": { "field": "price" } }
|
||||
}"#;
|
||||
let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
|
||||
|
||||
assert!(
|
||||
matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price")
|
||||
);
|
||||
assert!(
|
||||
matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price")
|
||||
);
|
||||
assert!(
|
||||
matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price")
|
||||
);
|
||||
assert!(
|
||||
matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price")
|
||||
);
|
||||
assert!(
|
||||
matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price")
|
||||
);
|
||||
assert!(
|
||||
matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_to_json_test() {
|
||||
let agg_req1: Aggregations = vec![(
|
||||
|
||||
@@ -8,7 +8,10 @@ use fastfield_codecs::Column;
|
||||
|
||||
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
|
||||
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
|
||||
use super::metric::{AverageAggregation, StatsAggregation};
|
||||
use super::metric::{
|
||||
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
|
||||
SumAggregation,
|
||||
};
|
||||
use super::segment_agg_result::BucketCount;
|
||||
use super::VecWithNames;
|
||||
use crate::fastfield::{type_and_cardinality, MultiValuedFastFieldReader};
|
||||
@@ -134,7 +137,11 @@ impl MetricAggregationWithAccessor {
|
||||
) -> crate::Result<MetricAggregationWithAccessor> {
|
||||
match &metric {
|
||||
MetricAggregation::Average(AverageAggregation { field: field_name })
|
||||
| MetricAggregation::Stats(StatsAggregation { field: field_name }) => {
|
||||
| MetricAggregation::Count(CountAggregation { field: field_name })
|
||||
| MetricAggregation::Max(MaxAggregation { field: field_name })
|
||||
| MetricAggregation::Min(MinAggregation { field: field_name })
|
||||
| MetricAggregation::Stats(StatsAggregation { field: field_name })
|
||||
| MetricAggregation::Sum(SumAggregation { field: field_name }) => {
|
||||
let (accessor, field_type) =
|
||||
get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?;
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ impl AggregationResults {
|
||||
} else {
|
||||
// Validation is be done during request parsing, so we can't reach this state.
|
||||
Err(TantivyError::InternalError(format!(
|
||||
"Can't find aggregation {:?} in sub_aggregations",
|
||||
"Can't find aggregation {:?} in sub-aggregations",
|
||||
name
|
||||
)))
|
||||
}
|
||||
@@ -70,27 +70,51 @@ impl AggregationResult {
|
||||
pub enum MetricResult {
|
||||
/// Average metric result.
|
||||
Average(SingleMetricResult),
|
||||
/// Count metric result.
|
||||
Count(SingleMetricResult),
|
||||
/// Max metric result.
|
||||
Max(SingleMetricResult),
|
||||
/// Min metric result.
|
||||
Min(SingleMetricResult),
|
||||
/// Stats metric result.
|
||||
Stats(Stats),
|
||||
/// Sum metric result.
|
||||
Sum(SingleMetricResult),
|
||||
}
|
||||
|
||||
impl MetricResult {
|
||||
fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
|
||||
match self {
|
||||
MetricResult::Average(avg) => Ok(avg.value),
|
||||
MetricResult::Count(count) => Ok(count.value),
|
||||
MetricResult::Max(max) => Ok(max.value),
|
||||
MetricResult::Min(min) => Ok(min.value),
|
||||
MetricResult::Stats(stats) => stats.get_value(agg_property),
|
||||
MetricResult::Sum(sum) => Ok(sum.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<IntermediateMetricResult> for MetricResult {
|
||||
fn from(metric: IntermediateMetricResult) -> Self {
|
||||
match metric {
|
||||
IntermediateMetricResult::Average(avg_data) => {
|
||||
MetricResult::Average(avg_data.finalize().into())
|
||||
IntermediateMetricResult::Average(intermediate_avg) => {
|
||||
MetricResult::Average(intermediate_avg.finalize().into())
|
||||
}
|
||||
IntermediateMetricResult::Count(intermediate_count) => {
|
||||
MetricResult::Count(intermediate_count.finalize().into())
|
||||
}
|
||||
IntermediateMetricResult::Max(intermediate_max) => {
|
||||
MetricResult::Max(intermediate_max.finalize().into())
|
||||
}
|
||||
IntermediateMetricResult::Min(intermediate_min) => {
|
||||
MetricResult::Min(intermediate_min.finalize().into())
|
||||
}
|
||||
IntermediateMetricResult::Stats(intermediate_stats) => {
|
||||
MetricResult::Stats(intermediate_stats.finalize())
|
||||
}
|
||||
IntermediateMetricResult::Sum(intermediate_sum) => {
|
||||
MetricResult::Sum(intermediate_sum.finalize().into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,13 +124,13 @@ impl From<IntermediateMetricResult> for MetricResult {
|
||||
#[serde(untagged)]
|
||||
pub enum BucketResult {
|
||||
/// This is the range entry for a bucket, which contains a key, count, from, to, and optionally
|
||||
/// sub_aggregations.
|
||||
/// sub-aggregations.
|
||||
Range {
|
||||
/// The range buckets sorted by range.
|
||||
buckets: BucketEntries<RangeBucketEntry>,
|
||||
},
|
||||
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
|
||||
/// sub_aggregations.
|
||||
/// sub-aggregations.
|
||||
Histogram {
|
||||
/// The buckets.
|
||||
///
|
||||
@@ -151,7 +175,7 @@ pub enum BucketEntries<T> {
|
||||
}
|
||||
|
||||
/// This is the default entry for a bucket, which contains a key, count, and optionally
|
||||
/// sub_aggregations.
|
||||
/// sub-aggregations.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
@@ -201,7 +225,7 @@ impl GetDocCount for BucketEntry {
|
||||
}
|
||||
|
||||
/// This is the range entry for a bucket, which contains a key, count, and optionally
|
||||
/// sub_aggregations.
|
||||
/// sub-aggregations.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
@@ -237,7 +261,7 @@ pub struct RangeBucketEntry {
|
||||
/// Number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
#[serde(flatten)]
|
||||
/// sub-aggregations in this bucket.
|
||||
/// Sub-aggregations in this bucket.
|
||||
pub sub_aggregation: AggregationResults,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -17,7 +17,10 @@ use super::bucket::{
|
||||
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
|
||||
GetDocCount, Order, OrderTarget, SegmentHistogramBucketEntry, TermsAggregation,
|
||||
};
|
||||
use super::metric::{IntermediateAverage, IntermediateStats};
|
||||
use super::metric::{
|
||||
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
|
||||
IntermediateSum,
|
||||
};
|
||||
use super::segment_agg_result::SegmentMetricResultCollector;
|
||||
use super::{format_date, Key, SerializedKey, VecWithNames};
|
||||
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
|
||||
@@ -204,22 +207,42 @@ pub enum IntermediateAggregationResult {
|
||||
/// Holds the intermediate data for metric results
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum IntermediateMetricResult {
|
||||
/// Intermediate average result
|
||||
/// Intermediate average result.
|
||||
Average(IntermediateAverage),
|
||||
/// Intermediate stats result
|
||||
/// Intermediate count result.
|
||||
Count(IntermediateCount),
|
||||
/// Intermediate max result.
|
||||
Max(IntermediateMax),
|
||||
/// Intermediate min result.
|
||||
Min(IntermediateMin),
|
||||
/// Intermediate stats result.
|
||||
Stats(IntermediateStats),
|
||||
/// Intermediate sum result.
|
||||
Sum(IntermediateSum),
|
||||
}
|
||||
|
||||
impl From<SegmentMetricResultCollector> for IntermediateMetricResult {
|
||||
fn from(tree: SegmentMetricResultCollector) -> Self {
|
||||
match tree {
|
||||
SegmentMetricResultCollector::Stats(collector) => match collector.collecting_for {
|
||||
super::metric::SegmentStatsType::Average => IntermediateMetricResult::Average(
|
||||
IntermediateAverage::from_collector(collector),
|
||||
),
|
||||
super::metric::SegmentStatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(collector))
|
||||
}
|
||||
super::metric::SegmentStatsType::Max => {
|
||||
IntermediateMetricResult::Max(IntermediateMax::from_collector(collector))
|
||||
}
|
||||
super::metric::SegmentStatsType::Min => {
|
||||
IntermediateMetricResult::Min(IntermediateMin::from_collector(collector))
|
||||
}
|
||||
super::metric::SegmentStatsType::Stats => {
|
||||
IntermediateMetricResult::Stats(collector.stats)
|
||||
}
|
||||
super::metric::SegmentStatsType::Avg => IntermediateMetricResult::Average(
|
||||
IntermediateAverage::from_collector(collector),
|
||||
),
|
||||
super::metric::SegmentStatsType::Sum => {
|
||||
IntermediateMetricResult::Sum(IntermediateSum::from_collector(collector))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -231,18 +254,36 @@ impl IntermediateMetricResult {
|
||||
MetricAggregation::Average(_) => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::default())
|
||||
}
|
||||
MetricAggregation::Count(_) => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::default())
|
||||
}
|
||||
MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()),
|
||||
MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()),
|
||||
MetricAggregation::Stats(_) => {
|
||||
IntermediateMetricResult::Stats(IntermediateStats::default())
|
||||
}
|
||||
MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()),
|
||||
}
|
||||
}
|
||||
fn merge_fruits(&mut self, other: IntermediateMetricResult) {
|
||||
match (self, other) {
|
||||
(
|
||||
IntermediateMetricResult::Average(avg_data_left),
|
||||
IntermediateMetricResult::Average(avg_data_right),
|
||||
IntermediateMetricResult::Average(avg_left),
|
||||
IntermediateMetricResult::Average(avg_right),
|
||||
) => {
|
||||
avg_data_left.merge_fruits(avg_data_right);
|
||||
avg_left.merge_fruits(avg_right);
|
||||
}
|
||||
(
|
||||
IntermediateMetricResult::Count(count_left),
|
||||
IntermediateMetricResult::Count(count_right),
|
||||
) => {
|
||||
count_left.merge_fruits(count_right);
|
||||
}
|
||||
(IntermediateMetricResult::Max(max_left), IntermediateMetricResult::Max(max_right)) => {
|
||||
max_left.merge_fruits(max_right);
|
||||
}
|
||||
(IntermediateMetricResult::Min(min_left), IntermediateMetricResult::Min(min_right)) => {
|
||||
min_left.merge_fruits(min_right);
|
||||
}
|
||||
(
|
||||
IntermediateMetricResult::Stats(stats_left),
|
||||
@@ -250,6 +291,9 @@ impl IntermediateMetricResult {
|
||||
) => {
|
||||
stats_left.merge_fruits(stats_right);
|
||||
}
|
||||
(IntermediateMetricResult::Sum(sum_left), IntermediateMetricResult::Sum(sum_right)) => {
|
||||
sum_left.merge_fruits(sum_right);
|
||||
}
|
||||
_ => {
|
||||
panic!("incompatible fruit types in tree");
|
||||
}
|
||||
|
||||
@@ -2,9 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::SegmentStatsCollector;
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// A single-value metric aggregation that computes the average of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// Supported field types are u64, i64, and f64.
|
||||
@@ -18,47 +17,43 @@ use super::SegmentStatsCollector;
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AverageAggregation {
|
||||
/// The field name to compute the stats on.
|
||||
/// The field name to compute the average on.
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
impl AverageAggregation {
|
||||
/// Create new AverageAggregation from a field.
|
||||
/// Creates a new [`AverageAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
AverageAggregation { field: field_name }
|
||||
Self { field: field_name }
|
||||
}
|
||||
/// Return the field name.
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains mergeable version of average data.
|
||||
/// Intermediate result of the average aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateAverage {
|
||||
pub(crate) sum: f64,
|
||||
pub(crate) doc_count: u64,
|
||||
stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl IntermediateAverage {
|
||||
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
sum: collector.stats.sum,
|
||||
doc_count: collector.stats.count,
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge average data into this instance.
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
self.sum += other.sum;
|
||||
self.doc_count += other.doc_count;
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// compute final result
|
||||
/// Computes the final average value.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
if self.doc_count == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(self.sum / self.doc_count as f64)
|
||||
}
|
||||
self.stats.finalize().avg
|
||||
}
|
||||
}
|
||||
|
||||
59
src/aggregation/metric/count.rs
Normal file
59
src/aggregation/metric/count.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
|
||||
/// A single-value metric aggregation that counts the number of values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// Supported field types are u64, i64, and f64.
|
||||
/// See [super::SingleMetricResult] for return value.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
/// {
|
||||
/// "value_count": {
|
||||
/// "field": "score",
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CountAggregation {
|
||||
/// The field name to compute the minimum on.
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
impl CountAggregation {
|
||||
/// Creates a new [`CountAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
Self { field: field_name }
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result of the count aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateCount {
|
||||
stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl IntermediateCount {
|
||||
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateCount) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// Computes the final minimum value.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
Some(self.stats.finalize().count as f64)
|
||||
}
|
||||
}
|
||||
59
src/aggregation/metric/max.rs
Normal file
59
src/aggregation/metric/max.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
|
||||
/// A single-value metric aggregation that computes the maximum of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// Supported field types are u64, i64, and f64.
|
||||
/// See [super::SingleMetricResult] for return value.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
/// {
|
||||
/// "max": {
|
||||
/// "field": "score",
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MaxAggregation {
|
||||
/// The field name to compute the maximum on.
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
impl MaxAggregation {
|
||||
/// Creates a new [`MaxAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
Self { field: field_name }
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result of the maximum aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateMax {
|
||||
stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl IntermediateMax {
|
||||
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMax) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// Computes the final maximum value.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
self.stats.finalize().max
|
||||
}
|
||||
}
|
||||
59
src/aggregation/metric/min.rs
Normal file
59
src/aggregation/metric/min.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
|
||||
/// A single-value metric aggregation that computes the minimum of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// Supported field types are u64, i64, and f64.
|
||||
/// See [super::SingleMetricResult] for return value.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
/// {
|
||||
/// "min": {
|
||||
/// "field": "score",
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MinAggregation {
|
||||
/// The field name to compute the minimum on.
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
impl MinAggregation {
|
||||
/// Creates a new [`MinAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
Self { field: field_name }
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result of the minimum aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateMin {
|
||||
stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl IntermediateMin {
|
||||
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMin) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// Computes the final minimum value.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
self.stats.finalize().min
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,18 @@
|
||||
//! The aggregations in this family compute metrics, see [super::agg_req::MetricAggregation] for
|
||||
//! details.
|
||||
mod average;
|
||||
mod count;
|
||||
mod max;
|
||||
mod min;
|
||||
mod stats;
|
||||
mod sum;
|
||||
pub use average::*;
|
||||
pub use count::*;
|
||||
pub use max::*;
|
||||
pub use min::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use stats::*;
|
||||
pub use sum::*;
|
||||
|
||||
/// Single-metric aggregations use this common result structure.
|
||||
///
|
||||
@@ -28,3 +36,61 @@ impl From<Option<f64>> for SingleMetricResult {
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::AggregationCollector;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{Cardinality, NumericOptions, Schema};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_metric_aggregations() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_options = NumericOptions::default().set_fast(Cardinality::SingleValue);
|
||||
let field = schema_builder.add_f64_field("price", field_options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for i in 0..3 {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
field => i as f64,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
for i in 3..6 {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
field => i as f64,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let aggregations_json = r#"{
|
||||
"price_avg": { "avg": { "field": "price" } },
|
||||
"price_count": { "value_count": { "field": "price" } },
|
||||
"price_max": { "max": { "field": "price" } },
|
||||
"price_min": { "min": { "field": "price" } },
|
||||
"price_stats": { "stats": { "field": "price" } },
|
||||
"price_sum": { "sum": { "field": "price" } }
|
||||
}"#;
|
||||
let aggregations: Aggregations = serde_json::from_str(&aggregations_json).unwrap();
|
||||
let collector = AggregationCollector::from_aggs(aggregations, None, index.schema());
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
|
||||
let aggregations_res_json = serde_json::to_value(&aggregations_res).unwrap();
|
||||
|
||||
assert_eq!(aggregations_res_json["price_avg"]["value"], 2.5);
|
||||
assert_eq!(aggregations_res_json["price_count"]["value"], 6.0);
|
||||
assert_eq!(aggregations_res_json["price_max"]["value"], 5.0);
|
||||
assert_eq!(aggregations_res_json["price_min"]["value"], 0.0);
|
||||
assert_eq!(aggregations_res_json["price_sum"]["value"], 15.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ use crate::aggregation::f64_from_fastfield_u64;
|
||||
use crate::schema::Type;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes stats of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
/// are extracted from the aggregated documents.
|
||||
/// Supported field types are `u64`, `i64`, and `f64`.
|
||||
/// See [`Stats`] for returned statistics.
|
||||
///
|
||||
@@ -26,11 +26,11 @@ pub struct StatsAggregation {
|
||||
}
|
||||
|
||||
impl StatsAggregation {
|
||||
/// Create new StatsAggregation from a field.
|
||||
/// Creates a new [`StatsAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
StatsAggregation { field: field_name }
|
||||
}
|
||||
/// Return the field name.
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
@@ -43,13 +43,13 @@ pub struct Stats {
|
||||
pub count: u64,
|
||||
/// The sum of the fast field values.
|
||||
pub sum: f64,
|
||||
/// The standard deviation of the fast field values. `None` for count == 0.
|
||||
/// The standard deviation of the fast field values. `None` if count equals zero.
|
||||
pub standard_deviation: Option<f64>,
|
||||
/// The min value of the fast field values.
|
||||
pub min: Option<f64>,
|
||||
/// The max value of the fast field values.
|
||||
pub max: Option<f64>,
|
||||
/// The average of the values. `None` for count == 0.
|
||||
/// The average of the fast field values. `None` if count equals zero.
|
||||
pub avg: Option<f64>,
|
||||
}
|
||||
|
||||
@@ -63,27 +63,29 @@ impl Stats {
|
||||
"max" => Ok(self.max),
|
||||
"avg" => Ok(self.avg),
|
||||
_ => Err(TantivyError::InvalidArgument(format!(
|
||||
"unknown property {} on stats metric aggregation",
|
||||
"Unknown property {} on stats metric aggregation",
|
||||
agg_property
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `IntermediateStats` contains the mergeable version for stats.
|
||||
/// Intermediate result of the stats aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateStats {
|
||||
/// the number of values
|
||||
/// The number of values.
|
||||
pub count: u64,
|
||||
/// the sum of the values
|
||||
/// The sum of the values.
|
||||
pub sum: f64,
|
||||
/// the squared sum of the values
|
||||
/// The sum of the squared values.
|
||||
pub squared_sum: f64,
|
||||
/// the min value of the values
|
||||
/// The min value of the values.
|
||||
pub min: f64,
|
||||
/// the max value of the values
|
||||
/// The max value of the values.
|
||||
pub max: f64,
|
||||
}
|
||||
|
||||
impl Default for IntermediateStats {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -97,7 +99,7 @@ impl Default for IntermediateStats {
|
||||
}
|
||||
|
||||
impl IntermediateStats {
|
||||
pub(crate) fn avg(&self) -> Option<f64> {
|
||||
fn avg(&self) -> Option<f64> {
|
||||
if self.count == 0 {
|
||||
None
|
||||
} else {
|
||||
@@ -109,12 +111,12 @@ impl IntermediateStats {
|
||||
self.squared_sum / (self.count as f64)
|
||||
}
|
||||
|
||||
pub(crate) fn standard_deviation(&self) -> Option<f64> {
|
||||
fn standard_deviation(&self) -> Option<f64> {
|
||||
self.avg()
|
||||
.map(|average| (self.square_mean() - average * average).sqrt())
|
||||
}
|
||||
|
||||
/// Merge data from other stats into this instance.
|
||||
/// Merges the other stats intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateStats) {
|
||||
self.count += other.count;
|
||||
self.sum += other.sum;
|
||||
@@ -123,7 +125,7 @@ impl IntermediateStats {
|
||||
self.max = self.max.max(other.max);
|
||||
}
|
||||
|
||||
/// compute final resultimprove_docs
|
||||
/// Computes the final stats value.
|
||||
pub fn finalize(&self) -> Stats {
|
||||
let min = if self.count == 0 {
|
||||
None
|
||||
@@ -157,23 +159,27 @@ impl IntermediateStats {
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) enum SegmentStatsType {
|
||||
Average,
|
||||
Count,
|
||||
Max,
|
||||
Min,
|
||||
Stats,
|
||||
Avg,
|
||||
Sum,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) stats: IntermediateStats,
|
||||
field_type: Type,
|
||||
pub(crate) collecting_for: SegmentStatsType,
|
||||
pub(crate) stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(field_type: Type, collecting_for: SegmentStatsType) -> Self {
|
||||
Self {
|
||||
field_type,
|
||||
stats: IntermediateStats::default(),
|
||||
collecting_for,
|
||||
stats: IntermediateStats::default(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &dyn Column<u64>) {
|
||||
|
||||
59
src/aggregation/metric/sum.rs
Normal file
59
src/aggregation/metric/sum.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
|
||||
/// A single-value metric aggregation that sums up numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
/// Supported field types are u64, i64, and f64.
|
||||
/// See [super::SingleMetricResult] for return value.
|
||||
///
|
||||
/// # JSON Format
|
||||
/// ```json
|
||||
/// {
|
||||
/// "sum": {
|
||||
/// "field": "score",
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SumAggregation {
|
||||
/// The field name to compute the minimum on.
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
impl SumAggregation {
|
||||
/// Creates a new [`SumAggregation`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
Self { field: field_name }
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result of the minimum aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateSum {
|
||||
stats: IntermediateStats,
|
||||
}
|
||||
|
||||
impl IntermediateSum {
|
||||
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// Computes the final minimum value.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
Some(self.stats.finalize().sum)
|
||||
}
|
||||
}
|
||||
@@ -216,8 +216,8 @@ impl<T: Clone> VecWithNames<T> {
|
||||
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
|
||||
// Sort to ensure order of elements match across multiple instances
|
||||
entries.sort_by(|left, right| left.0.cmp(&right.0));
|
||||
let mut data = vec![];
|
||||
let mut data_names = vec![];
|
||||
let mut data = Vec::with_capacity(entries.len());
|
||||
let mut data_names = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
data_names.push(entry.0);
|
||||
data.push(entry.1);
|
||||
|
||||
@@ -15,7 +15,8 @@ use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTer
|
||||
use super::collector::MAX_BUCKET_COUNT;
|
||||
use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult};
|
||||
use super::metric::{
|
||||
AverageAggregation, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
|
||||
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector,
|
||||
SegmentStatsType, StatsAggregation, SumAggregation,
|
||||
};
|
||||
use super::VecWithNames;
|
||||
use crate::aggregation::agg_req::BucketAggregationType;
|
||||
@@ -169,16 +170,36 @@ pub(crate) enum SegmentMetricResultCollector {
|
||||
impl SegmentMetricResultCollector {
|
||||
pub fn from_req_and_validate(req: &MetricAggregationWithAccessor) -> crate::Result<Self> {
|
||||
match &req.metric {
|
||||
MetricAggregation::Average(AverageAggregation { field: _ }) => {
|
||||
MetricAggregation::Average(AverageAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Avg),
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average),
|
||||
))
|
||||
}
|
||||
MetricAggregation::Stats(StatsAggregation { field: _ }) => {
|
||||
MetricAggregation::Count(CountAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count),
|
||||
))
|
||||
}
|
||||
MetricAggregation::Max(MaxAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max),
|
||||
))
|
||||
}
|
||||
MetricAggregation::Min(MinAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min),
|
||||
))
|
||||
}
|
||||
MetricAggregation::Stats(StatsAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats),
|
||||
))
|
||||
}
|
||||
MetricAggregation::Sum(SumAggregation { .. }) => {
|
||||
Ok(SegmentMetricResultCollector::Stats(
|
||||
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
pub(crate) fn collect_block(&mut self, doc: &[DocId], metric: &MetricAggregationWithAccessor) {
|
||||
|
||||
Reference in New Issue
Block a user