Files
tantivy/src/aggregation/agg_req.rs
PSeitz f8e79271ab Replace AggregationsWithAccessor (#2715)
* add nested histogram-termagg benchmark

* Replace AggregationsWithAccessor with AggData

With AggregationsWithAccessor pre-computation and caching was done on the collector level.
If you have 10000 sub collectors (e.g. a term aggregation with sub aggregations) this is very inefficient.
`AggData` instead moves the data from the collector to a node which reflects the cardinality of the request tree instead of the cardinality of the segment collector.
It also moves the global struct shared with all aggregations in to aggregation specific structs. So each aggregation has its own space to store cached data and aggregation specific information.

This also breaks up the dependency to the elastic search aggregation structure somewhat.

Due to lifetime issues, we move the agg request specific object out of `AggData` during the collection and move it back at the end (for now). That's some unnecessary work, which costs CPU.

This allows better caching and will also pave the way for another potential optimization, by separating the collector and its storage. Currently we allocate a new collector for each sub aggregation bucket (for nested aggregations), but ideally we would have just one collector instance.

* renames

* move request data to agg request files

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-14 09:22:11 +02:00

389 lines
13 KiB
Rust

//! Contains the aggregation request tree. Used to build an
//! [`AggregationCollector`](super::AggregationCollector).
//!
//! [`Aggregations`] is the top level entry point to create a request, which is a `HashMap<String,
//! Aggregation>`.
//!
//! Requests are compatible with the json format of elasticsearch.
//!
//! # Example
//!
//! ```
//! use tantivy::aggregation::agg_req::Aggregations;
//!
//! let elasticsearch_compatible_json_req = r#"
//! {
//! "range": {
//! "range": {
//! "field": "score",
//! "ranges": [
//! { "from": 3.0, "to": 7.0 },
//! { "from": 7.0, "to": 20.0 }
//! ]
//! }
//! }
//! }"#;
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
//! ```
use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
TopHitsAggregationReq,
};
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
///
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
/// Aggregation request.
///
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(try_from = "AggregationForDeserialization")]
pub struct Aggregation {
/// The aggregation variant, which can be either a bucket or a metric.
#[serde(flatten)]
pub agg: AggregationVariants,
/// on the document set in the bucket.
#[serde(rename = "aggs")]
#[serde(skip_serializing_if = "Aggregations::is_empty")]
pub sub_aggregation: Aggregations,
}
/// In order to display proper error message, we cannot rely on flattening
/// the json enum. Instead we introduce an intermediary struct to separate
/// the aggregation from the subaggregation.
#[derive(Deserialize)]
struct AggregationForDeserialization {
#[serde(flatten)]
pub aggs_remaining_json: serde_json::Value,
#[serde(rename = "aggs")]
#[serde(default)]
pub sub_aggregation: Aggregations,
}
impl TryFrom<AggregationForDeserialization> for Aggregation {
type Error = serde_json::Error;
fn try_from(value: AggregationForDeserialization) -> serde_json::Result<Self> {
let AggregationForDeserialization {
aggs_remaining_json,
sub_aggregation,
} = value;
let agg: AggregationVariants = serde_json::from_value(aggs_remaining_json)?;
Ok(Aggregation {
agg,
sub_aggregation,
})
}
}
impl Aggregation {
pub(crate) fn sub_aggregation(&self) -> &Aggregations {
&self.sub_aggregation
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
fast_field_names.extend(
self.agg
.get_fast_field_names()
.iter()
.map(|s| s.to_string()),
);
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
}
}
/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut fast_field_names = Default::default();
for el in aggs.values() {
el.get_fast_field_names(&mut fast_field_names)
}
fast_field_names
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {
// Bucket aggregation types
/// Put data into buckets of user-defined ranges.
#[serde(rename = "range")]
Range(RangeAggregation),
/// Put data into a histogram.
#[serde(rename = "histogram")]
Histogram(HistogramAggregation),
/// Put data into a date histogram.
#[serde(rename = "date_histogram")]
DateHistogram(DateHistogramAggregationReq),
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
#[serde(rename = "avg")]
Average(AverageAggregation),
/// Counts the number of extracted values.
#[serde(rename = "value_count")]
Count(CountAggregation),
/// Finds the maximum value.
#[serde(rename = "max")]
Max(MaxAggregation),
/// Finds the minimum value.
#[serde(rename = "min")]
Min(MinAggregation),
/// Computes a collection of statistics (`min`, `max`, `sum`, `count`, and `avg`) over the
/// extracted values.
#[serde(rename = "stats")]
Stats(StatsAggregation),
/// Computes a collection of estended statistics (`min`, `max`, `sum`, `count`, `avg`,
/// `sum_of_squares`, `variance`, `variance_sampling`, `std_deviation`,
/// `std_deviation_sampling`) over the extracted values.
#[serde(rename = "extended_stats")]
ExtendedStats(ExtendedStatsAggregation),
/// Computes the sum of the extracted values.
#[serde(rename = "sum")]
Sum(SumAggregation),
/// Computes the sum of the extracted values.
#[serde(rename = "percentiles")]
Percentiles(PercentilesAggregationReq),
/// Finds the top k values matching some order
#[serde(rename = "top_hits")]
TopHits(TopHitsAggregationReq),
/// Computes an estimate of the number of unique values
#[serde(rename = "cardinality")]
Cardinality(CardinalityAggregationReq),
}
impl AggregationVariants {
/// Returns the name of the fields used by the aggregation.
pub fn get_fast_field_names(&self) -> Vec<&str> {
match self {
AggregationVariants::Terms(terms) => vec![terms.field.as_str()],
AggregationVariants::Range(range) => vec![range.field.as_str()],
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
AggregationVariants::Min(min) => vec![min.field_name()],
AggregationVariants::Stats(stats) => vec![stats.field_name()],
AggregationVariants::ExtendedStats(extended_stats) => vec![extended_stats.field_name()],
AggregationVariants::Sum(sum) => vec![sum.field_name()],
AggregationVariants::Percentiles(per) => vec![per.field_name()],
AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
AggregationVariants::Cardinality(per) => vec![per.field_name()],
}
}
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self {
AggregationVariants::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> crate::Result<Option<HistogramAggregation>> {
match &self {
AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())),
AggregationVariants::DateHistogram(histogram) => {
Ok(Some(histogram.to_histogram_req()?))
}
_ => Ok(None),
}
}
pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
match &self {
AggregationVariants::Terms(terms) => Some(terms),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deser_json_test() {
let agg_req_json = r#"{
"price_avg": { "avg": { "field": "price" } },
"price_count": { "value_count": { "field": "price" } },
"price_max": { "max": { "field": "price" } },
"price_min": { "min": { "field": "price" } },
"price_stats": { "stats": { "field": "price" } },
"price_sum": { "sum": { "field": "price" } }
}"#;
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
}
#[test]
fn deser_json_test_bucket() {
let agg_req_json = r#"
{
"termagg": {
"terms": {
"field": "json.mixed_type",
"order": { "min_price": "desc" }
},
"aggs": {
"min_price": { "min": { "field": "json.mixed_type" } }
}
},
"rangeagg": {
"range": {
"field": "json.mixed_type",
"ranges": [
{ "to": 3.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "json.mixed_type" } }
}
}
} "#;
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
}
#[test]
fn test_metric_aggregations_deser() {
let agg_req_json = r#"{
"price_avg": { "avg": { "field": "price" } },
"price_count": { "value_count": { "field": "price" } },
"price_max": { "max": { "field": "price" } },
"price_min": { "min": { "field": "price" } },
"price_stats": { "stats": { "field": "price" } },
"price_sum": { "sum": { "field": "price" } }
}"#;
let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
assert!(
matches!(&agg_req.get("price_avg").unwrap().agg, AggregationVariants::Average(avg) if avg.field == "price")
);
assert!(
matches!(&agg_req.get("price_count").unwrap().agg, AggregationVariants::Count(count) if count.field == "price")
);
assert!(
matches!(&agg_req.get("price_max").unwrap().agg, AggregationVariants::Max(max) if max.field == "price")
);
assert!(
matches!(&agg_req.get("price_min").unwrap().agg, AggregationVariants::Min(min) if min.field == "price")
);
assert!(
matches!(&agg_req.get("price_stats").unwrap().agg, AggregationVariants::Stats(stats) if stats.field == "price")
);
assert!(
matches!(&agg_req.get("price_sum").unwrap().agg, AggregationVariants::Sum(sum) if sum.field == "price")
);
}
#[test]
fn serialize_to_json_test() {
let elasticsearch_compatible_json_req = r#"{
"range": {
"range": {
"field": "score",
"ranges": [
{
"to": 3.0
},
{
"from": 3.0,
"to": 7.0
},
{
"from": 7.0,
"to": 20.0
},
{
"from": 20.0
}
],
"keyed": true
}
}
}"#;
let agg_req1: Aggregations =
{ serde_json::from_str(elasticsearch_compatible_json_req).unwrap() };
let agg_req2: String = serde_json::to_string_pretty(&agg_req1).unwrap();
assert_eq!(agg_req2, elasticsearch_compatible_json_req);
}
#[test]
fn test_get_fast_field_names() {
let range_agg: Aggregation = {
serde_json::from_value(json!({
"range": {
"field": "score",
"ranges": [
{ "to": 3.0 },
{ "from": 3.0, "to": 7.0 },
{ "from": 7.0, "to": 20.0 },
{ "from": 20.0 }
],
}
}))
.unwrap()
};
let agg_req1: Aggregations = {
serde_json::from_value(json!({
"range1": range_agg,
"range2":{
"range": {
"field": "score2",
"ranges": [
{ "to": 3.0 },
{ "from": 3.0, "to": 7.0 },
{ "from": 7.0, "to": 20.0 },
{ "from": 20.0 }
],
},
"aggs": {
"metric": {
"avg": {
"field": "field123"
}
}
}
}
}))
.unwrap()
};
assert_eq!(
get_fast_field_names(&agg_req1),
vec![
"score".to_string(),
"score2".to_string(),
"field123".to_string()
]
.into_iter()
.collect()
)
}
}