From cdffce906c5cbddbfdf0f0684e880d3b109acbf4 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Mon, 30 Jan 2023 14:30:40 +0800 Subject: [PATCH] introduce SegmentAggregationCollector trait add dynamic dispatch for aggregation collection, to allow specialized implementations --- src/aggregation/bucket/histogram/histogram.rs | 12 +- src/aggregation/bucket/range.rs | 20 +- src/aggregation/bucket/term_agg.rs | 42 ++-- src/aggregation/collector.rs | 10 +- src/aggregation/intermediate_agg_result.rs | 15 +- src/aggregation/metric/stats.rs | 69 +++++- src/aggregation/mod.rs | 4 +- src/aggregation/segment_agg_result.rs | 210 +++++++++++++----- 8 files changed, 265 insertions(+), 117 deletions(-) diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 0ad1c62e0..f025fe57e 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -13,7 +13,9 @@ use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::segment_agg_result::{ + GenericSegmentAggregationResultsCollector, SegmentAggregationCollector, +}; use crate::aggregation::{f64_from_fastfield_u64, format_date}; use crate::schema::{Schema, Type}; use crate::{DocId, TantivyError}; @@ -184,7 +186,7 @@ pub(crate) struct SegmentHistogramBucketEntry { impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - sub_aggregation: SegmentAggregationResultsCollector, + sub_aggregation: GenericSegmentAggregationResultsCollector, agg_with_accessor: &AggregationsWithAccessor, ) -> crate::Result { Ok(IntermediateHistogramBucketEntry { @@ -198,11 +200,11 @@ impl SegmentHistogramBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. buckets: Vec, - sub_aggregations: Option>, + sub_aggregations: Option>, field_type: Type, interval: f64, offset: f64, @@ -300,7 +302,7 @@ impl SegmentHistogramCollector { None } else { let sub_aggregation = - SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregation)?; + GenericSegmentAggregationResultsCollector::from_req_and_validate(sub_aggregation)?; Some(buckets.iter().map(|_| sub_aggregation.clone()).collect()) }; diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 0e493fdf7..51eb06330 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -11,7 +11,9 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; +use crate::aggregation::segment_agg_result::{ + BucketCount, GenericSegmentAggregationResultsCollector, SegmentAggregationCollector, +}; use crate::aggregation::{ f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, }; @@ -114,7 +116,7 @@ impl From> for InternalRangeAggregationRange { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentRangeAndBucketEntry { range: Range, bucket: SegmentRangeBucketEntry, @@ -122,18 +124,18 @@ pub(crate) struct SegmentRangeAndBucketEntry { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct SegmentRangeCollector { /// The buckets containing the aggregation data. buckets: Vec, field_type: Type, } -#[derive(Clone, PartialEq)] +#[derive(Clone)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key, pub doc_count: u64, - pub sub_aggregation: Option, + pub sub_aggregation: Option, /// The from range of the bucket. Equals `f64::MIN` when `None`. pub from: Option, /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not @@ -227,9 +229,11 @@ impl SegmentRangeCollector { let sub_aggregation = if sub_aggregation.is_empty() { None } else { - Some(SegmentAggregationResultsCollector::from_req_and_validate( - sub_aggregation, - )?) + Some( + GenericSegmentAggregationResultsCollector::from_req_and_validate( + sub_aggregation, + )?, + ) }; Ok(SegmentRangeAndBucketEntry { diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index cc17f1627..5e8070cb8 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -12,7 +12,10 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::segment_agg_result::{ + build_segment_agg_collector, GenericSegmentAggregationResultsCollector, + SegmentAggregationCollector, +}; use crate::error::DataCorruption; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -196,17 +199,16 @@ impl TermsAggregationInternal { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Default)] /// Container to store term_ids and their buckets. struct TermBuckets { pub(crate) entries: FxHashMap, - blueprint: Option, } -#[derive(Clone, PartialEq, Default)] +#[derive(Clone, Default)] struct TermBucketEntry { doc_count: u64, - sub_aggregations: Option>, + sub_aggregations: Option>, } impl Debug for TermBucketEntry { @@ -218,7 +220,7 @@ impl Debug for TermBucketEntry { } impl TermBucketEntry { - fn from_blueprint(blueprint: &Option>) -> Self { + fn from_blueprint(blueprint: &Option>) -> Self { Self { doc_count: 0, sub_aggregations: blueprint.clone(), @@ -247,18 +249,7 @@ impl TermBuckets { sub_aggregation: &AggregationsWithAccessor, _max_term_id: usize, ) -> crate::Result { - let has_sub_aggregations = sub_aggregation.is_empty(); - - let blueprint = if has_sub_aggregations { - let sub_aggregation = - SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregation)?; - Some(sub_aggregation) - } else { - None - }; - Ok(TermBuckets { - blueprint, entries: Default::default(), }) } @@ -275,13 +266,12 @@ impl TermBuckets { /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct SegmentTermCollector { /// The buckets containing the aggregation data. term_buckets: TermBuckets, req: TermsAggregationInternal, - field_type: Type, - blueprint: Option>, + blueprint: Option>, } pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { @@ -293,12 +283,8 @@ impl SegmentTermCollector { pub(crate) fn from_req_and_validate( req: &TermsAggregation, sub_aggregations: &AggregationsWithAccessor, - field_type: Type, - accessor: &Column, ) -> crate::Result { - let max_term_id = accessor.max_value(); - let term_buckets = - TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?; + let term_buckets = TermBuckets::default(); if let Some(custom_order) = req.order.as_ref() { // Validate sub aggregtion exists @@ -316,9 +302,8 @@ impl SegmentTermCollector { let has_sub_aggregations = !sub_aggregations.is_empty(); let blueprint = if has_sub_aggregations { - let sub_aggregation = - SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregations)?; - Some(Box::new(sub_aggregation)) + let sub_aggregation = build_segment_agg_collector(sub_aggregations)?; + Some(sub_aggregation) } else { None }; @@ -326,7 +311,6 @@ impl SegmentTermCollector { Ok(SegmentTermCollector { req: TermsAggregationInternal::from_req(req), term_buckets, - field_type, blueprint, }) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index a53ac268a..348da8c76 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -4,7 +4,10 @@ use super::agg_req::Aggregations; use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::SegmentAggregationResultsCollector; +use super::segment_agg_result::{ + build_segment_agg_collector, GenericSegmentAggregationResultsCollector, + SegmentAggregationCollector, +}; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; use crate::schema::Schema; @@ -137,7 +140,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsWithAccessor, - result: SegmentAggregationResultsCollector, + result: Box, error: Option, } @@ -151,8 +154,7 @@ impl AggregationSegmentCollector { ) -> crate::Result { let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?; - let result = - SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; + let result = build_segment_agg_collector(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { aggs_with_accessor, result, diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 390cee13f..ffef02cea 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -222,24 +222,23 @@ pub enum IntermediateMetricResult { impl From for IntermediateMetricResult { fn from(tree: SegmentMetricResultCollector) -> Self { + use super::metric::SegmentStatsType; match tree { SegmentMetricResultCollector::Stats(collector) => match collector.collecting_for { - super::metric::SegmentStatsType::Average => IntermediateMetricResult::Average( + SegmentStatsType::Average => IntermediateMetricResult::Average( IntermediateAverage::from_collector(collector), ), - super::metric::SegmentStatsType::Count => { + SegmentStatsType::Count => { IntermediateMetricResult::Count(IntermediateCount::from_collector(collector)) } - super::metric::SegmentStatsType::Max => { + SegmentStatsType::Max => { IntermediateMetricResult::Max(IntermediateMax::from_collector(collector)) } - super::metric::SegmentStatsType::Min => { + SegmentStatsType::Min => { IntermediateMetricResult::Min(IntermediateMin::from_collector(collector)) } - super::metric::SegmentStatsType::Stats => { - IntermediateMetricResult::Stats(collector.stats) - } - super::metric::SegmentStatsType::Sum => { + SegmentStatsType::Stats => IntermediateMetricResult::Stats(collector.stats), + SegmentStatsType::Sum => { IntermediateMetricResult::Sum(IntermediateSum::from_collector(collector)) } }, diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 32ee645f7..7951de1d4 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,10 +1,17 @@ use columnar::Column; use serde::{Deserialize, Serialize}; -use crate::aggregation::f64_from_fastfield_u64; +use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResults, IntermediateMetricResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::{f64_from_fastfield_u64, VecWithNames}; use crate::schema::Type; use crate::{DocId, TantivyError}; +use super::*; + /// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// are extracted from the aggregated documents. /// See [`Stats`] for returned statistics. @@ -171,6 +178,66 @@ impl SegmentStatsCollector { } } +impl SegmentAggregationCollector for SegmentStatsCollector { + fn into_intermediate_aggregations_result( + self: Box, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result { + let name = agg_with_accessor.metrics.keys[0].to_string(); + + let intermediate_metric_result = match self.collecting_for { + SegmentStatsType::Average => { + IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) + } + SegmentStatsType::Count => { + IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) + } + SegmentStatsType::Max => { + IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)) + } + SegmentStatsType::Min => { + IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)) + } + SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats), + SegmentStatsType::Sum => { + IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)) + } + }; + + let metrics = Some(VecWithNames::from_entries(vec![( + name, + intermediate_metric_result, + )])); + + Ok(IntermediateAggregationResults { + metrics, + buckets: None, + }) + } + + fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result<()> { + let accessor = &agg_with_accessor.metrics.values[0].accessor; + for val in accessor.values(doc) { + let val1 = f64_from_fastfield_u64(val, &self.field_type); + self.stats.collect(val1); + } + + Ok(()) + } + + fn flush_staged_docs( + &mut self, + _agg_with_accessor: &AggregationsWithAccessor, + _force_flush: bool, + ) -> crate::Result<()> { + Ok(()) + } +} + #[cfg(test)] mod tests { diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 5247b9684..bf417471d 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -182,7 +182,7 @@ use crate::schema::Type; /// Represents an associative array `(key => values)` in a very efficient manner. #[derive(Clone, PartialEq, Serialize, Deserialize)] pub(crate) struct VecWithNames { - values: Vec, + pub(crate) values: Vec, keys: Vec, } @@ -1396,7 +1396,7 @@ mod tests { } #[bench] - fn bench_aggregation_terms_many(b: &mut Bencher) { + fn bench_aggregation_terms_many2(b: &mut Bencher) { let index = get_test_index_bench(false).unwrap(); let reader = index.reader().unwrap(); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 269095890..a56a1a377 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -25,15 +25,90 @@ use crate::{DocId, TantivyError}; pub(crate) const DOC_BLOCK_SIZE: usize = 64; pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; -#[derive(Clone, PartialEq)] -pub(crate) struct SegmentAggregationResultsCollector { +pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { + fn into_intermediate_aggregations_result( + self: Box, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result; + + fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result<()>; + + fn flush_staged_docs( + &mut self, + agg_with_accessor: &AggregationsWithAccessor, + force_flush: bool, + ) -> crate::Result<()>; +} + +pub(crate) trait CollectorClone { + fn clone_box(&self) -> Box; +} + +impl CollectorClone for T +where + T: 'static + SegmentAggregationCollector + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + +pub(crate) fn build_segment_agg_collector( + req: &AggregationsWithAccessor, +) -> crate::Result> { + // Single metric special case + if req.buckets.is_empty() && req.metrics.len() == 1 { + let req = &req.metrics.values[0]; + let stats_collector = match &req.metric { + MetricAggregation::Average(AverageAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average) + } + MetricAggregation::Count(CountAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count) + } + MetricAggregation::Max(MaxAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max) + } + MetricAggregation::Min(MinAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min) + } + MetricAggregation::Stats(StatsAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats) + } + MetricAggregation::Sum(SumAggregation { .. }) => { + SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum) + } + }; + + return Ok(Box::new(stats_collector)); + } + + let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?; + Ok(Box::new(agg)) +} + +#[derive(Clone)] +/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which +/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one +/// and can provide specialized versions instead, that remove some of its overhead. +pub(crate) struct GenericSegmentAggregationResultsCollector { pub(crate) metrics: Option>, pub(crate) buckets: Option>, staged_docs: DocBlock, num_staged_docs: usize, } -impl Default for SegmentAggregationResultsCollector { +impl Default for GenericSegmentAggregationResultsCollector { fn default() -> Self { Self { metrics: Default::default(), @@ -44,7 +119,7 @@ impl Default for SegmentAggregationResultsCollector { } } -impl Debug for SegmentAggregationResultsCollector { +impl Debug for GenericSegmentAggregationResultsCollector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SegmentAggregationResultsCollector") .field("metrics", &self.metrics) @@ -55,7 +130,74 @@ impl Debug for SegmentAggregationResultsCollector { } } -impl SegmentAggregationResultsCollector { +impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { + fn into_intermediate_aggregations_result( + self: Box, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result { + let buckets = if let Some(buckets) = self.buckets { + let entries = buckets + .into_iter() + .zip(agg_with_accessor.buckets.values()) + .map(|((key, bucket), acc)| Ok((key, bucket.into_intermediate_bucket_result(acc)?))) + .collect::>>()?; + Some(VecWithNames::from_entries(entries)) + } else { + None + }; + let metrics = self.metrics.map(VecWithNames::from_other); + + Ok(IntermediateAggregationResults { metrics, buckets }) + } + + fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &AggregationsWithAccessor, + ) -> crate::Result<()> { + self.staged_docs[self.num_staged_docs] = doc; + self.num_staged_docs += 1; + if self.num_staged_docs == self.staged_docs.len() { + self.flush_staged_docs(agg_with_accessor, false)?; + } + Ok(()) + } + + fn flush_staged_docs( + &mut self, + agg_with_accessor: &AggregationsWithAccessor, + force_flush: bool, + ) -> crate::Result<()> { + if self.num_staged_docs == 0 { + return Ok(()); + } + if let Some(metrics) = &mut self.metrics { + for (collector, agg_with_accessor) in + metrics.values_mut().zip(agg_with_accessor.metrics.values()) + { + collector + .collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor); + } + } + + if let Some(buckets) = &mut self.buckets { + for (collector, agg_with_accessor) in + buckets.values_mut().zip(agg_with_accessor.buckets.values()) + { + collector.collect_block( + &self.staged_docs[..self.num_staged_docs], + agg_with_accessor, + force_flush, + )?; + } + } + + self.num_staged_docs = 0; + Ok(()) + } +} + +impl GenericSegmentAggregationResultsCollector { pub fn into_intermediate_aggregations_result( self, agg_with_accessor: &AggregationsWithAccessor, @@ -106,60 +248,13 @@ impl SegmentAggregationResultsCollector { } else { Some(VecWithNames::from_entries(buckets)) }; - Ok(SegmentAggregationResultsCollector { + Ok(GenericSegmentAggregationResultsCollector { metrics, buckets, staged_docs: [0; DOC_BLOCK_SIZE], num_staged_docs: 0, }) } - - #[inline] - pub(crate) fn collect( - &mut self, - doc: crate::DocId, - agg_with_accessor: &AggregationsWithAccessor, - ) -> crate::Result<()> { - self.staged_docs[self.num_staged_docs] = doc; - self.num_staged_docs += 1; - if self.num_staged_docs == self.staged_docs.len() { - self.flush_staged_docs(agg_with_accessor, false)?; - } - Ok(()) - } - - pub(crate) fn flush_staged_docs( - &mut self, - agg_with_accessor: &AggregationsWithAccessor, - force_flush: bool, - ) -> crate::Result<()> { - if self.num_staged_docs == 0 { - return Ok(()); - } - if let Some(metrics) = &mut self.metrics { - for (collector, agg_with_accessor) in - metrics.values_mut().zip(agg_with_accessor.metrics.values()) - { - collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor); - } - } - - if let Some(buckets) = &mut self.buckets { - for (collector, agg_with_accessor) in - buckets.values_mut().zip(agg_with_accessor.buckets.values()) - { - collector.collect_block( - &self.staged_docs[..self.num_staged_docs], - agg_with_accessor, - force_flush, - )?; - } - } - - self.num_staged_docs = 0; - Ok(()) - } } #[derive(Clone, Debug, PartialEq)] @@ -215,7 +310,7 @@ impl SegmentMetricResultCollector { /// segments. /// The typical structure of Map is not suitable during collection for performance /// reasons. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) enum SegmentBucketResultCollector { Range(SegmentRangeCollector), Histogram(Box), @@ -243,12 +338,7 @@ impl SegmentBucketResultCollector { pub fn from_req_and_validate(req: &BucketAggregationWithAccessor) -> crate::Result { match &req.bucket_agg { BucketAggregationType::Terms(terms_req) => Ok(Self::Terms(Box::new( - SegmentTermCollector::from_req_and_validate( - terms_req, - &req.sub_aggregation, - req.field_type, - &req.accessor, - )?, + SegmentTermCollector::from_req_and_validate(terms_req, &req.sub_aggregation)?, ))), BucketAggregationType::Range(range_req) => { Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(