switch to Aggregation without serde_untagged (#2003)

* refactor result handling

* remove Internal stuff

* merge different accessors

* switch to Aggregation without serde_untagged

* fix doctests
This commit is contained in:
PSeitz
2023-04-25 14:54:51 +08:00
committed by GitHub
parent 7b31100208
commit 2e369db936
18 changed files with 839 additions and 1190 deletions

View File

@@ -7,12 +7,8 @@
// ---
use serde_json::{Deserializer, Value};
use tantivy::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
};
use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::agg_result::AggregationResults;
use tantivy::aggregation::bucket::{RangeAggregation, RangeAggregationRange};
use tantivy::aggregation::metric::AverageAggregation;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::AllQuery;
use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing, FAST};

View File

@@ -5,16 +5,10 @@ mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand_distr::Distribution;
use serde_json::json;
use test::{self, Bencher};
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
};
use crate::aggregation::bucket::{
CustomOrder, HistogramAggregation, HistogramBounds, Order, OrderTarget, RangeAggregation,
TermsAggregation,
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::AggregationCollector;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
@@ -153,14 +147,10 @@ mod bench {
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"average": { "avg": { "field": "score", } }
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -182,14 +172,10 @@ mod bench {
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score_f64".to_string(),
))),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"average_f64": { "stats": { "field": "score_f64", } }
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -211,14 +197,10 @@ mod bench {
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"average_f64": { "avg": { "field": "score_f64", } }
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -265,22 +247,11 @@ mod bench {
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![
(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
),
(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"average_f64": { "avg": { "field": "score_f64" } },
"average": { "avg": { "field": "score" } },
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -296,21 +267,10 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_few_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
}))
.unwrap();
let collector = get_collector(agg_req);
@@ -326,30 +286,15 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: sub_agg_req,
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
.into(),
),
)]
.into_iter()
.collect();
},
}))
.unwrap();
let collector = get_collector(agg_req);
@@ -365,21 +310,10 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
}))
.unwrap();
let collector = get_collector(agg_req);
@@ -395,25 +329,10 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } },
}))
.unwrap();
let collector = get_collector(agg_req);
@@ -429,29 +348,17 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"range_f64": { "range": { "field": "score_f64", "ranges": [
{ "from": 3, "to": 7000 },
{ "from": 7000, "to": 20000 },
{ "from": 20000, "to": 30000 },
{ "from": 30000, "to": 40000 },
{ "from": 40000, "to": 50000 },
{ "from": 50000, "to": 60000 }
] } },
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -467,38 +374,25 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
let agg_req_1: Aggregations = serde_json::from_value(json!({
"rangef64": {
"range": {
"field": "score_f64",
"ranges": [
{ "from": 3, "to": 7000 },
{ "from": 7000, "to": 20000 },
{ "from": 20000, "to": 30000 },
{ "from": 30000, "to": 40000 },
{ "from": 40000, "to": 50000 },
{ "from": 50000, "to": 60000 }
]
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
.into(),
),
)]
.into_iter()
.collect();
},
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -519,26 +413,10 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64,
hard_bounds: Some(HistogramBounds {
min: 1000.0,
max: 300_000.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"rangef64": { "histogram": { "field": "score_f64", "interval": 100, "hard_bounds": { "min": 1000, "max": 300000 } } },
}))
.unwrap();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
@@ -553,31 +431,15 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: sub_agg_req,
let agg_req_1: Aggregations = serde_json::from_value(json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 100 },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
.into(),
),
)]
.into_iter()
.collect();
}
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -593,22 +455,15 @@ mod bench {
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"rangef64": {
"histogram": {
"field": "score_f64",
"interval": 100 // 1000 buckets
},
}
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -630,43 +485,23 @@ mod bench {
IndexRecordOption::Basic,
);
let sub_agg_req_1: Aggregations = vec![(
"average_in_range".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![
(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1,
}
.into(),
),
),
]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"rangef64": {
"range": {
"field": "score_f64",
"ranges": [
{ "from": 3, "to": 7000 },
{ "from": 7000, "to": 20000 },
{ "from": 20000, "to": 60000 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "score" } }
}
},
"average": { "avg": { "field": "score" } }
}))
.unwrap();
let collector = get_collector(agg_req_1);

View File

@@ -37,7 +37,6 @@ use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation,
PercentilesAggregationReq, StatsAggregation, SumAggregation,
};
use super::VecWithNames;
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
/// defined names. It is also used in [buckets](BucketAggregation) to define sub-aggregations.
@@ -45,73 +44,30 @@ use super::VecWithNames;
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
/// Like Aggregations, but optimized to work with the aggregation result
#[derive(Clone, Debug)]
pub(crate) struct AggregationsInternal {
pub(crate) metrics: VecWithNames<MetricAggregation>,
pub(crate) buckets: VecWithNames<BucketAggregationInternal>,
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Aggregation request.
///
/// An aggregation is either a bucket or a metric.
pub struct Aggregation {
/// The aggregation variant, which can be either a bucket or a metric.
#[serde(flatten)]
pub agg: AggregationVariants,
/// The sub_aggregations, only valid for bucket type aggregations. Each bucket will aggregate
/// on the document set in the bucket.
#[serde(rename = "aggs")]
#[serde(default)]
#[serde(skip_serializing_if = "Aggregations::is_empty")]
pub sub_aggregation: Aggregations,
}
impl From<Aggregations> for AggregationsInternal {
fn from(aggs: Aggregations) -> Self {
let mut metrics = vec![];
let mut buckets = vec![];
for (key, agg) in aggs {
match agg {
Aggregation::Bucket(bucket) => {
let sub_aggregation = bucket.get_sub_aggs().clone().into();
buckets.push((
key,
BucketAggregationInternal {
bucket_agg: bucket.bucket_agg,
sub_aggregation,
},
))
}
Aggregation::Metric(metric) => metrics.push((key, metric)),
}
}
Self {
metrics: VecWithNames::from_entries(metrics),
buckets: VecWithNames::from_entries(buckets),
}
}
}
#[derive(Clone, Debug)]
// Like BucketAggregation, but optimized to work with the result
pub(crate) struct BucketAggregationInternal {
/// Bucket aggregation strategy to group documents.
pub bucket_agg: BucketAggregationType,
/// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the
/// bucket.
sub_aggregation: AggregationsInternal,
}
impl BucketAggregationInternal {
pub(crate) fn sub_aggregation(&self) -> &AggregationsInternal {
impl Aggregation {
pub(crate) fn sub_aggregation(&self) -> &Aggregations {
&self.sub_aggregation
}
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self.bucket_agg {
BucketAggregationType::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> crate::Result<Option<HistogramAggregation>> {
match &self.bucket_agg {
BucketAggregationType::Histogram(histogram) => Ok(Some(histogram.clone())),
BucketAggregationType::DateHistogram(histogram) => {
Ok(Some(histogram.to_histogram_req()?))
}
_ => Ok(None),
}
}
pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
match &self.bucket_agg {
BucketAggregationType::Terms(terms) => Some(terms),
_ => None,
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
fast_field_names.insert(self.agg.get_fast_field_name().to_string());
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
}
}
@@ -124,100 +80,24 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`].
///
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Aggregation {
/// Bucket aggregation, see [`BucketAggregation`] for details.
Bucket(Box<BucketAggregation>),
/// Metric aggregation, see [`MetricAggregation`] for details.
Metric(MetricAggregation),
}
impl Aggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names),
Aggregation::Metric(metric) => {
fast_field_names.insert(metric.get_fast_field_name().to_string());
}
}
}
}
/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which
/// determines whether or not a document in the falls into it. In other words, the buckets
/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can
/// fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also
/// compute and return the number of documents for each bucket. Bucket aggregations, as opposed to
/// metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for
/// the buckets created by their "parent" bucket aggregation. There are different bucket
/// aggregators, each with a different "bucketing" strategy. Some define a single bucket, some
/// define fixed number of multiple buckets, and others dynamically create the buckets during the
/// aggregation process.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BucketAggregation {
/// Bucket aggregation strategy to group documents.
#[serde(flatten)]
pub bucket_agg: BucketAggregationType,
/// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the
/// bucket.
#[serde(rename = "aggs")]
#[serde(default)]
#[serde(skip_serializing_if = "Aggregations::is_empty")]
pub sub_aggregation: Aggregations,
}
impl BucketAggregation {
pub(crate) fn get_sub_aggs(&self) -> &Aggregations {
&self.sub_aggregation
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
let fast_field_name = self.bucket_agg.get_fast_field_name();
fast_field_names.insert(fast_field_name.to_string());
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
}
}
/// The bucket aggregation types.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum BucketAggregationType {
/// All aggregation types.
pub enum AggregationVariants {
// Bucket aggregation types
/// Put data into buckets of user-defined ranges.
#[serde(rename = "range")]
Range(RangeAggregation),
/// Put data into buckets of user-defined ranges.
/// Put data into a histogram.
#[serde(rename = "histogram")]
Histogram(HistogramAggregation),
/// Put data into buckets of user-defined ranges.
/// Put data into a date histogram.
#[serde(rename = "date_histogram")]
DateHistogram(DateHistogramAggregationReq),
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
}
impl BucketAggregationType {
fn get_fast_field_name(&self) -> &str {
match self {
BucketAggregationType::Terms(terms) => terms.field.as_str(),
BucketAggregationType::Range(range) => range.field.as_str(),
BucketAggregationType::Histogram(histogram) => histogram.field.as_str(),
BucketAggregationType::DateHistogram(histogram) => histogram.field.as_str(),
}
}
}
/// The aggregations in this family compute metrics based on values extracted
/// from the documents that are being aggregated. Values are extracted from the fast field of
/// the document.
/// Some aggregations output a single numeric metric (e.g. Average) and are called
/// single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are
/// called multi-value numeric metrics aggregation.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum MetricAggregation {
// Metric aggregation types
/// Computes the average of the extracted values.
#[serde(rename = "avg")]
Average(AverageAggregation),
@@ -242,31 +122,102 @@ pub enum MetricAggregation {
Percentiles(PercentilesAggregationReq),
}
impl MetricAggregation {
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
impl AggregationVariants {
fn get_fast_field_name(&self) -> &str {
match self {
AggregationVariants::Terms(terms) => terms.field.as_str(),
AggregationVariants::Range(range) => range.field.as_str(),
AggregationVariants::Histogram(histogram) => histogram.field.as_str(),
AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(),
AggregationVariants::Average(avg) => avg.field_name(),
AggregationVariants::Count(count) => count.field_name(),
AggregationVariants::Max(max) => max.field_name(),
AggregationVariants::Min(min) => min.field_name(),
AggregationVariants::Stats(stats) => stats.field_name(),
AggregationVariants::Sum(sum) => sum.field_name(),
AggregationVariants::Percentiles(per) => per.field_name(),
}
}
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self {
MetricAggregation::Percentiles(percentile_req) => Some(percentile_req),
AggregationVariants::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> crate::Result<Option<HistogramAggregation>> {
match &self {
AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())),
AggregationVariants::DateHistogram(histogram) => {
Ok(Some(histogram.to_histogram_req()?))
}
_ => Ok(None),
}
}
pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
match &self {
AggregationVariants::Terms(terms) => Some(terms),
_ => None,
}
}
fn get_fast_field_name(&self) -> &str {
match self {
MetricAggregation::Average(avg) => avg.field_name(),
MetricAggregation::Count(count) => count.field_name(),
MetricAggregation::Max(max) => max.field_name(),
MetricAggregation::Min(min) => min.field_name(),
MetricAggregation::Stats(stats) => stats.field_name(),
MetricAggregation::Sum(sum) => sum.field_name(),
MetricAggregation::Percentiles(per) => per.field_name(),
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deser_json_test() {
let agg_req_json = r#"{
"price_avg": { "avg": { "field": "price" } },
"price_count": { "value_count": { "field": "price" } },
"price_max": { "max": { "field": "price" } },
"price_min": { "min": { "field": "price" } },
"price_stats": { "stats": { "field": "price" } },
"price_sum": { "sum": { "field": "price" } }
}"#;
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
}
#[test]
fn deser_json_test_bucket() {
let agg_req_json = r#"
{
"termagg": {
"terms": {
"field": "json.mixed_type",
"order": { "min_price": "desc" }
},
"aggs": {
"min_price": { "min": { "field": "json.mixed_type" } }
}
},
"rangeagg": {
"range": {
"field": "json.mixed_type",
"ranges": [
{ "to": 3.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "json.mixed_type" } }
}
}
} "#;
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
}
#[test]
fn test_metric_aggregations_deser() {
let agg_req_json = r#"{
@@ -280,22 +231,22 @@ mod tests {
let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
assert!(
matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price")
matches!(&agg_req.get("price_avg").unwrap().agg, AggregationVariants::Average(avg) if avg.field == "price")
);
assert!(
matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price")
matches!(&agg_req.get("price_count").unwrap().agg, AggregationVariants::Count(count) if count.field == "price")
);
assert!(
matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price")
matches!(&agg_req.get("price_max").unwrap().agg, AggregationVariants::Max(max) if max.field == "price")
);
assert!(
matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price")
matches!(&agg_req.get("price_min").unwrap().agg, AggregationVariants::Min(min) if min.field == "price")
);
assert!(
matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price")
matches!(&agg_req.get("price_stats").unwrap().agg, AggregationVariants::Stats(stats) if stats.field == "price")
);
assert!(
matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price")
matches!(&agg_req.get("price_sum").unwrap().agg, AggregationVariants::Sum(sum) if sum.field == "price")
);
}

View File

@@ -2,7 +2,7 @@
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
@@ -16,34 +16,107 @@ use crate::SegmentReader;
#[derive(Clone, Default)]
pub(crate) struct AggregationsWithAccessor {
pub metrics: VecWithNames<MetricAggregationWithAccessor>,
pub buckets: VecWithNames<BucketAggregationWithAccessor>,
pub aggs: VecWithNames<AggregationWithAccessor>,
}
impl AggregationsWithAccessor {
fn from_data(
metrics: VecWithNames<MetricAggregationWithAccessor>,
buckets: VecWithNames<BucketAggregationWithAccessor>,
) -> Self {
Self { metrics, buckets }
fn from_data(aggs: VecWithNames<AggregationWithAccessor>) -> Self {
Self { aggs }
}
pub fn is_empty(&self) -> bool {
self.metrics.is_empty() && self.buckets.is_empty()
self.aggs.is_empty()
}
}
#[derive(Clone)]
pub struct BucketAggregationWithAccessor {
pub struct AggregationWithAccessor {
/// In general there can be buckets without fast field access, e.g. buckets that are created
/// based on search terms. So eventually this needs to be Option or moved.
pub(crate) accessor: Column<u64>,
pub(crate) str_dict_column: Option<StrColumn>,
pub(crate) field_type: ColumnType,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) limits: AggregationLimits,
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
pub(crate) agg: Aggregation,
}
impl AggregationWithAccessor {
fn try_from_agg(
agg: &Aggregation,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
limits: AggregationLimits,
) -> crate::Result<AggregationWithAccessor> {
let mut str_dict_column = None;
use AggregationVariants::*;
let (accessor, field_type) = match &agg.agg {
Range(RangeAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
Histogram(HistogramAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
DateHistogram(DateHistogramAggregationReq {
field: field_name, ..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
Terms(TermsAggregation {
field: field_name, ..
}) => {
str_dict_column = reader.fast_fields().str(field_name)?;
get_ff_reader_and_validate(reader, field_name, None)?
}
Average(AverageAggregation { field: field_name })
| Count(CountAggregation { field: field_name })
| Max(MaxAggregation { field: field_name })
| Min(MinAggregation { field: field_name })
| Stats(StatsAggregation { field: field_name })
| Sum(SumAggregation { field: field_name }) => {
let (accessor, field_type) = get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?;
(accessor, field_type)
}
Percentiles(percentiles) => {
let (accessor, field_type) = get_ff_reader_and_validate(
reader,
percentiles.field_name(),
Some(get_numeric_or_date_column_types()),
)?;
(accessor, field_type)
}
};
let sub_aggregation = sub_aggregation.clone();
Ok(AggregationWithAccessor {
accessor,
field_type,
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
&limits.clone(),
)?,
agg: agg.clone(),
str_dict_column,
limits,
column_block_accessor: Default::default(),
})
}
}
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
@@ -55,140 +128,25 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
]
}
impl BucketAggregationWithAccessor {
fn try_from_bucket(
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
limits: AggregationLimits,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut str_dict_column = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::DateHistogram(DateHistogramAggregationReq {
field: field_name,
..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::Terms(TermsAggregation {
field: field_name, ..
}) => {
str_dict_column = reader.fast_fields().str(field_name)?;
get_ff_reader_and_validate(reader, field_name, None)?
}
};
let sub_aggregation = sub_aggregation.clone();
Ok(BucketAggregationWithAccessor {
accessor,
field_type,
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
&limits.clone(),
)?,
bucket_agg: bucket.clone(),
str_dict_column,
limits,
column_block_accessor: Default::default(),
})
}
}
/// Contains the metric request and the fast field accessor.
#[derive(Clone)]
pub struct MetricAggregationWithAccessor {
pub metric: MetricAggregation,
pub field_type: ColumnType,
pub accessor: Column<u64>,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl MetricAggregationWithAccessor {
fn try_from_metric(
metric: &MetricAggregation,
reader: &SegmentReader,
) -> crate::Result<MetricAggregationWithAccessor> {
match &metric {
MetricAggregation::Average(AverageAggregation { field: field_name })
| MetricAggregation::Count(CountAggregation { field: field_name })
| MetricAggregation::Max(MaxAggregation { field: field_name })
| MetricAggregation::Min(MinAggregation { field: field_name })
| MetricAggregation::Stats(StatsAggregation { field: field_name })
| MetricAggregation::Sum(SumAggregation { field: field_name }) => {
let (accessor, field_type) = get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?;
Ok(MetricAggregationWithAccessor {
accessor,
field_type,
metric: metric.clone(),
column_block_accessor: Default::default(),
})
}
MetricAggregation::Percentiles(percentiles) => {
let (accessor, field_type) = get_ff_reader_and_validate(
reader,
percentiles.field_name(),
Some(get_numeric_or_date_column_types()),
)?;
Ok(MetricAggregationWithAccessor {
accessor,
field_type,
metric: metric.clone(),
column_block_accessor: Default::default(),
})
}
}
}
}
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
limits: &AggregationLimits,
) -> crate::Result<AggregationsWithAccessor> {
let mut metrics = vec![];
let mut buckets = vec![];
let mut aggss = Vec::new();
for (key, agg) in aggs.iter() {
match agg {
Aggregation::Bucket(bucket) => buckets.push((
key.to_string(),
BucketAggregationWithAccessor::try_from_bucket(
&bucket.bucket_agg,
bucket.get_sub_aggs(),
reader,
limits.clone(),
)?,
)),
Aggregation::Metric(metric) => metrics.push((
key.to_string(),
MetricAggregationWithAccessor::try_from_metric(metric, reader)?,
)),
}
aggss.push((
key.to_string(),
AggregationWithAccessor::try_from_agg(
agg,
agg.sub_aggregation(),
reader,
limits.clone(),
)?,
));
}
Ok(AggregationsWithAccessor::from_data(
VecWithNames::from_entries(metrics),
VecWithNames::from_entries(buckets),
VecWithNames::from_entries(aggss),
))
}

View File

@@ -7,11 +7,8 @@
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::agg_req::BucketAggregationInternal;
use super::bucket::GetDocCount;
use super::intermediate_agg_result::IntermediateBucketResult;
use super::metric::{PercentilesMetricResult, SingleMetricResult, Stats};
use super::segment_agg_result::AggregationLimits;
use super::{AggregationError, Key};
use crate::TantivyError;
@@ -164,14 +161,6 @@ impl BucketResult {
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
}
}
pub(crate) fn empty_from_req(
req: &BucketAggregationInternal,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
empty_bucket.into_final_bucket_result(req, limits)
}
}
/// This is the wrapper of buckets entries, which can be vector or hashmap

View File

@@ -1,11 +1,10 @@
use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation};
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::metric::AverageAggregation;
use crate::aggregation::segment_agg_result::AggregationLimits;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
@@ -14,9 +13,12 @@ use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, Term};
fn get_avg_req(field_name: &str) -> Aggregation {
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name(field_name.to_string()),
))
serde_json::from_value(json!({
"avg": {
"field": field_name,
}
}))
.unwrap()
}
fn get_collector(agg_req: Aggregations) -> AggregationCollector {
@@ -195,6 +197,74 @@ fn test_aggregation_flushing_variants() {
test_aggregation_flushing(true, true).unwrap();
}
#[test]
fn test_aggregation_level1_simple() -> crate::Result<()> {
let index = get_test_index_2_segments(true)?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let range_agg = |field_name: &str| -> Aggregation {
serde_json::from_value(json!({
"range": {
"field": field_name,
"ranges": [ { "from": 3.0f64, "to": 7.0f64 }, { "from": 7.0f64, "to": 20.0f64 } ]
}
}))
.unwrap()
};
let agg_req_1: Aggregations = vec![
("average".to_string(), get_avg_req("score")),
("range".to_string(), range_agg("score")),
]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["average"]["value"], 12.142857142857142);
assert_eq!(
res["range"]["buckets"],
json!(
[
{
"key": "*-3",
"doc_count": 1,
"to": 3.0
},
{
"key": "3-7",
"doc_count": 2,
"from": 3.0,
"to": 7.0
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0
},
{
"key": "20-*",
"doc_count": 1,
"from": 20.0
}
])
);
Ok(())
}
#[test]
fn test_aggregation_level1() -> crate::Result<()> {
let index = get_test_index_2_segments(true)?;
@@ -449,14 +519,14 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> {
let reader = index.reader()?;
let avg_on_field = |field_name: &str| {
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name(field_name.to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"average": {
"avg": {
"field": field_name,
},
}
}))
.unwrap();
let collector = get_collector(agg_req_1);
@@ -471,6 +541,32 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> {
r#"InvalidArgument("Field \"dummy_text\" is not configured as fast field")"#
);
let agg_req_1: Result<Aggregations, serde_json::Error> = serde_json::from_value(json!({
"average": {
"avg": {
"fieldd": "a",
},
}
}));
assert_eq!(agg_req_1.is_err(), true);
assert_eq!(agg_req_1.unwrap_err().to_string(), "missing field `field`");
let agg_req_1: Result<Aggregations, serde_json::Error> = serde_json::from_value(json!({
"average": {
"doesnotmatchanyagg": {
"field": "a",
},
}
}));
assert_eq!(agg_req_1.is_err(), true);
// TODO: This should list valid values
assert_eq!(
agg_req_1.unwrap_err().to_string(),
"no variant of enum AggregationVariants found in flattened data"
);
// TODO: This should return an error
// let agg_res = avg_on_field("not_exist_field").unwrap_err();
// assert_eq!(

View File

@@ -8,18 +8,19 @@ use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::AggregationsInternal;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames};
use crate::aggregation::{f64_from_fastfield_u64, format_date};
use crate::TantivyError;
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
@@ -190,11 +191,13 @@ impl SegmentHistogramBucketEntry {
sub_aggregation: Box<dyn SegmentAggregationCollector>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?;
Ok(IntermediateHistogramBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation: sub_aggregation
.into_intermediate_aggregations_result(agg_with_accessor)?,
sub_aggregation: sub_aggregation_res,
})
}
}
@@ -215,20 +218,18 @@ pub struct SegmentHistogramCollector {
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx];
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)]));
results.push(name, IntermediateAggregationResult::Bucket(bucket));
Ok(IntermediateAggregationResults {
metrics: None,
buckets,
})
Ok(())
}
#[inline]
@@ -246,7 +247,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx];
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let mem_pre = self.get_memory_consumption();
@@ -280,7 +281,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits;
let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits;
limits.add_memory_consumed(mem_delta as u64);
limits.validate_memory_consumption()?;
@@ -289,7 +290,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(sub_aggregation_accessor)?;
@@ -308,7 +309,7 @@ impl SegmentHistogramCollector {
}
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
@@ -384,7 +385,7 @@ fn get_bucket_key_from_pos(bucket_pos: f64, interval: f64, offset: f64) -> f64 {
fn intermediate_buckets_to_final_buckets_fill_gaps(
buckets: Vec<IntermediateHistogramBucketEntry>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
sub_aggregation: &Aggregations,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
// Generate the full list of buckets without gaps.
@@ -443,7 +444,7 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets(
buckets: Vec<IntermediateHistogramBucketEntry>,
column_type: Option<ColumnType>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
sub_aggregation: &Aggregations,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
let mut buckets = if histogram_req.min_doc_count() == 0 {
@@ -695,7 +696,7 @@ mod tests {
assert_eq!(
res.to_string(),
"Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \
102.48 KB"
59.71 KB"
);
Ok(())

View File

@@ -2,6 +2,16 @@
//!
//! BucketAggregations create buckets of documents
//! [`BucketAggregation`](super::agg_req::BucketAggregation).
//! Each bucket is associated with a rule which
//! determines whether or not a document in the falls into it. In other words, the buckets
//! effectively define document sets. Buckets are not necessarily disjunct, therefore a document can
//! fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also
//! compute and return the number of documents for each bucket. Bucket aggregations, as opposed to
//! metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for
//! the buckets created by their "parent" bucket aggregation. There are different bucket
//! aggregators, each with a different "bucketing" strategy. Some define a single bucket, some
//! define fixed number of multiple buckets, and others dynamically create the buckets during the
//! aggregation process.
//!
//! Results of final buckets are [`BucketResult`](super::agg_result::BucketResult).
//! Results of intermediate buckets are

View File

@@ -7,14 +7,14 @@ use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry,
IntermediateRangeBucketResult,
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames,
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey,
};
use crate::TantivyError;
@@ -157,8 +157,10 @@ impl SegmentRangeBucketEntry {
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateRangeBucketEntry> {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)?
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?
} else {
Default::default()
};
@@ -166,7 +168,7 @@ impl SegmentRangeBucketEntry {
Ok(IntermediateRangeBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation,
sub_aggregation: sub_aggregation_res,
from: self.from,
to: self.to,
})
@@ -174,13 +176,14 @@ impl SegmentRangeBucketEntry {
}
impl SegmentAggregationCollector for SegmentRangeCollector {
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let field_type = self.column_type;
let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
@@ -200,12 +203,9 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
column_type: Some(self.column_type),
});
let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)]));
results.push(name, IntermediateAggregationResult::Bucket(bucket));
Ok(IntermediateAggregationResults {
metrics: None,
buckets,
})
Ok(())
}
#[inline]
@@ -223,7 +223,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx];
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
bucket_agg_accessor
.column_block_accessor
@@ -245,7 +245,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {

View File

@@ -7,16 +7,16 @@ use serde::{Deserialize, Serialize};
use super::{CustomOrder, Order, OrderTarget};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateTermBucketEntry,
IntermediateTermBucketResult,
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames};
use crate::aggregation::{f64_from_fastfield_u64, Key};
use crate::error::DataCorruption;
use crate::TantivyError;
@@ -246,20 +246,18 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
}
impl SegmentAggregationCollector for SegmentTermCollector {
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.buckets.values[self.accessor_idx];
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)]));
results.push(name, IntermediateAggregationResult::Bucket(bucket));
Ok(IntermediateAggregationResults {
metrics: None,
buckets,
})
Ok(())
}
#[inline]
@@ -277,7 +275,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx];
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let mem_pre = self.get_memory_consumption();
@@ -301,7 +299,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits;
let limits = &agg_with_accessor.aggs.values[self.accessor_idx].limits;
limits.add_memory_consumed(mem_delta as u64);
limits.validate_memory_consumption()?;
@@ -310,7 +308,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
self.term_buckets.force_flush(sub_aggregation_accessor)?;
Ok(())
@@ -337,7 +335,7 @@ impl SegmentTermCollector {
if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
sub_aggregations.metrics.get(agg_name).ok_or_else(|| {
sub_aggregations.aggs.get(agg_name).ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {} in metric sub_aggregations",
agg_name
@@ -366,7 +364,7 @@ impl SegmentTermCollector {
#[inline]
pub(crate) fn into_intermediate_bucket_result(
mut self,
agg_with_accessor: &BucketAggregationWithAccessor,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();
@@ -410,21 +408,24 @@ impl SegmentTermCollector {
let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if self.blueprint.as_ref().is_some() {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.term_buckets
.sub_aggs
.remove(&id)
.unwrap_or_else(|| {
panic!(
"Internal Error: could not find subaggregation for id {}",
id
)
})
.add_intermediate_aggregation_result(
&agg_with_accessor.sub_aggregation,
&mut sub_aggregation_res,
)?;
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: self
.term_buckets
.sub_aggs
.remove(&id)
.unwrap_or_else(|| {
panic!(
"Internal Error: could not find subaggregation for id {}",
id
)
})
.into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation,
)?,
sub_aggregation: sub_aggregation_res,
}
} else {
IntermediateTermBucketEntry {
@@ -522,8 +523,7 @@ pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
#[cfg(test)]
mod tests {
use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{
exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit,
get_test_index_from_terms, get_test_index_from_values_and_terms,
@@ -638,23 +638,6 @@ mod tests {
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let _sub_agg: Aggregations = vec![
(
"avg_score".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
(
"stats_score".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
),
]
.into_iter()
.collect();
let sub_agg: Aggregations = serde_json::from_value(json!({
"avg_score": {
"avg": {

View File

@@ -35,11 +35,12 @@ impl BufAggregationCollector {
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results)
}
#[inline]

View File

@@ -184,6 +184,13 @@ impl SegmentCollector for AggregationSegmentCollector {
return Err(err);
}
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor)
let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
Ok(sub_aggregation_res)
}
}

View File

@@ -10,10 +10,7 @@ use rustc_hash::FxHashMap;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::agg_req::{
Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType,
MetricAggregation,
};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry};
use super::bucket::{
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
@@ -34,20 +31,22 @@ use crate::TantivyError;
/// intermediate results.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateAggregationResults {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) metrics: Option<VecWithNames<IntermediateMetricResult>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) buckets: Option<VecWithNames<IntermediateBucketResult>>,
pub(crate) aggs_res: VecWithNames<IntermediateAggregationResult>,
}
impl IntermediateAggregationResults {
/// Add a result
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) {
self.aggs_res.push(key, value);
}
/// Convert intermediate result and its aggregation request to the final result.
pub fn into_final_result(
self,
req: Aggregations,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
let res = self.into_final_result_internal(&(req.into()), limits)?;
let res = self.into_final_result_internal(&req, limits)?;
let bucket_count = res.get_bucket_count() as u32;
if bucket_count > limits.get_bucket_limit() {
return Err(TantivyError::AggregationError(
@@ -66,67 +65,35 @@ impl IntermediateAggregationResults {
/// for internal processing, by splitting metric and buckets into separate groups.
pub(crate) fn into_final_result_internal(
self,
req: &AggregationsInternal,
req: &Aggregations,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
// request
let mut results: FxHashMap<String, AggregationResult> = FxHashMap::default();
if let Some(buckets) = self.buckets {
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)?
} else {
// When there are no buckets, we create empty buckets, so that the serialized json
// format is constant
add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)?
};
if let Some(metrics) = self.metrics {
convert_and_add_final_metrics_to_result(&mut results, metrics, &req.metrics);
} else {
// When there are no metrics, we create empty metric results, so that the serialized
// json format is constant
add_empty_final_metrics_to_result(&mut results, &req.metrics)?;
for (key, agg_res) in self.aggs_res.into_iter() {
let req = req.get(key.as_str()).unwrap();
results.insert(key, agg_res.into_final_result(req, limits)?);
}
// Handle empty results
if results.len() != req.len() {
for (key, req) in req.iter() {
if !results.contains_key(key) {
let empty_res = empty_from_req(req);
results.insert(key.to_string(), empty_res.into_final_result(req, limits)?);
}
}
}
Ok(AggregationResults(results))
}
pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self {
let metrics = if req.metrics.is_empty() {
None
} else {
let metrics = req
.metrics
.iter()
.map(|(key, req)| {
(
key.to_string(),
IntermediateMetricResult::empty_from_req(req),
)
})
.collect();
Some(VecWithNames::from_entries(metrics))
};
pub(crate) fn empty_from_req(req: &Aggregations) -> Self {
let mut aggs_res: VecWithNames<IntermediateAggregationResult> = VecWithNames::default();
for (key, req) in req.iter() {
let empty_res = empty_from_req(req);
aggs_res.push(key.to_string(), empty_res);
}
let buckets = if req.buckets.is_empty() {
None
} else {
let buckets = req
.buckets
.iter()
.map(|(key, req)| {
(
key.to_string(),
IntermediateBucketResult::empty_from_req(&req.bucket_agg),
)
})
.collect();
Some(VecWithNames::from_entries(buckets))
};
Self { metrics, buckets }
Self { aggs_res }
}
/// Merge another intermediate aggregation result into this result.
@@ -134,85 +101,50 @@ impl IntermediateAggregationResults {
/// The order of the values need to be the same on both results. This is ensured when the same
/// (key values) are present on the underlying `VecWithNames` struct.
pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> {
if let (Some(buckets_left), Some(buckets_right)) = (&mut self.buckets, other.buckets) {
for (bucket_left, bucket_right) in
buckets_left.values_mut().zip(buckets_right.into_values())
{
bucket_left.merge_fruits(bucket_right)?;
}
}
if let (Some(metrics_left), Some(metrics_right)) = (&mut self.metrics, other.metrics) {
for (metric_left, metric_right) in
metrics_left.values_mut().zip(metrics_right.into_values())
{
metric_left.merge_fruits(metric_right)?;
}
for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) {
left.merge_fruits(right)?;
}
Ok(())
}
}
fn convert_and_add_final_metrics_to_result(
results: &mut FxHashMap<String, AggregationResult>,
metrics: VecWithNames<IntermediateMetricResult>,
metrics_req: &VecWithNames<MetricAggregation>,
) {
let metric_result_with_request = metrics.into_iter().zip(metrics_req.values());
results.extend(
metric_result_with_request
.into_iter()
.map(|((key, metric), req)| {
(
key,
AggregationResult::MetricResult(metric.into_final_metric_result(req)),
)
}),
);
}
fn add_empty_final_metrics_to_result(
results: &mut FxHashMap<String, AggregationResult>,
req_metrics: &VecWithNames<MetricAggregation>,
) -> crate::Result<()> {
results.extend(req_metrics.iter().map(|(key, req)| {
let empty_bucket = IntermediateMetricResult::empty_from_req(req);
(
key.to_string(),
AggregationResult::MetricResult(empty_bucket.into_final_metric_result(req)),
)
}));
Ok(())
}
fn add_empty_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
let requested_buckets = req_buckets.iter();
for (key, req) in requested_buckets {
let empty_bucket =
AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?);
results.insert(key.to_string(), empty_bucket);
pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult {
use AggregationVariants::*;
match req.agg {
Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms(
Default::default(),
)),
Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(
Default::default(),
)),
Histogram(_) | DateHistogram(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Histogram {
buckets: Vec::new(),
column_type: None,
})
}
Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average(
IntermediateAverage::default(),
)),
Count(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Count(
IntermediateCount::default(),
)),
Max(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Max(
IntermediateMax::default(),
)),
Min(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Min(
IntermediateMin::default(),
)),
Stats(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Stats(
IntermediateStats::default(),
)),
Sum(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Sum(
IntermediateSum::default(),
)),
Percentiles(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Percentiles(PercentilesCollector::default()),
),
}
Ok(())
}
fn convert_and_add_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
buckets: VecWithNames<IntermediateBucketResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
assert_eq!(buckets.len(), req_buckets.len());
let buckets_with_request = buckets.into_iter().zip(req_buckets.values());
for ((key, bucket), req) in buckets_with_request {
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?);
results.insert(key, result);
}
Ok(())
}
/// An aggregation is either a bucket or a metric.
@@ -224,6 +156,37 @@ pub enum IntermediateAggregationResult {
Metric(IntermediateMetricResult),
}
impl IntermediateAggregationResult {
pub(crate) fn into_final_result(
self,
req: &Aggregation,
limits: &AggregationLimits,
) -> crate::Result<AggregationResult> {
let res = match self {
IntermediateAggregationResult::Bucket(bucket) => {
AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?)
}
IntermediateAggregationResult::Metric(metric) => {
AggregationResult::MetricResult(metric.into_final_metric_result(req))
}
};
Ok(res)
}
fn merge_fruits(&mut self, other: IntermediateAggregationResult) -> crate::Result<()> {
match (self, other) {
(
IntermediateAggregationResult::Bucket(b1),
IntermediateAggregationResult::Bucket(b2),
) => b1.merge_fruits(b2),
(
IntermediateAggregationResult::Metric(m1),
IntermediateAggregationResult::Metric(m2),
) => m1.merge_fruits(m2),
_ => panic!("aggregation result type mismatch (mixed metric and buckets)"),
}
}
}
/// Holds the intermediate data for metric results
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum IntermediateMetricResult {
@@ -244,7 +207,7 @@ pub enum IntermediateMetricResult {
}
impl IntermediateMetricResult {
fn into_final_metric_result(self, req: &MetricAggregation) -> MetricResult {
fn into_final_metric_result(self, req: &Aggregation) -> MetricResult {
match self {
IntermediateMetricResult::Average(intermediate_avg) => {
MetricResult::Average(intermediate_avg.finalize().into())
@@ -265,30 +228,12 @@ impl IntermediateMetricResult {
MetricResult::Sum(intermediate_sum.finalize().into())
}
IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles(
percentiles.into_final_result(req.as_percentile().expect("unexpected metric type")),
percentiles
.into_final_result(req.agg.as_percentile().expect("unexpected metric type")),
),
}
}
pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self {
match req {
MetricAggregation::Average(_) => {
IntermediateMetricResult::Average(IntermediateAverage::default())
}
MetricAggregation::Count(_) => {
IntermediateMetricResult::Count(IntermediateCount::default())
}
MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()),
MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()),
MetricAggregation::Stats(_) => {
IntermediateMetricResult::Stats(IntermediateStats::default())
}
MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()),
MetricAggregation::Percentiles(_) => {
IntermediateMetricResult::Percentiles(PercentilesCollector::default())
}
}
}
fn merge_fruits(&mut self, other: IntermediateMetricResult) -> crate::Result<()> {
match (self, other) {
(
@@ -355,7 +300,7 @@ pub enum IntermediateBucketResult {
impl IntermediateBucketResult {
pub(crate) fn into_final_bucket_result(
self,
req: &BucketAggregationInternal,
req: &Aggregation,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
match self {
@@ -366,7 +311,8 @@ impl IntermediateBucketResult {
.map(|bucket| {
bucket.into_final_bucket_entry(
req.sub_aggregation(),
req.as_range()
req.agg
.as_range()
.expect("unexpected aggregation, expected histogram aggregation"),
range_res.column_type,
limits,
@@ -381,6 +327,7 @@ impl IntermediateBucketResult {
});
let is_keyed = req
.agg
.as_range()
.expect("unexpected aggregation, expected range aggregation")
.keyed;
@@ -401,6 +348,7 @@ impl IntermediateBucketResult {
buckets,
} => {
let histogram_req = &req
.agg
.as_histogram()?
.expect("unexpected aggregation, expected histogram aggregation");
let buckets = intermediate_histogram_buckets_to_final_buckets(
@@ -424,7 +372,8 @@ impl IntermediateBucketResult {
Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
req.as_term()
req.agg
.as_term()
.expect("unexpected aggregation, expected term aggregation"),
req.sub_aggregation(),
limits,
@@ -432,18 +381,6 @@ impl IntermediateBucketResult {
}
}
pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self {
match req {
BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()),
BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()),
BucketAggregationType::Histogram(_) | BucketAggregationType::DateHistogram(_) => {
IntermediateBucketResult::Histogram {
buckets: vec![],
column_type: None,
}
}
}
}
fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> {
match (self, other) {
(
@@ -551,7 +488,7 @@ impl IntermediateTermBucketResult {
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
sub_aggregation_req: &Aggregations,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
let req = TermsAggregationInternal::from_req(req);
@@ -687,7 +624,7 @@ pub struct IntermediateHistogramBucketEntry {
impl IntermediateHistogramBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
req: &Aggregations,
limits: &AggregationLimits,
) -> crate::Result<BucketEntry> {
Ok(BucketEntry {
@@ -732,7 +669,7 @@ pub struct IntermediateRangeBucketEntry {
impl IntermediateRangeBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
req: &Aggregations,
_range_req: &RangeAggregation,
column_type: Option<ColumnType>,
limits: &AggregationLimits,
@@ -825,14 +762,15 @@ mod tests {
}
map.insert(
"my_agg_level2".to_string(),
IntermediateBucketResult::Range(IntermediateRangeBucketResult {
buckets,
column_type: None,
}),
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(
IntermediateRangeBucketResult {
buckets,
column_type: None,
},
)),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),
metrics: Default::default(),
aggs_res: VecWithNames::from_entries(map.into_iter().collect()),
}
}
@@ -858,14 +796,15 @@ mod tests {
}
map.insert(
"my_agg_level1".to_string(),
IntermediateBucketResult::Range(IntermediateRangeBucketResult {
buckets,
column_type: None,
}),
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(
IntermediateRangeBucketResult {
buckets,
column_type: None,
},
)),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),
metrics: Default::default(),
aggs_res: VecWithNames::from_entries(map.into_iter().collect()),
}
}

View File

@@ -1,7 +1,12 @@
//! Module for all metric aggregations.
//!
//! The aggregations in this family compute metrics, see [super::agg_req::MetricAggregation] for
//! details.
//! The aggregations in this family compute metrics based on values extracted
//! from the documents that are being aggregated. Values are extracted from the fast field of
//! the document.
//! Some aggregations output a single numeric metric (e.g. Average) and are called
//! single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are
//! called multi-value numeric metrics aggregation.
mod average;
mod count;
mod max;

View File

@@ -5,13 +5,13 @@ use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, MetricAggregationWithAccessor,
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateMetricResult,
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{f64_from_fastfield_u64, AggregationError, VecWithNames};
use crate::aggregation::{f64_from_fastfield_u64, AggregationError};
use crate::{DocId, TantivyError};
/// # Percentiles
@@ -240,7 +240,7 @@ impl SegmentPercentilesCollector {
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut MetricAggregationWithAccessor,
agg_accessor: &mut AggregationWithAccessor,
) {
agg_accessor
.column_block_accessor
@@ -255,22 +255,20 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string();
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
let metrics = Some(VecWithNames::from_entries(vec![(
results.push(
name,
intermediate_metric_result,
)]));
IntermediateAggregationResult::Metric(intermediate_metric_result),
);
Ok(IntermediateAggregationResults {
metrics,
buckets: None,
})
Ok(())
}
#[inline]
@@ -279,7 +277,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor;
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
@@ -295,7 +293,7 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.metrics.values[self.accessor_idx];
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
@@ -310,9 +308,8 @@ mod tests {
use rand::SeedableRng;
use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::metric::PercentilesAggregationReq;
use crate::aggregation::tests::{
get_test_index_from_values, get_test_index_from_values_and_terms,
};
@@ -326,14 +323,14 @@ mod tests {
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"percentiles".to_string(),
Aggregation::Metric(MetricAggregation::Percentiles(
PercentilesAggregationReq::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"percentiles": {
"percentiles": {
"field": "score",
}
},
}))
.unwrap();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
@@ -364,14 +361,14 @@ mod tests {
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"percentiles".to_string(),
Aggregation::Metric(MetricAggregation::Percentiles(
PercentilesAggregationReq::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"percentiles": {
"percentiles": {
"field": "score",
}
},
}))
.unwrap();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());

View File

@@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, MetricAggregationWithAccessor,
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::f64_from_fastfield_u64;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateMetricResult,
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{f64_from_fastfield_u64, VecWithNames};
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
@@ -179,7 +179,7 @@ impl SegmentStatsCollector {
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut MetricAggregationWithAccessor,
agg_accessor: &mut AggregationWithAccessor,
) {
agg_accessor
.column_block_accessor
@@ -194,11 +194,12 @@ impl SegmentStatsCollector {
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string();
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let intermediate_metric_result = match self.collecting_for {
SegmentStatsType::Average => {
@@ -219,15 +220,12 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
}
};
let metrics = Some(VecWithNames::from_entries(vec![(
results.push(
name,
intermediate_metric_result,
)]));
IntermediateAggregationResult::Metric(intermediate_metric_result),
);
Ok(IntermediateAggregationResults {
metrics,
buckets: None,
})
Ok(())
}
#[inline]
@@ -236,7 +234,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor;
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
@@ -252,7 +250,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.metrics.values[self.accessor_idx];
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
@@ -263,9 +261,8 @@ mod tests {
use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations, MetricAggregation};
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::metric::StatsAggregation;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values};
use crate::aggregation::AggregationCollector;
use crate::query::{AllQuery, TermQuery};
@@ -279,14 +276,14 @@ mod tests {
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"stats": {
"stats": {
"field": "score",
},
}
}))
.unwrap();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
@@ -315,14 +312,14 @@ mod tests {
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"stats": {
"stats": {
"field": "score",
},
}
}))
.unwrap();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
@@ -374,29 +371,25 @@ mod tests {
.unwrap()
};
let agg_req_1: Aggregations = vec![
(
"stats_i64".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score_i64".to_string(),
))),
),
(
"stats_f64".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score_f64".to_string(),
))),
),
(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
),
("range".to_string(), range_agg),
]
.into_iter()
.collect();
let agg_req_1: Aggregations = serde_json::from_value(json!({
"stats_i64": {
"stats": {
"field": "score_i64",
},
},
"stats_f64": {
"stats": {
"field": "score_f64",
},
},
"stats": {
"stats": {
"field": "score",
},
},
"range": range_agg
}))
.unwrap();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());

View File

@@ -49,34 +49,6 @@
//! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an
//! `(String, agg_req::Aggregation)` iterator.
//!
//! ```
//! use tantivy::aggregation::agg_req::{Aggregations, Aggregation, MetricAggregation};
//! use tantivy::aggregation::AggregationCollector;
//! use tantivy::aggregation::metric::AverageAggregation;
//! use tantivy::query::AllQuery;
//! use tantivy::aggregation::agg_result::AggregationResults;
//! use tantivy::IndexReader;
//!
//! # #[allow(dead_code)]
//! fn aggregate_on_index(reader: &IndexReader) {
//! let agg_req: Aggregations = vec![
//! (
//! "average".to_string(),
//! Aggregation::Metric(MetricAggregation::Average(
//! AverageAggregation::from_field_name("score".to_string()),
//! )),
//! ),
//! ]
//! .into_iter()
//! .collect();
//!
//! let collector = AggregationCollector::from_aggs(agg_req, Default::default());
//!
//! let searcher = reader.searcher();
//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
//! }
//! ```
//! # Example JSON
//! Requests are compatible with the elasticsearch JSON request format.
//!
//! ```
@@ -116,32 +88,24 @@
//! aggregation and then calculate the average on each bucket.
//! ```
//! use tantivy::aggregation::agg_req::*;
//! use tantivy::aggregation::metric::AverageAggregation;
//! use tantivy::aggregation::bucket::RangeAggregation;
//! let sub_agg_req_1: Aggregations = vec![(
//! "average_in_range".to_string(),
//! Aggregation::Metric(MetricAggregation::Average(
//! AverageAggregation::from_field_name("score".to_string()),
//! )),
//! )]
//! .into_iter()
//! .collect();
//! use serde_json::json;
//!
//! let agg_req_1: Aggregations = vec![
//! (
//! "range".to_string(),
//! Aggregation::Bucket(Box::new(BucketAggregation {
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: sub_agg_req_1.clone(),
//! })),
//! ),
//! ]
//! .into_iter()
//! .collect();
//! let agg_req_1: Aggregations = serde_json::from_value(json!({
//! "rangef64": {
//! "range": {
//! "field": "score",
//! "ranges": [
//! { "from": 3, "to": 7000 },
//! { "from": 7000, "to": 20000 },
//! { "from": 50000, "to": 60000 }
//! ]
//! },
//! "aggs": {
//! "average_in_range": { "avg": { "field": "score" } }
//! }
//! },
//! }))
//! .unwrap();
//! ```
//!
//! # Distributed Aggregation
@@ -216,9 +180,9 @@ impl<T: Clone> From<HashMap<String, T>> for VecWithNames<T> {
}
impl<T: Clone> VecWithNames<T> {
fn extend(&mut self, entries: VecWithNames<T>) {
self.keys.extend(entries.keys);
self.values.extend(entries.values);
fn push(&mut self, key: String, value: T) {
self.keys.push(key);
self.values.push(value);
}
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
@@ -247,9 +211,6 @@ impl<T: Clone> VecWithNames<T> {
fn into_values(self) -> impl Iterator<Item = T> {
self.values.into_iter()
}
fn values(&self) -> impl Iterator<Item = &T> + '_ {
self.values.iter()
}
fn values_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
self.values.iter_mut()
}

View File

@@ -6,10 +6,8 @@
use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimits;
use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
};
use super::agg_req::AggregationVariants;
use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
@@ -17,14 +15,13 @@ use super::metric::{
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
SumAggregation,
};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults>;
results: &mut IntermediateAggregationResults,
) -> crate::Result<()>;
fn collect(
&mut self,
@@ -66,52 +63,79 @@ impl Clone for Box<dyn SegmentAggregationCollector> {
pub(crate) fn build_segment_agg_collector(
req: &AggregationsWithAccessor,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
// Single metric special case
if req.buckets.is_empty() && req.metrics.len() == 1 {
let req = &req.metrics.values[0];
// Single collector special case
if req.aggs.is_empty() && req.aggs.len() == 1 {
let req = &req.aggs.values[0];
let accessor_idx = 0;
return build_metric_segment_agg_collector(req, accessor_idx);
}
// Single bucket special case
if req.metrics.is_empty() && req.buckets.len() == 1 {
let req = &req.buckets.values[0];
let accessor_idx = 0;
return build_bucket_segment_agg_collector(req, accessor_idx);
return build_single_agg_segment_collector(req, accessor_idx);
}
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
Ok(Box::new(agg))
}
pub(crate) fn build_metric_segment_agg_collector(
req: &MetricAggregationWithAccessor,
pub(crate) fn build_single_agg_segment_collector(
req: &AggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match &req.metric {
MetricAggregation::Average(AverageAggregation { .. }) => {
Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Average,
accessor_idx,
)))
}
MetricAggregation::Count(CountAggregation { .. }) => Ok(Box::new(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count, accessor_idx),
)),
MetricAggregation::Max(MaxAggregation { .. }) => Ok(Box::new(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max, accessor_idx),
)),
MetricAggregation::Min(MinAggregation { .. }) => Ok(Box::new(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min, accessor_idx),
)),
MetricAggregation::Stats(StatsAggregation { .. }) => Ok(Box::new(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats, accessor_idx),
)),
MetricAggregation::Sum(SumAggregation { .. }) => Ok(Box::new(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum, accessor_idx),
)),
MetricAggregation::Percentiles(percentiles_req) => Ok(Box::new(
use AggregationVariants::*;
match &req.agg.agg {
Terms(terms_req) => Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.limits,
req.field_type,
accessor_idx,
)?)),
Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
&histogram.to_histogram_req()?,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
Average(AverageAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Average,
accessor_idx,
))),
Count(CountAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Count,
accessor_idx,
))),
Max(MaxAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Max,
accessor_idx,
))),
Min(MinAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Min,
accessor_idx,
))),
Stats(StatsAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
))),
Sum(SumAggregation { .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Sum,
accessor_idx,
))),
Percentiles(percentiles_req) => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
percentiles_req,
req.field_type,
@@ -121,96 +145,33 @@ pub(crate) fn build_metric_segment_agg_collector(
}
}
pub(crate) fn build_bucket_segment_agg_collector(
req: &BucketAggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match &req.bucket_agg {
BucketAggregationType::Terms(terms_req) => {
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::Range(range_req) => {
Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.limits,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::Histogram(histogram) => {
Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::DateHistogram(histogram) => {
Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
&histogram.to_histogram_req()?,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
pub(crate) struct GenericSegmentAggregationResultsCollector {
pub(crate) metrics: Option<Vec<Box<dyn SegmentAggregationCollector>>>,
pub(crate) buckets: Option<Vec<Box<dyn SegmentAggregationCollector>>>,
pub(crate) aggs: Vec<Box<dyn SegmentAggregationCollector>>,
}
impl Debug for GenericSegmentAggregationResultsCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("metrics", &self.metrics)
.field("buckets", &self.buckets)
.field("aggs", &self.aggs)
.finish()
}
}
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn into_intermediate_aggregations_result(
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let buckets = if let Some(buckets) = self.buckets {
let mut intermeditate_buckets = VecWithNames::default();
for bucket in buckets {
// TODO too many allocations?
let res = bucket.into_intermediate_aggregations_result(agg_with_accessor)?;
// unwrap is fine since we only have buckets here
intermeditate_buckets.extend(res.buckets.unwrap());
}
Some(intermeditate_buckets)
} else {
None
};
let metrics = if let Some(metrics) = self.metrics {
let mut intermeditate_metrics = VecWithNames::default();
for metric in metrics {
// TODO too many allocations?
let res = metric.into_intermediate_aggregations_result(agg_with_accessor)?;
// unwrap is fine since we only have metrics here
intermeditate_metrics.extend(res.metrics.unwrap());
}
Some(intermeditate_metrics)
} else {
None
};
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_with_accessor, results)?;
}
Ok(IntermediateAggregationResults { metrics, buckets })
Ok(())
}
fn collect(
@@ -228,31 +189,16 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
if let Some(metrics) = self.metrics.as_mut() {
for collector in metrics {
collector.collect_block(docs, agg_with_accessor)?;
}
}
if let Some(buckets) = self.buckets.as_mut() {
for collector in buckets {
collector.collect_block(docs, agg_with_accessor)?;
}
for collector in &mut self.aggs {
collector.collect_block(docs, agg_with_accessor)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
if let Some(metrics) = &mut self.metrics {
for collector in metrics {
collector.flush(agg_with_accessor)?;
}
}
if let Some(buckets) = &mut self.buckets {
for collector in buckets {
collector.flush(agg_with_accessor)?;
}
for collector in &mut self.aggs {
collector.flush(agg_with_accessor)?;
}
Ok(())
}
@@ -260,34 +206,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
impl GenericSegmentAggregationResultsCollector {
pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result<Self> {
let buckets = req
.buckets
let aggs = req
.aggs
.iter()
.enumerate()
.map(|(accessor_idx, (_key, req))| {
build_bucket_segment_agg_collector(req, accessor_idx)
})
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
let metrics = req
.metrics
.iter()
.enumerate()
.map(|(accessor_idx, (_key, req))| {
build_metric_segment_agg_collector(req, accessor_idx)
build_single_agg_segment_collector(req, accessor_idx)
})
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
let metrics = if metrics.is_empty() {
None
} else {
Some(metrics)
};
let buckets = if buckets.is_empty() {
None
} else {
Some(buckets)
};
Ok(GenericSegmentAggregationResultsCollector { metrics, buckets })
Ok(GenericSegmentAggregationResultsCollector { aggs })
}
}