diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 7004b46e0..2f47bc0e4 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -71,6 +71,7 @@ fn bench_agg(mut group: InputGroup) { register!(group, histogram); register!(group, histogram_hard_bounds); register!(group, histogram_with_avg_sub_agg); + register!(group, histogram_with_term_agg_few); register!(group, avg_and_range_with_avg_sub_agg); group.run(); @@ -339,6 +340,17 @@ fn histogram_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } +fn histogram_with_term_agg_few(index: &Index) { + let agg_req = json!({ + "rangef64": { + "histogram": { "field": "score_f64", "interval": 10 }, + "aggs": { + "my_texts": { "terms": { "field": "text_few_terms" } } + } + } + }); + execute_agg(index, agg_req); +} fn avg_and_range_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "rangef64": { diff --git a/src/aggregation/README.md b/src/aggregation/README.md index c96454d07..c0067dcfd 100644 --- a/src/aggregation/README.md +++ b/src/aggregation/README.md @@ -20,17 +20,16 @@ Contains all metric aggregations, like average aggregation. Metric aggregations #### agg_req agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests. -#### agg_req_with_accessor -agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are +#### agg_data +agg_data contains the users aggregation request enriched with fast field accessors etc, which are used during collection. #### segment_agg_result segment_agg_result contains the aggregation result tree, which is used for collection of a segment. -The tree from agg_req_with_accessor is passed during collection. +agg_data is passed during collection. #### intermediate_agg_result intermediate_agg_result contains the aggregation tree for merging with other trees. #### agg_result agg_result contains the final aggregation tree. - diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs new file mode 100644 index 000000000..9b0827e6b --- /dev/null +++ b/src/aggregation/agg_data.rs @@ -0,0 +1,894 @@ +use columnar::{Column, ColumnType}; +use serde::Serialize; + +use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; +use crate::aggregation::agg_req_with_accessor::{ + get_all_ff_reader_or_empty, get_dynamic_columns, get_ff_reader, get_missing_val_as_u64_lenient, + get_numeric_or_date_column_types, +}; +use crate::aggregation::bucket::{ + HistogramAggReqData, HistogramBounds, MissingTermAggReqData, RangeAggReqData, + SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, + TermsAggReqData, TermsAggregation, TermsAggregationInternal, +}; +use crate::aggregation::metric::{ + AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, + ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation, + SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector, + SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, + TopHitsSegmentCollector, +}; +use crate::aggregation::segment_agg_result::{ + GenericSegmentAggregationResultsCollector, SegmentAggregationCollector, +}; +use crate::aggregation::{f64_to_fastfield_u64, AggregationLimitsGuard, Key}; +use crate::{SegmentOrdinal, SegmentReader}; + +#[derive(Default)] +/// Datastructure holding all request data for executing aggregations on a segment. +/// It is passed to the collectors during collection. +pub struct AggregationsSegmentCtx { + /// Request data for each aggregation type. + pub per_request: PerRequestAggSegCtx, + pub limits: AggregationLimitsGuard, +} + +impl AggregationsSegmentCtx { + pub(crate) fn push_term_req_data(&mut self, data: TermsAggReqData) -> usize { + self.per_request.term_req_data.push(Some(Box::new(data))); + self.per_request.term_req_data.len() - 1 + } + pub(crate) fn push_cardinality_req_data(&mut self, data: CardinalityAggReqData) -> usize { + self.per_request.cardinality_req_data.push(data); + self.per_request.cardinality_req_data.len() - 1 + } + pub(crate) fn push_metric_req_data(&mut self, data: MetricAggReqData) -> usize { + self.per_request.stats_metric_req_data.push(data); + self.per_request.stats_metric_req_data.len() - 1 + } + pub(crate) fn push_top_hits_req_data(&mut self, data: TopHitsAggReqData) -> usize { + self.per_request.top_hits_req_data.push(data); + self.per_request.top_hits_req_data.len() - 1 + } + pub(crate) fn push_missing_term_req_data(&mut self, data: MissingTermAggReqData) -> usize { + self.per_request.missing_term_req_data.push(data); + self.per_request.missing_term_req_data.len() - 1 + } + pub(crate) fn push_histogram_req_data(&mut self, data: HistogramAggReqData) -> usize { + self.per_request + .histogram_req_data + .push(Some(Box::new(data))); + self.per_request.histogram_req_data.len() - 1 + } + pub(crate) fn push_range_req_data(&mut self, data: RangeAggReqData) -> usize { + self.per_request.range_req_data.push(Some(Box::new(data))); + self.per_request.range_req_data.len() - 1 + } + + #[inline] + pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData { + self.per_request.term_req_data[idx] + .as_deref() + .expect("term_req_data slot is empty (taken)") + } + #[inline] + pub(crate) fn get_cardinality_req_data(&self, idx: usize) -> &CardinalityAggReqData { + &self.per_request.cardinality_req_data[idx] + } + #[inline] + pub(crate) fn get_metric_req_data(&self, idx: usize) -> &MetricAggReqData { + &self.per_request.stats_metric_req_data[idx] + } + #[inline] + pub(crate) fn get_top_hits_req_data(&self, idx: usize) -> &TopHitsAggReqData { + &self.per_request.top_hits_req_data[idx] + } + #[inline] + pub(crate) fn get_missing_term_req_data(&self, idx: usize) -> &MissingTermAggReqData { + &self.per_request.missing_term_req_data[idx] + } + #[inline] + pub(crate) fn get_histogram_req_data(&self, idx: usize) -> &HistogramAggReqData { + self.per_request.histogram_req_data[idx] + .as_deref() + .expect("histogram_req_data slot is empty (taken)") + } + #[inline] + pub(crate) fn get_range_req_data(&self, idx: usize) -> &RangeAggReqData { + self.per_request.range_req_data[idx] + .as_deref() + .expect("range_req_data slot is empty (taken)") + } + + // ---------- mutable getters ---------- + + #[inline] + pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { + self.per_request.term_req_data[idx] + .as_deref_mut() + .expect("term_req_data slot is empty (taken)") + } + #[inline] + pub(crate) fn get_cardinality_req_data_mut( + &mut self, + idx: usize, + ) -> &mut CardinalityAggReqData { + &mut self.per_request.cardinality_req_data[idx] + } + #[inline] + pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData { + &mut self.per_request.stats_metric_req_data[idx] + } + #[inline] + pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { + self.per_request.histogram_req_data[idx] + .as_deref_mut() + .expect("histogram_req_data slot is empty (taken)") + } + + // ---------- take / put (terms, histogram, range) ---------- + + /// Move out the boxed Terms request at `idx`, leaving `None`. + #[inline] + pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box { + self.per_request.term_req_data[idx] + .take() + .expect("term_req_data slot is empty (taken)") + } + + /// Put back a Terms request into an empty slot at `idx`. + #[inline] + pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box) { + debug_assert!(self.per_request.term_req_data[idx].is_none()); + self.per_request.term_req_data[idx] = Some(value); + } + + /// Move out the boxed Histogram request at `idx`, leaving `None`. + #[inline] + pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box { + self.per_request.histogram_req_data[idx] + .take() + .expect("histogram_req_data slot is empty (taken)") + } + + /// Put back a Histogram request into an empty slot at `idx`. + #[inline] + pub(crate) fn put_back_histogram_req_data( + &mut self, + idx: usize, + value: Box, + ) { + debug_assert!(self.per_request.histogram_req_data[idx].is_none()); + self.per_request.histogram_req_data[idx] = Some(value); + } + + /// Move out the boxed Range request at `idx`, leaving `None`. + #[inline] + pub(crate) fn take_range_req_data(&mut self, idx: usize) -> Box { + self.per_request.range_req_data[idx] + .take() + .expect("range_req_data slot is empty (taken)") + } + + /// Put back a Range request into an empty slot at `idx`. + #[inline] + pub(crate) fn put_back_range_req_data(&mut self, idx: usize, value: Box) { + debug_assert!(self.per_request.range_req_data[idx].is_none()); + self.per_request.range_req_data[idx] = Some(value); + } +} + +/// Each type of aggregation has its own request data struct. +/// This struct holds all request data to execute the aggregation request on a single segment. +/// +/// The request tree is represented by `agg_tree` which contains nodes with references +/// into the various request ata vectors. +#[derive(Default)] +pub struct PerRequestAggSegCtx { + // Box for cheap take/put - Only necessary for bucket aggs that have sub-aggregations + /// TermsAggReqData contains the request data for a terms aggregation. + pub term_req_data: Vec>>, + /// HistogramAggReqData contains the request data for a histogram aggregation. + pub histogram_req_data: Vec>>, + /// RangeAggReqData contains the request data for a range aggregation. + pub range_req_data: Vec>>, + /// Shared by avg, min, max, sum, stats, extended_stats, count + pub stats_metric_req_data: Vec, + /// CardinalityAggReqData contains the request data for a cardinality aggregation. + pub cardinality_req_data: Vec, + /// TopHitsAggReqData contains the request data for a top_hits aggregation. + pub top_hits_req_data: Vec, + /// MissingTermAggReqData contains the request data for a missing term aggregation. + pub missing_term_req_data: Vec, + + /// Request tree used to build collectors. + pub agg_tree: Vec, +} + +impl PerRequestAggSegCtx { + pub fn get_name(&self, node: &AggRefNode) -> &str { + let idx = node.idx_in_req_data; + let kind = node.kind; + match kind { + AggKind::Terms => self.term_req_data[idx] + .as_deref() + .expect("term_req_data slot is empty (taken)") + .name + .as_str(), + AggKind::Cardinality => &self.cardinality_req_data[idx].name, + AggKind::StatsKind(_) => &self.stats_metric_req_data[idx].name, + AggKind::TopHits => &self.top_hits_req_data[idx].name, + AggKind::MissingTerm => &self.missing_term_req_data[idx].name, + AggKind::Histogram => self.histogram_req_data[idx] + .as_deref() + .expect("histogram_req_data slot is empty (taken)") + .name + .as_str(), + AggKind::DateHistogram => self.histogram_req_data[idx] + .as_deref() + .expect("histogram_req_data slot is empty (taken)") + .name + .as_str(), + AggKind::Range => self.range_req_data[idx] + .as_deref() + .expect("range_req_data slot is empty (taken)") + .name + .as_str(), + } + } + + /// Convert the aggregation tree into a serializable struct representation. + /// Each node contains: { name, kind, children }. + pub fn get_view_tree(&self) -> Vec { + fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode { + let mut children: Vec = + node.children.iter().map(|c| node_to_view(c, pr)).collect(); + children.sort_by_key(|v| serde_json::to_string(v).unwrap()); + AggTreeViewNode { + name: pr.get_name(node).to_string(), + kind: node.kind.as_str().to_string(), + children, + } + } + + let mut roots: Vec = self + .agg_tree + .iter() + .map(|n| node_to_view(n, self)) + .collect(); + roots.sort_by_key(|v| serde_json::to_string(v).unwrap()); + roots + } +} + +pub(crate) fn build_segment_agg_collectors_root( + req: &mut AggregationsSegmentCtx, +) -> crate::Result> { + build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) +} + +pub(crate) fn build_segment_agg_collectors( + req: &mut AggregationsSegmentCtx, + nodes: &[AggRefNode], +) -> crate::Result> { + let mut collectors = Vec::new(); + for node in nodes.iter() { + collectors.push(build_segment_agg_collector(req, node)?); + } + + // Single collector special case + if collectors.len() == 1 { + return Ok(collectors.pop().unwrap()); + } + let agg = GenericSegmentAggregationResultsCollector { aggs: collectors }; + Ok(Box::new(agg)) +} + +pub(crate) fn build_segment_agg_collector( + req: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + match node.kind { + AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate( + req, node, + )?)), + AggKind::MissingTerm => { + let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data]; + if req_data.accessors.is_empty() { + return Err(crate::TantivyError::InternalError( + "MissingTerm aggregation requires at least one field accessor.".to_string(), + )); + } + Ok(Box::new(TermMissingAgg::new(req, node)?)) + } + AggKind::Cardinality => { + let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data); + Ok(Box::new(SegmentCardinalityCollector::from_req( + req_data.column_type, + node.idx_in_req_data, + ))) + } + AggKind::StatsKind(stats_type) => { + let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data]; + match stats_type { + StatsType::Sum + | StatsType::Average + | StatsType::Count + | StatsType::Max + | StatsType::Min + | StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( + node.idx_in_req_data, + ))), + StatsType::ExtendedStats(sigma) => { + Ok(Box::new(SegmentExtendedStatsCollector::from_req( + req_data.field_type, + sigma, + node.idx_in_req_data, + req_data.missing, + ))) + } + StatsType::Percentiles => Ok(Box::new( + SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?, + )), + } + } + AggKind::TopHits => { + let req_data = &mut req.per_request.top_hits_req_data[node.idx_in_req_data]; + Ok(Box::new(TopHitsSegmentCollector::from_req( + &req_data.req, + node.idx_in_req_data, + req_data.segment_ordinal, + ))) + } + AggKind::Histogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + req, node, + )?)), + AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( + req, node, + )?)), + AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( + req, node, + )?)), + } +} + +#[derive(Debug, Clone)] +pub struct AggRefNode { + pub kind: AggKind, + pub idx_in_req_data: usize, + pub children: Vec, +} +impl AggRefNode { + pub fn get_sub_agg(&self, name: &str, pr: &PerRequestAggSegCtx) -> Option<&AggRefNode> { + self.children + .iter() + .find(|&child| pr.get_name(child) == name) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum AggKind { + Terms, + Cardinality, + /// One of: Statistics, Average, Min, Max, Sum, Count, Stats, ExtendedStats + StatsKind(StatsType), + TopHits, + MissingTerm, + Histogram, + DateHistogram, + Range, +} + +impl AggKind { + #[cfg_attr(not(test), allow(dead_code))] + fn as_str(&self) -> &'static str { + match self { + AggKind::Terms => "Terms", + AggKind::Cardinality => "Cardinality", + AggKind::StatsKind(_) => "Metric", + AggKind::TopHits => "TopHits", + AggKind::MissingTerm => "MissingTerm", + AggKind::Histogram => "Histogram", + AggKind::DateHistogram => "DateHistogram", + AggKind::Range => "Range", + } + } +} + +// ReqData structs moved to their respective collector modules + +/// Build AggregationsData by walking the request tree. +pub(crate) fn build_aggregations_data_from_req( + aggs: &Aggregations, + reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, + limits: AggregationLimitsGuard, +) -> crate::Result { + let mut data = AggregationsSegmentCtx { + per_request: Default::default(), + limits, + }; + + for (name, agg) in aggs.iter() { + let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?; + data.per_request.agg_tree.extend(nodes); + } + Ok(data) +} + +fn build_nodes( + agg_name: &str, + req: &Aggregation, + reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, + data: &mut AggregationsSegmentCtx, +) -> crate::Result> { + use AggregationVariants::*; + match &req.agg { + Range(range_req) => { + let (accessor, field_type) = get_ff_reader( + reader, + &range_req.field, + Some(get_numeric_or_date_column_types()), + )?; + let idx_in_req_data = data.push_range_req_data(RangeAggReqData { + accessor, + field_type, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + req: range_req.clone(), + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::Range, + idx_in_req_data, + children, + }]) + } + Histogram(histo_req) => { + let (accessor, field_type) = get_ff_reader( + reader, + &histo_req.field, + Some(get_numeric_or_date_column_types()), + )?; + let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { + accessor, + field_type, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + sub_aggregation_blueprint: None, + req: histo_req.clone(), + is_date_histogram: false, + bounds: HistogramBounds { + min: f64::MIN, + max: f64::MAX, + }, + offset: 0.0, + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::Histogram, + idx_in_req_data, + children, + }]) + } + DateHistogram(date_req) => { + let (accessor, field_type) = + get_ff_reader(reader, &date_req.field, Some(&[ColumnType::DateTime]))?; + // Convert to histogram request, normalize to ns precision + let mut histo_req = date_req.to_histogram_req()?; + histo_req.normalize_date_time(); + let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { + accessor, + field_type, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + sub_aggregation_blueprint: None, + req: histo_req, + is_date_histogram: true, + bounds: HistogramBounds { + min: f64::MIN, + max: f64::MAX, + }, + offset: 0.0, + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::DateHistogram, + idx_in_req_data, + children, + }]) + } + Terms(terms_req) => build_terms_or_cardinality_nodes( + agg_name, + &terms_req.field, + &terms_req.missing, + reader, + segment_ordinal, + data, + &req.sub_aggregation, + TermsOrCardinalityRequest::Terms(terms_req.clone()), + ), + Cardinality(card_req) => build_terms_or_cardinality_nodes( + agg_name, + &card_req.field, + &card_req.missing, + reader, + segment_ordinal, + data, + &req.sub_aggregation, + TermsOrCardinalityRequest::Cardinality(card_req.clone()), + ), + Average(AverageAggregation { field, missing, .. }) + | Max(MaxAggregation { field, missing, .. }) + | Min(MinAggregation { field, missing, .. }) + | Stats(StatsAggregation { field, missing, .. }) + | ExtendedStats(ExtendedStatsAggregation { field, missing, .. }) + | Sum(SumAggregation { field, missing, .. }) + | Count(CountAggregation { field, missing, .. }) => { + let allowed_column_types = if matches!(&req.agg, Count(_)) { + Some( + &[ + ColumnType::I64, + ColumnType::U64, + ColumnType::F64, + ColumnType::Str, + ColumnType::DateTime, + ColumnType::Bool, + ColumnType::IpAddr, + ][..], + ) + } else { + Some(get_numeric_or_date_column_types()) + }; + let collecting_for = match &req.agg { + Average(_) => StatsType::Average, + Max(_) => StatsType::Max, + Min(_) => StatsType::Min, + Stats(_) => StatsType::Stats, + ExtendedStats(req) => StatsType::ExtendedStats(req.sigma), + Sum(_) => StatsType::Sum, + Count(_) => StatsType::Count, + _ => { + return Err(crate::TantivyError::InvalidArgument( + "Internal error: unexpected aggregation type in metric aggregation \ + handling." + .to_string(), + )) + } + }; + let (accessor, field_type) = get_ff_reader(reader, field, allowed_column_types)?; + let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { + accessor, + field_type, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + collecting_for, + missing: *missing, + missing_u64: (*missing).and_then(|m| f64_to_fastfield_u64(m, &field_type)), + is_number_or_date_type: matches!( + field_type, + ColumnType::I64 | ColumnType::U64 | ColumnType::F64 | ColumnType::DateTime + ), + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::StatsKind(collecting_for), + idx_in_req_data, + children, + }]) + } + // Percentiles handled as Metric as well + AggregationVariants::Percentiles(percentiles_req) => { + percentiles_req.validate()?; + let (accessor, field_type) = get_ff_reader( + reader, + percentiles_req.field_name(), + Some(get_numeric_or_date_column_types()), + )?; + let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { + accessor, + field_type, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + collecting_for: StatsType::Percentiles, + missing: percentiles_req.missing, + missing_u64: percentiles_req + .missing + .and_then(|m| f64_to_fastfield_u64(m, &field_type)), + is_number_or_date_type: matches!( + field_type, + ColumnType::I64 | ColumnType::U64 | ColumnType::F64 | ColumnType::DateTime + ), + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::StatsKind(StatsType::Percentiles), + idx_in_req_data, + children, + }]) + } + AggregationVariants::TopHits(top_hits_req) => { + let mut top_hits = top_hits_req.clone(); + top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?; + let accessors: Vec<(Column, ColumnType)> = top_hits + .field_names() + .iter() + .map(|field| get_ff_reader(reader, field, Some(get_numeric_or_date_column_types()))) + .collect::>()?; + + let value_accessors = top_hits + .value_field_names() + .iter() + .map(|field_name| { + Ok(( + field_name.to_string(), + get_dynamic_columns(reader, field_name)?, + )) + }) + .collect::>()?; + + let idx_in_req_data = data.push_top_hits_req_data(TopHitsAggReqData { + accessors, + value_accessors, + segment_ordinal, + name: agg_name.to_string(), + req: top_hits.clone(), + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::TopHits, + idx_in_req_data, + children, + }]) + } + } +} + +fn build_children( + aggs: &Aggregations, + reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, + data: &mut AggregationsSegmentCtx, +) -> crate::Result> { + let mut children = Vec::new(); + for (name, agg) in aggs.iter() { + children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?); + } + Ok(children) +} + +fn get_term_agg_accessors( + reader: &SegmentReader, + field_name: &str, + missing: &Option, +) -> crate::Result, ColumnType)>> { + let allowed_column_types = [ + ColumnType::I64, + ColumnType::U64, + ColumnType::F64, + ColumnType::Str, + ColumnType::DateTime, + ColumnType::Bool, + ColumnType::IpAddr, + ]; + + // In case the column is empty we want the shim column to match the missing type + let fallback_type = missing + .as_ref() + .map(|missing| match missing { + Key::Str(_) => ColumnType::Str, + Key::F64(_) => ColumnType::F64, + Key::I64(_) => ColumnType::I64, + Key::U64(_) => ColumnType::U64, + }) + .unwrap_or(ColumnType::U64); + + let column_and_types = get_all_ff_reader_or_empty( + reader, + field_name, + Some(&allowed_column_types), + fallback_type, + )?; + + Ok(column_and_types) +} + +enum TermsOrCardinalityRequest { + Terms(TermsAggregation), + Cardinality(CardinalityAggregationReq), +} +impl TermsOrCardinalityRequest { + fn as_terms(&self) -> Option<&TermsAggregation> { + match self { + TermsOrCardinalityRequest::Terms(t) => Some(t), + _ => None, + } + } +} + +#[allow(clippy::too_many_arguments)] +fn build_terms_or_cardinality_nodes( + agg_name: &str, + field_name: &str, + missing: &Option, + reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, + data: &mut AggregationsSegmentCtx, + sub_aggs: &Aggregations, + req: TermsOrCardinalityRequest, +) -> crate::Result> { + let mut nodes = Vec::new(); + + let str_dict_column = reader.fast_fields().str(field_name)?; + + let column_and_types = get_term_agg_accessors(reader, field_name, missing)?; + + // Special handling when missing + multi column or incompatible type on text/date. + let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some(); + let text_on_non_text_col = column_and_types.len() == 1 + && column_and_types[0].1 != ColumnType::Str + && matches!(missing, Some(Key::Str(_))); + + let use_special_missing_agg = missing_and_more_than_one_col || text_on_non_text_col; + + // If special missing handling is required, build a MissingTerm node that carries all + // accessors (across any column types) for existence checks. + if use_special_missing_agg { + let fallback_type = missing + .as_ref() + .map(|missing| match missing { + Key::Str(_) => ColumnType::Str, + Key::F64(_) => ColumnType::F64, + Key::I64(_) => ColumnType::I64, + Key::U64(_) => ColumnType::U64, + }) + .unwrap_or(ColumnType::U64); + let all_accessors = get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)? + .into_iter() + .collect::>(); + // This case only happens when we have term aggregation, or we fail + let req = req.as_terms().cloned().ok_or_else(|| { + crate::TantivyError::InvalidArgument( + "Cardinality aggregation with missing on non-text/number field is not supported." + .to_string(), + ) + })?; + + let children = build_children(sub_aggs, reader, segment_ordinal, data)?; + let idx_in_req_data = data.push_missing_term_req_data(MissingTermAggReqData { + accessors: all_accessors, + name: agg_name.to_string(), + req, + }); + nodes.push(AggRefNode { + kind: AggKind::MissingTerm, + idx_in_req_data, + children, + }); + } + + // Add one node per accessor to mirror previous behavior and allow per-type missing handling. + for (accessor, column_type) in column_and_types { + let missing_value_for_accessor = if use_special_missing_agg { + None + } else if let Some(m) = missing.as_ref() { + get_missing_val_as_u64_lenient(column_type, m, field_name)? + } else { + None + }; + + let children = build_children(sub_aggs, reader, segment_ordinal, data)?; + let (idx, kind) = match req { + TermsOrCardinalityRequest::Terms(ref req) => { + let idx_in_req_data = data.push_term_req_data(TermsAggReqData { + accessor, + column_type, + str_dict_column: str_dict_column.clone(), + missing_value_for_accessor, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + field_type: column_type, + req: TermsAggregationInternal::from_req(req), + // Will be filled later when building collectors + sub_aggregation_blueprint: None, + sug_aggregations: sub_aggs.clone(), + }); + (idx_in_req_data, AggKind::Terms) + } + TermsOrCardinalityRequest::Cardinality(ref req) => { + let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData { + accessor, + column_type, + str_dict_column: str_dict_column.clone(), + missing_value_for_accessor, + column_block_accessor: Default::default(), + name: agg_name.to_string(), + req: req.clone(), + }); + (idx_in_req_data, AggKind::Cardinality) + } + }; + nodes.push(AggRefNode { + kind, + idx_in_req_data: idx, + children, + }); + } + + Ok(nodes) +} + +/// Convert the aggregation tree to something serializable and easy to read. +#[derive(Serialize, Debug, Clone, PartialEq, Eq)] +pub struct AggTreeViewNode { + pub name: String, + pub kind: String, + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub children: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::tests::get_test_index_2_segments; + + fn agg_from_json(val: serde_json::Value) -> crate::aggregation::agg_req::Aggregation { + serde_json::from_value(val).unwrap() + } + + #[test] + fn test_tree_roots_and_expansion_terms_missing_on_numeric() -> crate::Result<()> { + let index = get_test_index_2_segments(true)?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let seg_reader = searcher.segment_reader(0u32); + + // Build request with: + // 1) Terms on numeric field with missing as string => expands to MissingTerm + Terms + // 2) Avg metric + // 3) Terms on string with child histogram + let terms_score_missing = agg_from_json(json!({ + "terms": {"field": "score", "missing": "NA"} + })); + let avg_score = agg_from_json(json!({ + "avg": {"field": "score"} + })); + let terms_string_with_child = agg_from_json(json!({ + "terms": {"field": "string_id"}, + "aggs": { + "histo": {"histogram": {"field": "score", "interval": 10.0}} + } + })); + + let aggs: Aggregations = vec![ + ("t_score_missing_str".to_string(), terms_score_missing), + ("avg_score".to_string(), avg_score), + ("terms_string".to_string(), terms_string_with_child), + ] + .into_iter() + .collect(); + + let data = build_aggregations_data_from_req(&aggs, seg_reader, 0u32, Default::default())?; + let printed_nodes = data.per_request.get_view_tree(); + let printed = serde_json::to_value(&printed_nodes).unwrap(); + + let expected = json!([ + {"name": "avg_score", "kind": "Metric"}, + {"name": "t_score_missing_str", "kind": "MissingTerm"}, + {"name": "t_score_missing_str", "kind": "Terms"}, + {"name": "terms_string", "kind": "Terms", "children": [ + {"name": "histo", "kind": "Histogram"} + ]} + ]); + assert_eq!( + printed, + expected, + "tree json:\n{}", + serde_json::to_string_pretty(&printed).unwrap() + ); + + Ok(()) + } +} diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 7d06fae2b..e2a85c8fb 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -208,13 +208,6 @@ impl AggregationVariants { _ => None, } } - pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> { - match &self { - AggregationVariants::TopHits(top_hits) => Some(top_hits), - _ => None, - } - } - pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { match &self { AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 5b5bfb6d7..eb44a734b 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,354 +1,11 @@ //! This will enhance the request tree with access to the fastfield and metadata. -use std::collections::HashMap; use std::io; -use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn}; +use columnar::{Column, ColumnType}; -use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; -use super::bucket::{ - DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, -}; -use super::metric::{ - AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, - MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, -}; -use super::segment_agg_result::AggregationLimitsGuard; -use super::VecWithNames; use crate::aggregation::{f64_to_fastfield_u64, Key}; use crate::index::SegmentReader; -use crate::SegmentOrdinal; - -#[derive(Default)] -pub(crate) struct AggregationsWithAccessor { - pub aggs: VecWithNames, -} - -impl AggregationsWithAccessor { - fn from_data(aggs: VecWithNames) -> Self { - Self { aggs } - } - - pub fn is_empty(&self) -> bool { - self.aggs.is_empty() - } -} - -pub struct AggregationWithAccessor { - pub(crate) segment_ordinal: SegmentOrdinal, - /// In general there can be buckets without fast field access, e.g. buckets that are created - /// based on search terms. That is not that case currently, but eventually this needs to be - /// Option or moved. - pub(crate) accessor: Column, - /// Load insert u64 for missing use case - pub(crate) missing_value_for_accessor: Option, - pub(crate) str_dict_column: Option, - pub(crate) field_type: ColumnType, - pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) limits: AggregationLimitsGuard, - pub(crate) column_block_accessor: ColumnBlockAccessor, - /// Used for missing term aggregation, which checks all columns for existence. - /// And also for `top_hits` aggregation, which may sort on multiple fields. - /// By convention the missing aggregation is chosen, when this property is set - /// (instead bein set in `agg`). - /// If this needs to used by other aggregations, we need to refactor this. - // NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type` - // (making them obsolete) But will it have a performance impact? - pub(crate) accessors: Vec<(Column, ColumnType)>, - /// Map field names to all associated column accessors. - /// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`. - pub(crate) value_accessors: HashMap>, - pub(crate) agg: Aggregation, -} - -impl AggregationWithAccessor { - /// May return multiple accessors if the aggregation is e.g. on mixed field types. - fn try_from_agg( - agg: &Aggregation, - sub_aggregation: &Aggregations, - reader: &SegmentReader, - segment_ordinal: SegmentOrdinal, - limits: AggregationLimitsGuard, - ) -> crate::Result> { - let mut agg = agg.clone(); - - let add_agg_with_accessor = |agg: &Aggregation, - accessor: Column, - column_type: ColumnType, - aggs: &mut Vec| - -> crate::Result<()> { - let res = AggregationWithAccessor { - segment_ordinal, - accessor, - accessors: Default::default(), - value_accessors: Default::default(), - field_type: column_type, - sub_aggregation: get_aggs_with_segment_accessor_and_validate( - sub_aggregation, - reader, - segment_ordinal, - &limits, - )?, - agg: agg.clone(), - limits: limits.clone(), - missing_value_for_accessor: None, - str_dict_column: None, - column_block_accessor: Default::default(), - }; - aggs.push(res); - Ok(()) - }; - - let add_agg_with_accessors = |agg: &Aggregation, - accessors: Vec<(Column, ColumnType)>, - aggs: &mut Vec, - value_accessors: HashMap>| - -> crate::Result<()> { - let (accessor, field_type) = accessors.first().expect("at least one accessor"); - let limits = limits.clone(); - let res = AggregationWithAccessor { - segment_ordinal, - // TODO: We should do away with the `accessor` field altogether - accessor: accessor.clone(), - value_accessors, - field_type: *field_type, - accessors, - sub_aggregation: get_aggs_with_segment_accessor_and_validate( - sub_aggregation, - reader, - segment_ordinal, - &limits, - )?, - agg: agg.clone(), - limits, - missing_value_for_accessor: None, - str_dict_column: None, - column_block_accessor: Default::default(), - }; - aggs.push(res); - Ok(()) - }; - - let mut res: Vec = Vec::new(); - use AggregationVariants::*; - - match agg.agg { - Range(RangeAggregation { - field: ref field_name, - .. - }) => { - let (accessor, column_type) = - get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - Histogram(HistogramAggregation { - field: ref field_name, - .. - }) => { - let (accessor, column_type) = - get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - DateHistogram(DateHistogramAggregationReq { - field: ref field_name, - .. - }) => { - let (accessor, column_type) = - // Only DateTime is supported for DateHistogram - get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - Terms(TermsAggregation { - field: ref field_name, - ref missing, - .. - }) - | Cardinality(CardinalityAggregationReq { - field: ref field_name, - ref missing, - .. - }) => { - let str_dict_column = reader.fast_fields().str(field_name)?; - let allowed_column_types = [ - ColumnType::I64, - ColumnType::U64, - ColumnType::F64, - ColumnType::Str, - ColumnType::DateTime, - ColumnType::Bool, - ColumnType::IpAddr, - // ColumnType::Bytes Unsupported - ]; - - // In case the column is empty we want the shim column to match the missing type - let fallback_type = missing - .as_ref() - .map(|missing| match missing { - Key::Str(_) => ColumnType::Str, - Key::F64(_) => ColumnType::F64, - Key::I64(_) => ColumnType::I64, - Key::U64(_) => ColumnType::U64, - }) - .unwrap_or(ColumnType::U64); - let column_and_types = get_all_ff_reader_or_empty( - reader, - field_name, - Some(&allowed_column_types), - fallback_type, - )?; - let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some(); - let text_on_non_text_col = column_and_types.len() == 1 - && column_and_types[0].1.numerical_type().is_some() - && missing - .as_ref() - .map(|m| matches!(m, Key::Str(_))) - .unwrap_or(false); - - // Actually we could convert the text to a number and have the fast path, if it is - // provided in Rfc3339 format. But this use case is probably common - // enough to justify the effort. - let text_on_date_col = column_and_types.len() == 1 - && column_and_types[0].1 == ColumnType::DateTime - && missing - .as_ref() - .map(|m| matches!(m, Key::Str(_))) - .unwrap_or(false); - - let use_special_missing_agg = - missing_and_more_than_one_col || text_on_non_text_col || text_on_date_col; - if use_special_missing_agg { - let column_and_types = - get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?; - - let accessors = column_and_types - .iter() - .map(|c_t| (c_t.0.clone(), c_t.1)) - .collect(); - add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?; - } - - for (accessor, column_type) in column_and_types { - let missing_value_term_agg = if use_special_missing_agg { - None - } else { - missing.clone() - }; - - let missing_value_for_accessor = - if let Some(missing) = missing_value_term_agg.as_ref() { - get_missing_val_as_u64_lenient( - column_type, - missing, - agg.agg.get_fast_field_names()[0], - )? - } else { - None - }; - - let limits = limits.clone(); - let agg = AggregationWithAccessor { - segment_ordinal, - missing_value_for_accessor, - accessor, - accessors: Default::default(), - value_accessors: Default::default(), - field_type: column_type, - sub_aggregation: get_aggs_with_segment_accessor_and_validate( - sub_aggregation, - reader, - segment_ordinal, - &limits, - )?, - agg: agg.clone(), - str_dict_column: str_dict_column.clone(), - limits, - column_block_accessor: Default::default(), - }; - res.push(agg); - } - } - Average(AverageAggregation { - field: ref field_name, - .. - }) - | Max(MaxAggregation { - field: ref field_name, - .. - }) - | Min(MinAggregation { - field: ref field_name, - .. - }) - | Stats(StatsAggregation { - field: ref field_name, - .. - }) - | ExtendedStats(ExtendedStatsAggregation { - field: ref field_name, - .. - }) - | Sum(SumAggregation { - field: ref field_name, - .. - }) => { - let (accessor, column_type) = - get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - Count(CountAggregation { - field: ref field_name, - .. - }) => { - let allowed_column_types = [ - ColumnType::I64, - ColumnType::U64, - ColumnType::F64, - ColumnType::Str, - ColumnType::DateTime, - ColumnType::Bool, - ColumnType::IpAddr, - // ColumnType::Bytes Unsupported - ]; - let (accessor, column_type) = - get_ff_reader(reader, field_name, Some(&allowed_column_types))?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - Percentiles(ref percentiles) => { - let (accessor, column_type) = get_ff_reader( - reader, - percentiles.field_name(), - Some(get_numeric_or_date_column_types()), - )?; - add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; - } - TopHits(ref mut top_hits) => { - top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?; - let accessors: Vec<(Column, ColumnType)> = top_hits - .field_names() - .iter() - .map(|field| { - get_ff_reader(reader, field, Some(get_numeric_or_date_column_types())) - }) - .collect::>()?; - - let value_accessors = top_hits - .value_field_names() - .iter() - .map(|field_name| { - Ok(( - field_name.to_string(), - get_dynamic_columns(reader, field_name)?, - )) - }) - .collect::>()?; - - add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?; - } - }; - - Ok(res) - } -} /// Get the missing value as internal u64 representation /// @@ -357,7 +14,7 @@ impl AggregationWithAccessor { /// we would get from the fast field, when we open it as u64_lenient_for_type. /// /// That way we can use it the same way as if it would come from the fastfield. -fn get_missing_val_as_u64_lenient( +pub(crate) fn get_missing_val_as_u64_lenient( column_type: ColumnType, missing: &Key, field_name: &str, @@ -388,7 +45,7 @@ fn get_missing_val_as_u64_lenient( Ok(missing_val) } -fn get_numeric_or_date_column_types() -> &'static [ColumnType] { +pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] { &[ ColumnType::F64, ColumnType::U64, @@ -397,32 +54,8 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] { ] } -pub(crate) fn get_aggs_with_segment_accessor_and_validate( - aggs: &Aggregations, - reader: &SegmentReader, - segment_ordinal: SegmentOrdinal, - limits: &AggregationLimitsGuard, -) -> crate::Result { - let mut aggss = Vec::new(); - for (key, agg) in aggs.iter() { - let aggs = AggregationWithAccessor::try_from_agg( - agg, - agg.sub_aggregation(), - reader, - segment_ordinal, - limits.clone(), - )?; - for agg in aggs { - aggss.push((key.to_string(), agg)); - } - } - Ok(AggregationsWithAccessor::from_data( - VecWithNames::from_entries(aggss), - )) -} - /// Get fast field reader or empty as default. -fn get_ff_reader( +pub(crate) fn get_ff_reader( reader: &SegmentReader, field_name: &str, allowed_column_types: Option<&[ColumnType]>, @@ -439,7 +72,7 @@ fn get_ff_reader( Ok(ff_field_with_type) } -fn get_dynamic_columns( +pub(crate) fn get_dynamic_columns( reader: &SegmentReader, field_name: &str, ) -> crate::Result> { @@ -455,7 +88,7 @@ fn get_dynamic_columns( /// Get all fast field reader or empty as default. /// /// Is guaranteed to return at least one column. -fn get_all_ff_reader_or_empty( +pub(crate) fn get_all_ff_reader_or_empty( reader: &SegmentReader, field_name: &str, allowed_column_types: Option<&[ColumnType]>, diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 4fe720201..428f2a0a6 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -1,25 +1,48 @@ use std::cmp::Ordering; +use columnar::{Column, ColumnBlockAccessor, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, -}; use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, SegmentAggregationCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; use crate::TantivyError; +/// Contains all information required by the SegmentHistogramCollector to perform the +/// histogram or date_histogram aggregation on a segment. +pub struct HistogramAggReqData { + /// The column accessor to access the fast field values. + pub accessor: Column, + /// The field type of the fast field. + pub field_type: ColumnType, + /// The column block accessor to access the fast field values. + pub column_block_accessor: ColumnBlockAccessor, + /// The name of the aggregation. + pub name: String, + /// The sub aggregation blueprint, used to create sub aggregations for each bucket. + /// Will be filled during initialization of the collector. + pub sub_aggregation_blueprint: Option>, + /// The histogram aggregation request. + pub req: HistogramAggregation, + /// True if this is a date_histogram aggregation. + pub is_date_histogram: bool, + /// The bounds to limit the buckets to. + pub bounds: HistogramBounds, + /// The offset used to calculate the bucket position. + pub offset: f64, +} + /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. /// Each document value is rounded down to its bucket. /// @@ -234,12 +257,12 @@ impl SegmentHistogramBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, sub_aggregation: Option>, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?; + .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; } Ok(IntermediateHistogramBucketEntry { key: self.key, @@ -256,24 +279,20 @@ pub struct SegmentHistogramCollector { /// The buckets containing the aggregation data. buckets: FxHashMap, sub_aggregations: FxHashMap>, - sub_aggregation_blueprint: Option>, - column_type: ColumnType, - interval: f64, - offset: f64, - bounds: HistogramBounds, accessor_idx: usize, } impl SegmentAggregationCollector for SegmentHistogramCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; - - let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; + let name = agg_data + .get_histogram_req_data(self.accessor_idx) + .name + .clone(); + let bucket = self.into_intermediate_bucket_result(agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -283,69 +302,59 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collect_block(&[doc], agg_with_accessor) + self.collect_block(&[doc], agg_data) } #[inline] fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - + let mut req = agg_data.take_histogram_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); - let bounds = self.bounds; - let interval = self.interval; - let offset = self.offset; + let bounds = req.bounds; + let interval = req.req.interval; + let offset = req.offset; let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64; - bucket_agg_accessor + req.column_block_accessor.fetch_block(docs, &req.accessor); + for (doc, val) in req .column_block_accessor - .fetch_block(docs, &bucket_agg_accessor.accessor); - - for (doc, val) in bucket_agg_accessor - .column_block_accessor - .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + .iter_docid_vals(docs, &req.accessor) { - let val = self.f64_from_fastfield_u64(val); - + let val = f64_from_fastfield_u64(val, &req.field_type); let bucket_pos = get_bucket_pos(val); - if bounds.contains(val) { let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); SegmentHistogramBucketEntry { key, doc_count: 0 } }); bucket.doc_count += 1; - if let Some(sub_aggregation_blueprint) = self.sub_aggregation_blueprint.as_mut() { + if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { self.sub_aggregations .entry(bucket_pos) .or_insert_with(|| sub_aggregation_blueprint.clone()) - .collect(doc, &mut bucket_agg_accessor.sub_aggregation)?; + .collect(doc, agg_data)?; } } } + agg_data.put_back_histogram_req_data(self.accessor_idx, req); let mem_delta = self.get_memory_consumption() - mem_pre; if mem_delta > 0 { - bucket_agg_accessor - .limits - .add_memory_consumed(mem_delta as u64)?; + agg_data.limits.add_memory_consumed(mem_delta as u64)?; } Ok(()) } - fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { - let sub_aggregation_accessor = - &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; - + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { for sub_aggregation in self.sub_aggregations.values_mut() { - sub_aggregation.flush(sub_aggregation_accessor)?; + sub_aggregation.flush(agg_data)?; } Ok(()) @@ -362,65 +371,58 @@ impl SegmentHistogramCollector { /// Converts the collector result into a intermediate bucket result. pub fn into_intermediate_bucket_result( self, - agg_with_accessor: &AggregationWithAccessor, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut buckets = Vec::with_capacity(self.buckets.len()); for (bucket_pos, bucket) in self.buckets { let bucket_res = bucket.into_intermediate_bucket_entry( self.sub_aggregations.get(&bucket_pos).cloned(), - &agg_with_accessor.sub_aggregation, + agg_data, ); buckets.push(bucket_res?); } buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key)); + let is_date_agg = agg_data + .get_histogram_req_data(self.accessor_idx) + .field_type + == ColumnType::DateTime; Ok(IntermediateBucketResult::Histogram { buckets, - is_date_agg: self.column_type == ColumnType::DateTime, + is_date_agg, }) } pub(crate) fn from_req_and_validate( - mut req: HistogramAggregation, - sub_aggregation: &mut AggregationsWithAccessor, - field_type: ColumnType, - accessor_idx: usize, + agg_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, ) -> crate::Result { - req.validate()?; - if field_type == ColumnType::DateTime { - req.normalize_date_time(); - } - - let sub_aggregation_blueprint = if sub_aggregation.is_empty() { - None + let blueprint = if !node.children.is_empty() { + Some(build_segment_agg_collectors(agg_data, &node.children)?) } else { - let sub_aggregation = build_segment_agg_collector(sub_aggregation)?; - Some(sub_aggregation) + None }; - - let bounds = req.hard_bounds.unwrap_or(HistogramBounds { + let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data); + req_data.req.validate()?; + if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram { + req_data.req.normalize_date_time(); + } + req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds { min: f64::MIN, max: f64::MAX, }); + req_data.offset = req_data.req.offset.unwrap_or(0.0); + + req_data.sub_aggregation_blueprint = blueprint; Ok(Self { buckets: Default::default(), - column_type: field_type, - interval: req.interval, - offset: req.offset.unwrap_or(0.0), - bounds, sub_aggregations: Default::default(), - sub_aggregation_blueprint, - accessor_idx, + accessor_idx: node.idx_in_req_data, }) } - - #[inline] - fn f64_from_fastfield_u64(&self, val: u64) -> f64 { - f64_from_fastfield_u64(val, &self.column_type) - } } #[inline] diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 96242cc15..a2b092257 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,20 +1,36 @@ use std::fmt::Debug; use std::ops::Range; +use columnar::{Column, ColumnBlockAccessor, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; -use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, SegmentAggregationCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; use crate::TantivyError; +/// Contains all information required by the SegmentRangeCollector to perform the +/// range aggregation on a segment. +pub struct RangeAggReqData { + /// The column accessor to access the fast field values. + pub accessor: Column, + /// The type of the fast field. + pub field_type: ColumnType, + /// The column block accessor to access the fast field values. + pub column_block_accessor: ColumnBlockAccessor, + /// The range aggregation request. + pub req: RangeAggregation, + /// The name of the aggregation. + pub name: String, +} + /// Provide user-defined buckets to aggregate on. /// /// Two special buckets will automatically be created to cover the whole range of values. @@ -161,12 +177,12 @@ impl Debug for SegmentRangeBucketEntry { impl SegmentRangeBucketEntry { pub(crate) fn into_intermediate_bucket_entry( self, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let mut sub_aggregation_res = IntermediateAggregationResults::default(); if let Some(sub_aggregation) = self.sub_aggregation { sub_aggregation - .add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)? + .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)? } else { Default::default() }; @@ -184,12 +200,14 @@ impl SegmentRangeBucketEntry { impl SegmentAggregationCollector for SegmentRangeCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { let field_type = self.column_type; - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; + let name = agg_data + .get_range_req_data(self.accessor_idx) + .name + .to_string(); let buckets: FxHashMap = self .buckets @@ -199,7 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector { range_to_string(&range_bucket.range, &field_type)?, range_bucket .bucket - .into_intermediate_bucket_entry(sub_agg)?, + .into_intermediate_bucket_entry(agg_data)?, )) }) .collect::>()?; @@ -218,66 +236,70 @@ impl SegmentAggregationCollector for SegmentRangeCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collect_block(&[doc], agg_with_accessor) + self.collect_block(&[doc], agg_data) } #[inline] fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; + // Take request data to avoid borrow conflicts during sub-aggregation + let mut req = agg_data.take_range_req_data(self.accessor_idx); - bucket_agg_accessor - .column_block_accessor - .fetch_block(docs, &bucket_agg_accessor.accessor); + req.column_block_accessor.fetch_block(docs, &req.accessor); - for (doc, val) in bucket_agg_accessor + for (doc, val) in req .column_block_accessor - .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + .iter_docid_vals(docs, &req.accessor) { let bucket_pos = self.get_bucket_pos(val); - let bucket = &mut self.buckets[bucket_pos]; - bucket.bucket.doc_count += 1; - if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { - sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?; + if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { + sub_agg.collect(doc, agg_data)?; } } + agg_data.put_back_range_req_data(self.accessor_idx, req); + Ok(()) } - fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { - let sub_aggregation_accessor = - &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; - + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { for bucket in self.buckets.iter_mut() { if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { - sub_agg.flush(sub_aggregation_accessor)?; + sub_agg.flush(agg_data)?; } } - Ok(()) } } impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( - req: &RangeAggregation, - sub_aggregation: &mut AggregationsWithAccessor, - limits: &mut AggregationLimitsGuard, - field_type: ColumnType, - accessor_idx: usize, + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, ) -> crate::Result { + let accessor_idx = node.idx_in_req_data; + let (field_type, ranges) = { + let req_view = req_data.get_range_req_data(node.idx_in_req_data); + (req_view.field_type, req_view.req.ranges.clone()) + }; + // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)? + let sub_agg_prototype = if !node.children.is_empty() { + Some(build_segment_agg_collectors(req_data, &node.children)?) + } else { + None + }; + + let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)? .iter() .map(|range| { let key = range @@ -295,11 +317,7 @@ impl SegmentRangeCollector { } else { Some(f64_from_fastfield_u64(range.range.start, &field_type)) }; - let sub_aggregation = if sub_aggregation.is_empty() { - None - } else { - Some(build_segment_agg_collector(sub_aggregation)?) - }; + let sub_aggregation = sub_agg_prototype.clone(); Ok(SegmentRangeAndBucketEntry { range: range.range.clone(), @@ -314,7 +332,7 @@ impl SegmentRangeCollector { }) .collect::>()?; - limits.add_memory_consumed( + req_data.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; @@ -467,15 +485,45 @@ mod tests { ranges, ..Default::default() }; + // Build buckets directly as in from_req_and_validate without AggregationsData + let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type) + .expect("unexpected error in extend_validate_ranges") + .iter() + .map(|range| { + let key = range + .key + .clone() + .map(|key| Ok(Key::Str(key))) + .unwrap_or_else(|| range_to_key(&range.range, &field_type)) + .expect("unexpected error in range_to_key"); + let to = if range.range.end == u64::MAX { + None + } else { + Some(f64_from_fastfield_u64(range.range.end, &field_type)) + }; + let from = if range.range.start == u64::MIN { + None + } else { + Some(f64_from_fastfield_u64(range.range.start, &field_type)) + }; + SegmentRangeAndBucketEntry { + range: range.range.clone(), + bucket: SegmentRangeBucketEntry { + doc_count: 0, + sub_aggregation: None, + key, + from, + to, + }, + } + }) + .collect(); - SegmentRangeCollector::from_req_and_validate( - &req, - &mut Default::default(), - &mut AggregationLimitsGuard::default(), - field_type, - 0, - ) - .expect("unexpected error") + SegmentRangeCollector { + buckets, + column_type: field_type, + accessor_idx: 0, + } } #[test] diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 6a93952f1..2582f6a8f 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -4,27 +4,52 @@ use std::net::Ipv6Addr; use columnar::column_values::CompactSpaceU64Accessor; use columnar::{ - ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalValue, + Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128, + MonotonicallyMappableToU64, NumericalValue, StrColumn, }; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::{CustomOrder, Order, OrderTarget}; -use crate::aggregation::agg_limits::MemoryConsumption; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, }; +use crate::aggregation::agg_limits::MemoryConsumption; +use crate::aggregation::agg_req::Aggregations; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, SegmentAggregationCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::{format_date, Key}; use crate::error::DataCorruption; use crate::TantivyError; +/// Contains all information required by the SegmentTermCollector to perform the +/// terms aggregation on a segment. +pub struct TermsAggReqData { + /// The column accessor to access the fast field values. + pub accessor: Column, + /// The type of the column. + pub column_type: ColumnType, + /// The string dictionary column if the field is of type text. + pub str_dict_column: Option, + /// The missing value as u64 value. + pub missing_value_for_accessor: Option, + /// The column block accessor to access the fast field values. + pub column_block_accessor: ColumnBlockAccessor, + /// The type of the fast field. + pub field_type: ColumnType, + /// Note: sub_aggregation_blueprint is filled later when building collectors + pub sub_aggregation_blueprint: Option>, + /// Used to build the correct nested result when we have an empty result. + pub sug_aggregations: Aggregations, + /// The name of the aggregation. + pub name: String, + /// The normalized term aggregation request. + pub req: TermsAggregationInternal, +} + /// Creates a bucket for every unique term and counts the number of occurrences. /// Note that doc_count in the response buckets equals term count here. /// @@ -168,7 +193,7 @@ pub struct TermsAggregation { /// Same as TermsAggregation, but with populated defaults. #[derive(Clone, Debug, PartialEq)] -pub(crate) struct TermsAggregationInternal { +pub struct TermsAggregationInternal { /// The field to aggregate on. pub field: String, /// By default, the top 10 terms with the most documents are returned. @@ -193,7 +218,11 @@ pub(crate) struct TermsAggregationInternal { /// *Expensive*: When set to 0, this will return all terms in the field. pub min_doc_count: u64, + /// Set the order. `String` is here a target, which is either "_count", "_key", or the name of + /// a metric sub_aggregation. pub order: CustomOrder, + + /// The missing parameter defines how documents that are missing a value should be treated. pub missing: Option, } @@ -233,12 +262,9 @@ impl TermBuckets { sub_aggs_mem + buckets_mem } - fn force_flush( - &mut self, - agg_with_accessor: &mut AggregationsWithAccessor, - ) -> crate::Result<()> { + fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { for sub_aggregations in &mut self.sub_aggs.values_mut() { - sub_aggregations.as_mut().flush(agg_with_accessor)?; + sub_aggregations.as_mut().flush(agg_data)?; } Ok(()) } @@ -250,9 +276,6 @@ impl TermBuckets { pub struct SegmentTermCollector { /// The buckets containing the aggregation data. term_buckets: TermBuckets, - req: TermsAggregationInternal, - blueprint: Option>, - column_type: ColumnType, accessor_idx: usize, } @@ -264,13 +287,12 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { impl SegmentAggregationCollector for SegmentTermCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; + let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?; + let bucket = self.into_intermediate_bucket_result(agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; Ok(()) @@ -280,65 +302,63 @@ impl SegmentAggregationCollector for SegmentTermCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collect_block(&[doc], agg_with_accessor) + self.collect_block(&[doc], agg_data) } #[inline] fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; + let mut req_data = agg_data.take_term_req_data(self.accessor_idx); let mem_pre = self.get_memory_consumption(); - if let Some(missing) = bucket_agg_accessor.missing_value_for_accessor { - bucket_agg_accessor - .column_block_accessor - .fetch_block_with_missing(docs, &bucket_agg_accessor.accessor, missing); + if let Some(missing) = req_data.missing_value_for_accessor { + req_data.column_block_accessor.fetch_block_with_missing( + docs, + &req_data.accessor, + missing, + ); } else { - bucket_agg_accessor + req_data .column_block_accessor - .fetch_block(docs, &bucket_agg_accessor.accessor); + .fetch_block(docs, &req_data.accessor); } - for term_id in bucket_agg_accessor.column_block_accessor.iter_vals() { + for term_id in req_data.column_block_accessor.iter_vals() { let entry = self.term_buckets.entries.entry(term_id).or_default(); *entry += 1; } // has subagg - if let Some(blueprint) = self.blueprint.as_ref() { - for (doc, term_id) in bucket_agg_accessor + if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() { + for (doc, term_id) in req_data .column_block_accessor - .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + .iter_docid_vals(docs, &req_data.accessor) { let sub_aggregations = self .term_buckets .sub_aggs .entry(term_id) .or_insert_with(|| blueprint.clone()); - sub_aggregations.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?; + sub_aggregations.collect(doc, agg_data)?; } } let mem_delta = self.get_memory_consumption() - mem_pre; if mem_delta > 0 { - bucket_agg_accessor - .limits - .add_memory_consumed(mem_delta as u64)?; + agg_data.limits.add_memory_consumed(mem_delta as u64)?; } + agg_data.put_back_term_req_data(self.accessor_idx, req_data); Ok(()) } - fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { - let sub_aggregation_accessor = - &mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation; - - self.term_buckets.force_flush(sub_aggregation_accessor)?; + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + self.term_buckets.force_flush(agg_data)?; Ok(()) } } @@ -351,11 +371,12 @@ impl SegmentTermCollector { } pub(crate) fn from_req_and_validate( - req: &TermsAggregation, - sub_aggregations: &mut AggregationsWithAccessor, - field_type: ColumnType, - accessor_idx: usize, + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, ) -> crate::Result { + let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); + let field_type = terms_req_data.field_type; + let accessor_idx = node.idx_in_req_data; if field_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( "terms aggregation is not supported for column type {field_type:?}" @@ -363,33 +384,31 @@ impl SegmentTermCollector { } let term_buckets = TermBuckets::default(); - if let Some(custom_order) = req.order.as_ref() { - // Validate sub aggregation exists - if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target { - let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); + // Validate sub aggregation exists + if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { + let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); - sub_aggregations.aggs.get(agg_name).ok_or_else(|| { + node.get_sub_agg(agg_name, &req_data.per_request) + .ok_or_else(|| { TantivyError::InvalidArgument(format!( "could not find aggregation with name {agg_name} in metric \ sub_aggregations" )) })?; - } } - let has_sub_aggregations = !sub_aggregations.is_empty(); + let has_sub_aggregations = !node.children.is_empty(); let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collector(sub_aggregations)?; + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; Some(sub_aggregation) } else { None }; + let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data); + terms_req_data.sub_aggregation_blueprint = blueprint; Ok(SegmentTermCollector { - req: TermsAggregationInternal::from_req(req), term_buckets, - blueprint, - column_type: field_type, accessor_idx, }) } @@ -397,19 +416,20 @@ impl SegmentTermCollector { #[inline] pub(crate) fn into_intermediate_bucket_result( mut self, - agg_with_accessor: &AggregationWithAccessor, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result { + let term_req = agg_data.get_term_req_data(self.accessor_idx); let mut entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect(); let order_by_sub_aggregation = - matches!(self.req.order.target, OrderTarget::SubAggregation(_)); + matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); - match self.req.order.target { + match &term_req.req.order.target { OrderTarget::Key => { // We rely on the fact, that term ordinals match the order of the strings // TODO: We could have a special collector, that keeps only TOP n results at any // time. - if self.req.order.order == Order::Desc { + if term_req.req.order.order == Order::Desc { entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.0)); } else { entries.sort_unstable_by_key(|bucket| bucket.0); @@ -421,7 +441,7 @@ impl SegmentTermCollector { // to check). } OrderTarget::Count => { - if self.req.order.order == Order::Desc { + if term_req.req.order.order == Order::Desc { entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1)); } else { entries.sort_unstable_by_key(|bucket| bucket.1); @@ -432,7 +452,7 @@ impl SegmentTermCollector { let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation { (0, 0) } else { - cut_off_buckets(&mut entries, self.req.segment_size as usize) + cut_off_buckets(&mut entries, term_req.req.segment_size as usize) }; let mut dict: FxHashMap = Default::default(); @@ -440,7 +460,7 @@ impl SegmentTermCollector { let mut into_intermediate_bucket_entry = |id, doc_count| -> crate::Result { - let intermediate_entry = if self.blueprint.as_ref().is_some() { + let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { let mut sub_aggregation_res = IntermediateAggregationResults::default(); self.term_buckets .sub_aggs @@ -448,10 +468,7 @@ impl SegmentTermCollector { .unwrap_or_else(|| { panic!("Internal Error: could not find subaggregation for id {id}") }) - .add_intermediate_aggregation_result( - &agg_with_accessor.sub_aggregation, - &mut sub_aggregation_res, - )?; + .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; IntermediateTermBucketEntry { doc_count, @@ -466,9 +483,9 @@ impl SegmentTermCollector { Ok(intermediate_entry) }; - if self.column_type == ColumnType::Str { + if term_req.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); - let term_dict = agg_with_accessor + let term_dict = term_req .str_dict_column .as_ref() .map(|el| el.dictionary()) @@ -479,7 +496,7 @@ impl SegmentTermCollector { if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) { let entry = entries[index]; let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?; - let missing_key = self + let missing_key = term_req .req .missing .as_ref() @@ -530,14 +547,13 @@ impl SegmentTermCollector { }, )?; - if self.req.min_doc_count == 0 { + if term_req.req.min_doc_count == 0 { // TODO: Handle rev streaming for descending sorting by keys let mut stream = term_dict.stream()?; - let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req( - agg_with_accessor.agg.sub_aggregation(), - ); + let empty_sub_aggregation = + IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations); while let Some((key, _ord)) = stream.next() { - if dict.len() >= self.req.segment_size as usize { + if dict.len() >= term_req.req.segment_size as usize { break; } @@ -554,21 +570,21 @@ impl SegmentTermCollector { }); } } - } else if self.column_type == ColumnType::DateTime { + } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } - } else if self.column_type == ColumnType::Bool { + } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } - } else if self.column_type == ColumnType::IpAddr { - let compact_space_accessor = agg_with_accessor + } else if term_req.column_type == ColumnType::IpAddr { + let compact_space_accessor = term_req .accessor .values .clone() @@ -591,9 +607,9 @@ impl SegmentTermCollector { } else { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - if self.column_type == ColumnType::U64 { + if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); - } else if self.column_type == ColumnType::I64 { + } else if term_req.column_type == ColumnType::I64 { dict.insert(IntermediateKey::I64(i64::from_u64(val)), intermediate_entry); } else { let val = f64::from_u64(val); diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index df24eee12..6f99af968 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -1,13 +1,32 @@ +use columnar::{Column, ColumnType}; use rustc_hash::FxHashMap; -use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor; +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; +use crate::aggregation::bucket::term_agg::TermsAggregation; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, SegmentAggregationCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; + +/// Special aggregation to handle missing values for term aggregations. +/// This missing aggregation will check multiple columns for existence. +/// +/// This is needed when: +/// - The field is multi-valued and we therefore have multiple columns +/// - The field is not text and missing is provided as string (we cannot use the numeric missing +/// value optimization) +#[derive(Default)] +pub struct MissingTermAggReqData { + /// The accessors to check for existence of a value. + pub accessors: Vec<(Column, ColumnType)>, + /// The name of the aggregation. + pub name: String, + /// The original terms aggregation request. + pub req: TermsAggregation, +} /// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] @@ -18,12 +37,13 @@ pub struct TermMissingAgg { } impl TermMissingAgg { pub(crate) fn new( - accessor_idx: usize, - sub_aggregations: &mut AggregationsWithAccessor, + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, ) -> crate::Result { - let has_sub_aggregations = !sub_aggregations.is_empty(); + let has_sub_aggregations = !node.children.is_empty(); + let accessor_idx = node.idx_in_req_data; let sub_agg = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collector(sub_aggregations)?; + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; Some(sub_aggregation) } else { None @@ -40,16 +60,11 @@ impl TermMissingAgg { impl SegmentAggregationCollector for TermMissingAgg { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; - let term_agg = agg_with_accessor - .agg - .agg - .as_term() - .expect("TermMissingAgg collector must be term agg req"); + let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); + let term_agg = &req_data.req; let missing = term_agg .missing .as_ref() @@ -64,10 +79,7 @@ impl SegmentAggregationCollector for TermMissingAgg { }; if let Some(sub_agg) = self.sub_agg { let mut res = IntermediateAggregationResults::default(); - sub_agg.add_intermediate_aggregation_result( - &agg_with_accessor.sub_aggregation, - &mut res, - )?; + sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; missing_entry.sub_aggregation = res; } entries.insert(missing.into(), missing_entry); @@ -80,7 +92,10 @@ impl SegmentAggregationCollector for TermMissingAgg { }, }; - results.push(name, IntermediateAggregationResult::Bucket(bucket))?; + results.push( + req_data.name.to_string(), + IntermediateAggregationResult::Bucket(bucket), + )?; Ok(()) } @@ -88,17 +103,17 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - let has_value = agg + let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); + let has_value = req_data .accessors .iter() .any(|(acc, _)| acc.index.has_value(doc)); if !has_value { self.missing_count += 1; if let Some(sub_agg) = self.sub_agg.as_mut() { - sub_agg.collect(doc, &mut agg.sub_aggregation)?; + sub_agg.collect(doc, agg_data)?; } } Ok(()) @@ -107,10 +122,10 @@ impl SegmentAggregationCollector for TermMissingAgg { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for doc in docs { - self.collect(*doc, agg_with_accessor)?; + self.collect(*doc, agg_data)?; } Ok(()) } diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs index 15be6281b..e34c84760 100644 --- a/src/aggregation/buf_collector.rs +++ b/src/aggregation/buf_collector.rs @@ -1,6 +1,6 @@ -use super::agg_req_with_accessor::AggregationsWithAccessor; use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::DocId; pub(crate) const DOC_BLOCK_SIZE: usize = 64; @@ -37,23 +37,23 @@ impl SegmentAggregationCollector for BufAggregationCollector { #[inline] fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results) + Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results) } #[inline] fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { self.staged_docs[self.num_staged_docs] = doc; self.num_staged_docs += 1; if self.num_staged_docs == self.staged_docs.len() { self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?; + .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; self.num_staged_docs = 0; } Ok(()) @@ -63,20 +63,20 @@ impl SegmentAggregationCollector for BufAggregationCollector { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collector.collect_block(docs, agg_with_accessor)?; + self.collector.collect_block(docs, agg_data)?; Ok(()) } #[inline] - fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { self.collector - .collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?; + .collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?; self.num_staged_docs = 0; - self.collector.flush(agg_with_accessor)?; + self.collector.flush(agg_data)?; Ok(()) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 3c5ad4eae..10e3ef526 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,12 +1,11 @@ use super::agg_req::Aggregations; -use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; use super::buf_collector::BufAggregationCollector; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::{ - build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector, +use super::segment_agg_result::{AggregationLimitsGuard, SegmentAggregationCollector}; +use crate::aggregation::agg_data::{ + build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; -use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; use crate::index::SegmentReader; use crate::{DocId, SegmentOrdinal, TantivyError}; @@ -135,7 +134,7 @@ fn merge_fruits( /// `AggregationSegmentCollector` does the aggregation collection on a segment. pub struct AggregationSegmentCollector { - aggs_with_accessor: AggregationsWithAccessor, + aggs_with_accessor: AggregationsSegmentCtx, agg_collector: BufAggregationCollector, error: Option, } @@ -149,12 +148,13 @@ impl AggregationSegmentCollector { segment_ordinal: SegmentOrdinal, limits: &AggregationLimitsGuard, ) -> crate::Result { - let mut aggs_with_accessor = - get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?; + let mut agg_data = + build_aggregations_data_from_req(agg, reader, segment_ordinal, limits.clone())?; let result = - BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?); + BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); + Ok(AggregationSegmentCollector { - aggs_with_accessor, + aggs_with_accessor: agg_data, agg_collector: result, error: None, }) diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 58940de3f..f0309eafe 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -246,6 +246,7 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult /// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] pub enum IntermediateAggregationResult { /// Bucket variant Bucket(IntermediateBucketResult), diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 4f494b486..8331b3ab3 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -2,15 +2,13 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{BuildHasher, Hasher}; use columnar::column_values::CompactSpaceU64Accessor; -use columnar::Dictionary; +use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn}; use common::f64_to_u64; use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, -}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; @@ -97,6 +95,25 @@ pub struct CardinalityAggregationReq { pub missing: Option, } +/// Contains all information required by the SegmentCardinalityCollector to perform the +/// cardinality aggregation on a segment. +pub struct CardinalityAggReqData { + /// The column accessor to access the fast field values. + pub accessor: Column, + /// The column_type of the field. + pub column_type: ColumnType, + /// The string dictionary column if the field is of type string. + pub str_dict_column: Option, + /// The missing value normalized to the internal u64 representation of the field type. + pub missing_value_for_accessor: Option, + /// The column block accessor to access the fast field values. + pub(crate) column_block_accessor: ColumnBlockAccessor, + /// The name of the aggregation. + pub name: String, + /// The aggregation request. + pub req: CardinalityAggregationReq, +} + impl CardinalityAggregationReq { /// Creates a new [`CardinalityAggregationReq`] instance from a field name. pub fn from_field_name(field_name: String) -> Self { @@ -115,47 +132,44 @@ impl CardinalityAggregationReq { pub(crate) struct SegmentCardinalityCollector { cardinality: CardinalityCollector, entries: FxHashSet, - column_type: ColumnType, accessor_idx: usize, - missing: Option, } impl SegmentCardinalityCollector { - pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option) -> Self { + pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { Self { cardinality: CardinalityCollector::new(column_type as u8), entries: Default::default(), - column_type, accessor_idx, - missing: missing.clone(), } } fn fetch_block_with_field( &mut self, docs: &[crate::DocId], - agg_accessor: &mut AggregationWithAccessor, + agg_data: &mut CardinalityAggReqData, ) { - if let Some(missing) = agg_accessor.missing_value_for_accessor { - agg_accessor.column_block_accessor.fetch_block_with_missing( + if let Some(missing) = agg_data.missing_value_for_accessor { + agg_data.column_block_accessor.fetch_block_with_missing( docs, - &agg_accessor.accessor, + &agg_data.accessor, missing, ); } else { - agg_accessor + agg_data .column_block_accessor - .fetch_block(docs, &agg_accessor.accessor); + .fetch_block(docs, &agg_data.accessor); } } fn into_intermediate_metric_result( mut self, - agg_with_accessor: &AggregationWithAccessor, + agg_data: &AggregationsSegmentCtx, ) -> crate::Result { - if self.column_type == ColumnType::Str { + let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); + if req_data.column_type == ColumnType::Str { let fallback_dict = Dictionary::empty(); - let dict = agg_with_accessor + let dict = req_data .str_dict_column .as_ref() .map(|el| el.dictionary()) @@ -180,10 +194,10 @@ impl SegmentCardinalityCollector { })?; if has_missing { // Replace missing with the actual value provided - let missing_key = self - .missing - .as_ref() - .expect("Found sentinel value u64::MAX for term_ord but `missing` is not set"); + let missing_key = + req_data.req.missing.as_ref().expect( + "Found sentinel value u64::MAX for term_ord but `missing` is not set", + ); match missing_key { Key::Str(missing) => { self.cardinality.sketch.insert_any(&missing); @@ -209,13 +223,13 @@ impl SegmentCardinalityCollector { impl SegmentAggregationCollector for SegmentCardinalityCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx]; + let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); + let name = req_data.name.to_string(); - let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?; + let intermediate_result = self.into_intermediate_metric_result(agg_data)?; results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -227,26 +241,26 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collect_block(&[doc], agg_with_accessor) + self.collect_block(&[doc], agg_data) } fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - self.fetch_block_with_field(docs, bucket_agg_accessor); + let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); + self.fetch_block_with_field(docs, req_data); - let col_block_accessor = &bucket_agg_accessor.column_block_accessor; - if self.column_type == ColumnType::Str { + let col_block_accessor = &req_data.column_block_accessor; + if req_data.column_type == ColumnType::Str { for term_ord in col_block_accessor.iter_vals() { self.entries.insert(term_ord); } - } else if self.column_type == ColumnType::IpAddr { - let compact_space_accessor = bucket_agg_accessor + } else if req_data.column_type == ColumnType::IpAddr { + let compact_space_accessor = req_data .accessor .values .clone() diff --git a/src/aggregation/metric/extended_stats.rs b/src/aggregation/metric/extended_stats.rs index d3cd59d9a..0250118a2 100644 --- a/src/aggregation/metric/extended_stats.rs +++ b/src/aggregation/metric/extended_stats.rs @@ -4,12 +4,11 @@ use std::mem; use serde::{Deserialize, Serialize}; use super::*; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, -}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; +use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; use crate::{DocId, TantivyError}; @@ -348,20 +347,20 @@ impl SegmentExtendedStatsCollector { pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut AggregationWithAccessor, + req_data: &mut MetricAggReqData, ) { if let Some(missing) = self.missing.as_ref() { - agg_accessor.column_block_accessor.fetch_block_with_missing( + req_data.column_block_accessor.fetch_block_with_missing( docs, - &agg_accessor.accessor, + &req_data.accessor, *missing, ); } else { - agg_accessor + req_data .column_block_accessor - .fetch_block(docs, &agg_accessor.accessor); + .fetch_block(docs, &req_data.accessor); } - for val in agg_accessor.column_block_accessor.iter_vals() { + for val in req_data.column_block_accessor.iter_vals() { let val1 = f64_from_fastfield_u64(val, &self.field_type); self.extended_stats.collect(val1); } @@ -372,10 +371,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { #[inline] fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); results.push( name, IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( @@ -390,12 +389,12 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; + let req_data = agg_data.get_metric_req_data(self.accessor_idx); if let Some(missing) = self.missing { let mut has_val = false; - for val in field.values_for_doc(doc) { + for val in req_data.accessor.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); self.extended_stats.collect(val1); has_val = true; @@ -405,7 +404,7 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { .collect(f64_from_fastfield_u64(missing, &self.field_type)); } } else { - for val in field.values_for_doc(doc) { + for val in req_data.accessor.values_for_doc(doc) { let val1 = f64_from_fastfield_u64(val, &self.field_type); self.extended_stats.collect(val1); } @@ -418,10 +417,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - self.collect_block_with_field(docs, field); + let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); + self.collect_block_with_field(docs, req_data); Ok(()) } } diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 9d470bc22..6342b2045 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -31,6 +31,7 @@ use std::collections::HashMap; pub use average::*; pub use cardinality::*; +use columnar::{Column, ColumnBlockAccessor, ColumnType}; pub use count::*; pub use extended_stats::*; pub use max::*; @@ -44,6 +45,28 @@ pub use top_hits::*; use crate::schema::OwnedValue; +/// Contains all information required by metric aggregations like avg, min, max, sum, stats, +/// extended_stats, count, percentiles. +#[repr(C)] +pub struct MetricAggReqData { + /// True if the field is of number or date type. + pub is_number_or_date_type: bool, + /// The type of the field. + pub field_type: ColumnType, + /// The missing value normalized to the internal u64 representation of the field type. + pub missing_u64: Option, + /// The column block accessor to access the fast field values. + pub column_block_accessor: ColumnBlockAccessor, + /// The column accessor to access the fast field values. + pub accessor: Column, + /// Used when converting to intermediate result + pub collecting_for: StatsType, + /// The missing value + pub missing: Option, + /// The name of the aggregation. + pub name: String, +} + /// Single-metric aggregations use this common result structure. /// /// Main reason to wrap it in value is to match elasticsearch output structure. diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index d34440e9f..c846e2187 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -3,12 +3,11 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::*; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, -}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; +use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; use crate::{DocId, TantivyError}; @@ -112,7 +111,8 @@ impl PercentilesAggregationReq { &self.field } - fn validate(&self) -> crate::Result<()> { + /// Validates the request parameters. + pub fn validate(&self) -> crate::Result<()> { if let Some(percents) = self.percents.as_ref() { let all_in_range = percents .iter() @@ -133,10 +133,8 @@ impl PercentilesAggregationReq { #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentPercentilesCollector { - field_type: ColumnType, pub(crate) percentiles: PercentilesCollector, pub(crate) accessor_idx: usize, - missing: Option, } #[derive(Clone, Serialize, Deserialize)] @@ -231,43 +229,32 @@ impl PercentilesCollector { } impl SegmentPercentilesCollector { - pub fn from_req_and_validate( - req: &PercentilesAggregationReq, - field_type: ColumnType, - accessor_idx: usize, - ) -> crate::Result { - req.validate()?; - let missing = req - .missing - .and_then(|val| f64_to_fastfield_u64(val, &field_type)); - + pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result { Ok(Self { - field_type, percentiles: PercentilesCollector::new(), accessor_idx, - missing, }) } #[inline] pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut AggregationWithAccessor, + req_data: &mut MetricAggReqData, ) { - if let Some(missing) = self.missing.as_ref() { - agg_accessor.column_block_accessor.fetch_block_with_missing( + if let Some(missing) = req_data.missing_u64.as_ref() { + req_data.column_block_accessor.fetch_block_with_missing( docs, - &agg_accessor.accessor, + &req_data.accessor, *missing, ); } else { - agg_accessor + req_data .column_block_accessor - .fetch_block(docs, &agg_accessor.accessor); + .fetch_block(docs, &req_data.accessor); } - for val in agg_accessor.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + for val in req_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.percentiles.collect(val1); } } @@ -277,10 +264,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { #[inline] fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); results.push( @@ -295,24 +282,24 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; + let req_data = agg_data.get_metric_req_data(self.accessor_idx); - if let Some(missing) = self.missing { + if let Some(missing) = req_data.missing_u64 { let mut has_val = false; - for val in field.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + for val in req_data.accessor.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.percentiles.collect(val1); has_val = true; } if !has_val { self.percentiles - .collect(f64_from_fastfield_u64(missing, &self.field_type)); + .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); } } else { - for val in field.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + for val in req_data.accessor.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.percentiles.collect(val1); } } @@ -324,10 +311,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - self.collect_block_with_field(docs, field); + let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); + self.collect_block_with_field(docs, req_data); Ok(()) } } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 99d2ddceb..56715fdea 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -3,12 +3,11 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::*; -use crate::aggregation::agg_req_with_accessor::{ - AggregationWithAccessor, AggregationsWithAccessor, -}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; +use crate::aggregation::metric::MetricAggReqData; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::*; use crate::{DocId, TantivyError}; @@ -166,74 +165,65 @@ impl IntermediateStats { } } -#[derive(Clone, Debug, PartialEq)] -pub(crate) enum SegmentStatsType { +/// The type of stats aggregation to perform. +/// Note that not all stats types are supported in the stats aggregation. +#[derive(Clone, Copy, Debug)] +pub enum StatsType { + /// The average of the values. Average, + /// The count of the values. Count, + /// The maximum value. Max, + /// The minimum value. Min, + /// The stats (count, sum, min, max, avg) of the values. Stats, + /// The extended stats (count, sum, min, max, avg, sum_of_squares, variance, std_deviation, + ExtendedStats(Option), // sigma + /// The sum of the values. Sum, + /// The percentiles of the values. + Percentiles, } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct SegmentStatsCollector { - missing: Option, - field_type: ColumnType, - pub(crate) collecting_for: SegmentStatsType, pub(crate) stats: IntermediateStats, pub(crate) accessor_idx: usize, - val_cache: Vec, } impl SegmentStatsCollector { - pub fn from_req( - field_type: ColumnType, - collecting_for: SegmentStatsType, - accessor_idx: usize, - missing: Option, - ) -> Self { - let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type)); + pub fn from_req(accessor_idx: usize) -> Self { Self { - field_type, - collecting_for, stats: IntermediateStats::default(), accessor_idx, - missing, - val_cache: Default::default(), } } #[inline] pub(crate) fn collect_block_with_field( &mut self, docs: &[DocId], - agg_accessor: &mut AggregationWithAccessor, + req_data: &mut MetricAggReqData, ) { - if let Some(missing) = self.missing.as_ref() { - agg_accessor.column_block_accessor.fetch_block_with_missing( + if let Some(missing) = req_data.missing_u64.as_ref() { + req_data.column_block_accessor.fetch_block_with_missing( docs, - &agg_accessor.accessor, + &req_data.accessor, *missing, ); } else { - agg_accessor + req_data .column_block_accessor - .fetch_block(docs, &agg_accessor.accessor); + .fetch_block(docs, &req_data.accessor); } - if [ - ColumnType::I64, - ColumnType::U64, - ColumnType::F64, - ColumnType::DateTime, - ] - .contains(&self.field_type) - { - for val in agg_accessor.column_block_accessor.iter_vals() { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + if req_data.is_number_or_date_type { + for val in req_data.column_block_accessor.iter_vals() { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.stats.collect(val1); } } else { - for _val in agg_accessor.column_block_accessor.iter_vals() { + for _val in req_data.column_block_accessor.iter_vals() { // we ignore the value and simply record that we got something self.stats.collect(0.0); } @@ -245,27 +235,28 @@ impl SegmentAggregationCollector for SegmentStatsCollector { #[inline] fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let req = agg_data.get_metric_req_data(self.accessor_idx); + let name = req.name.clone(); - let intermediate_metric_result = match self.collecting_for { - SegmentStatsType::Average => { + let intermediate_metric_result = match req.collecting_for { + StatsType::Average => { IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) } - SegmentStatsType::Count => { + StatsType::Count => { IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) } - SegmentStatsType::Max => { - IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)) - } - SegmentStatsType::Min => { - IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)) - } - SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats), - SegmentStatsType::Sum => { - IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)) + StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), + StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), + StatsType::Stats => IntermediateMetricResult::Stats(self.stats), + StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), + _ => { + return Err(TantivyError::InvalidArgument(format!( + "Unsupported stats type for stats aggregation: {:?}", + req.collecting_for + ))) } }; @@ -281,23 +272,23 @@ impl SegmentAggregationCollector for SegmentStatsCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor; - if let Some(missing) = self.missing { + let req_data = agg_data.get_metric_req_data(self.accessor_idx); + if let Some(missing) = req_data.missing_u64 { let mut has_val = false; - for val in field.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + for val in req_data.accessor.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.stats.collect(val1); has_val = true; } if !has_val { self.stats - .collect(f64_from_fastfield_u64(missing, &self.field_type)); + .collect(f64_from_fastfield_u64(missing, &req_data.field_type)); } } else { - for val in field.values_for_doc(doc) { - let val1 = f64_from_fastfield_u64(val, &self.field_type); + for val in req_data.accessor.values_for_doc(doc) { + let val1 = f64_from_fastfield_u64(val, &req_data.field_type); self.stats.collect(val1); } } @@ -309,10 +300,10 @@ impl SegmentAggregationCollector for SegmentStatsCollector { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let field = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - self.collect_block_with_field(docs, field); + let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); + self.collect_block_with_field(docs, req_data); Ok(()) } } diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index a75f7e3bd..e59f9b210 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -9,6 +9,7 @@ use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::{TopHitsMetricResult, TopHitsVecEntry}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::bucket::Order; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateMetricResult, @@ -18,6 +19,23 @@ use crate::aggregation::AggregationError; use crate::collector::TopNComputer; use crate::schema::OwnedValue; use crate::{DocAddress, DocId, SegmentOrdinal}; +// duplicate import removed; already imported above + +/// Contains all information required by the TopHitsSegmentCollector to perform the +/// top_hits aggregation on a segment. +#[derive(Default)] +pub struct TopHitsAggReqData { + /// The accessors to access the fast field values. + pub accessors: Vec<(Column, ColumnType)>, + /// The accessors to access the fast field values for retrieving document fields. + pub value_accessors: HashMap>, + /// The ordinal of the segment this request data is for. + pub segment_ordinal: SegmentOrdinal, + /// The name of the aggregation. + pub name: String, + /// The top_hits aggregation request. + pub req: TopHitsAggregationReq, +} /// # Top Hits /// @@ -566,23 +584,18 @@ impl TopHitsSegmentCollector { impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, ) -> crate::Result<()> { - let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); - let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors; - let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx] - .agg - .agg - .as_top_hits() - .expect("aggregation request must be of type top hits"); + let value_accessors = &req_data.value_accessors; let intermediate_result = IntermediateMetricResult::TopHits( - self.into_top_hits_collector(value_accessors, tophits_req), + self.into_top_hits_collector(value_accessors, &req_data.req), ); results.push( - name, + req_data.name.to_string(), IntermediateAggregationResult::Metric(intermediate_result), ) } @@ -591,32 +604,22 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector { fn collect( &mut self, doc_id: crate::DocId, - agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx] - .agg - .agg - .as_top_hits() - .expect("aggregation request must be of type top hits"); - let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors; - self.collect_with(doc_id, tophits_req, accessors)?; + let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); + self.collect_with(doc_id, &req_data.req, &req_data.accessors)?; Ok(()) } fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx] - .agg - .agg - .as_top_hits() - .expect("aggregation request must be of type top hits"); - let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors; + let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); // TODO: Consider getting fields with the column block accessor. for doc in docs { - self.collect_with(*doc, tophits_req, accessors)?; + self.collect_with(*doc, &req_data.req, &req_data.accessors)?; } Ok(()) } diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 173ada3be..2ff47ea84 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -127,6 +127,7 @@ //! [`AggregationResults`](agg_result::AggregationResults) via the //! [`into_final_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_result) method. +mod agg_data; mod agg_limits; pub mod agg_req; mod agg_req_with_accessor; @@ -140,7 +141,6 @@ pub mod intermediate_agg_result; pub mod metric; mod segment_agg_result; -use std::collections::HashMap; use std::fmt::Display; #[cfg(test)] @@ -257,80 +257,6 @@ where D: Deserializer<'de> { deserializer.deserialize_any(StringOrFloatVisitor) } -/// Represents an associative array `(key => values)` in a very efficient manner. -#[derive(PartialEq, Serialize, Deserialize)] -pub(crate) struct VecWithNames { - pub(crate) values: Vec, - keys: Vec, -} - -impl Clone for VecWithNames { - fn clone(&self) -> Self { - Self { - values: self.values.clone(), - keys: self.keys.clone(), - } - } -} - -impl Default for VecWithNames { - fn default() -> Self { - Self { - values: Default::default(), - keys: Default::default(), - } - } -} - -impl std::fmt::Debug for VecWithNames { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_map().entries(self.iter()).finish() - } -} - -impl From> for VecWithNames { - fn from(map: HashMap) -> Self { - VecWithNames::from_entries(map.into_iter().collect_vec()) - } -} - -impl VecWithNames { - fn from_entries(mut entries: Vec<(String, T)>) -> Self { - // Sort to ensure order of elements match across multiple instances - entries.sort_by(|left, right| left.0.cmp(&right.0)); - let mut data = Vec::with_capacity(entries.len()); - let mut data_names = Vec::with_capacity(entries.len()); - for entry in entries { - data_names.push(entry.0); - data.push(entry.1); - } - VecWithNames { - values: data, - keys: data_names, - } - } - fn iter(&self) -> impl Iterator + '_ { - self.keys().zip(self.values.iter()) - } - fn keys(&self) -> impl Iterator + '_ { - self.keys.iter().map(|key| key.as_str()) - } - fn values_mut(&mut self) -> impl Iterator + '_ { - self.values.iter_mut() - } - fn is_empty(&self) -> bool { - self.keys.is_empty() - } - fn len(&self) -> usize { - self.keys.len() - } - fn get(&self, name: &str) -> Option<&T> { - self.keys() - .position(|key| key == name) - .map(|pos| &self.values[pos]) - } -} - /// The serialized key is used in a `HashMap`. pub type SerializedKey = String; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 747543e23..5cc2650b6 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -6,48 +6,38 @@ use std::fmt::Debug; pub(crate) use super::agg_limits::AggregationLimitsGuard; -use super::agg_req::AggregationVariants; -use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor}; -use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::metric::{ - AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, - SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation, - SumAggregation, -}; -use crate::aggregation::bucket::TermMissingAgg; -use crate::aggregation::metric::{ - CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector, - TopHitsSegmentCollector, -}; +use crate::aggregation::agg_data::AggregationsSegmentCtx; -pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { +/// A SegmentAggregationCollector is used to collect aggregation results. +pub trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()>; fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()>; /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. /// This method ensures those staged docs will be collected. - fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { + fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { Ok(()) } } -pub(crate) trait CollectorClone { +/// A helper trait to enable cloning of Box +pub trait CollectorClone { fn clone_box(&self) -> Box; } @@ -65,119 +55,6 @@ impl Clone for Box { } } -pub(crate) fn build_segment_agg_collector( - req: &mut AggregationsWithAccessor, -) -> crate::Result> { - // Single collector special case - if req.aggs.len() == 1 { - let req = &mut req.aggs.values[0]; - let accessor_idx = 0; - return build_single_agg_segment_collector(req, accessor_idx); - } - - let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?; - Ok(Box::new(agg)) -} - -pub(crate) fn build_single_agg_segment_collector( - req: &mut AggregationWithAccessor, - accessor_idx: usize, -) -> crate::Result> { - use AggregationVariants::*; - match &req.agg.agg { - Terms(terms_req) => { - if req.accessors.is_empty() { - Ok(Box::new(SegmentTermCollector::from_req_and_validate( - terms_req, - &mut req.sub_aggregation, - req.field_type, - accessor_idx, - )?)) - } else { - Ok(Box::new(TermMissingAgg::new( - accessor_idx, - &mut req.sub_aggregation, - )?)) - } - } - Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( - range_req, - &mut req.sub_aggregation, - &mut req.limits, - req.field_type, - accessor_idx, - )?)), - Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - histogram.clone(), - &mut req.sub_aggregation, - req.field_type, - accessor_idx, - )?)), - DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( - histogram.to_histogram_req()?, - &mut req.sub_aggregation, - req.field_type, - accessor_idx, - )?)), - Average(AverageAggregation { missing, .. }) => { - Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Average, - accessor_idx, - *missing, - ))) - } - Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Count, - accessor_idx, - *missing, - ))), - Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Max, - accessor_idx, - *missing, - ))), - Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Min, - accessor_idx, - *missing, - ))), - Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Stats, - accessor_idx, - *missing, - ))), - ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new( - SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing), - )), - Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req( - req.field_type, - SegmentStatsType::Sum, - accessor_idx, - *missing, - ))), - Percentiles(percentiles_req) => Ok(Box::new( - SegmentPercentilesCollector::from_req_and_validate( - percentiles_req, - req.field_type, - accessor_idx, - )?, - )), - TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req( - top_hits_req, - accessor_idx, - req.segment_ordinal, - ))), - Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new( - SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing), - )), - } -} - #[derive(Clone, Default)] /// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one @@ -197,11 +74,11 @@ impl Debug for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn add_intermediate_aggregation_result( self: Box, - agg_with_accessor: &AggregationsWithAccessor, + agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { for agg in self.aggs { - agg.add_intermediate_aggregation_result(agg_with_accessor, results)?; + agg.add_intermediate_aggregation_result(agg_data, results)?; } Ok(()) @@ -210,9 +87,9 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect( &mut self, doc: crate::DocId, - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { - self.collect_block(&[doc], agg_with_accessor)?; + self.collect_block(&[doc], agg_data)?; Ok(()) } @@ -220,32 +97,19 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { fn collect_block( &mut self, docs: &[crate::DocId], - agg_with_accessor: &mut AggregationsWithAccessor, + agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { for collector in &mut self.aggs { - collector.collect_block(docs, agg_with_accessor)?; + collector.collect_block(docs, agg_data)?; } Ok(()) } - fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> { + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { for collector in &mut self.aggs { - collector.flush(agg_with_accessor)?; + collector.flush(agg_data)?; } Ok(()) } } - -impl GenericSegmentAggregationResultsCollector { - pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result { - let aggs = req - .aggs - .values_mut() - .enumerate() - .map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx)) - .collect::>>>()?; - - Ok(GenericSegmentAggregationResultsCollector { aggs }) - } -}