From b2573a3b16aec4f785736390effa4551bbefb65d Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 19 Nov 2025 18:41:10 +0100 Subject: [PATCH] low cardinality optimisation --- src/aggregation/agg_data.rs | 12 +- src/aggregation/agg_limits.rs | 2 +- .../bucket/term_agg/default_impl.rs | 196 +++++ .../bucket/term_agg/low_cardinality_impl.rs | 228 ++++++ .../bucket/{term_agg.rs => term_agg/mod.rs} | 674 +++++++----------- src/aggregation/segment_agg_result.rs | 5 +- 6 files changed, 697 insertions(+), 420 deletions(-) create mode 100644 src/aggregation/bucket/term_agg/default_impl.rs create mode 100644 src/aggregation/bucket/term_agg/low_cardinality_impl.rs rename src/aggregation/bucket/{term_agg.rs => term_agg/mod.rs} (82%) diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index 3b29830a7..3ed0167b1 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -10,10 +10,10 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, - MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, - TermsAggregationInternal, + build_segment_aggregation_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds, + IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, + SegmentHistogramCollector, SegmentRangeCollector, TermMissingAgg, TermsAggReqData, + TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, @@ -373,9 +373,7 @@ pub(crate) fn build_segment_agg_collector( node: &AggRefNode, ) -> crate::Result> { match node.kind { - AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Terms => build_segment_aggregation_collector(req, node), AggKind::MissingTerm => { let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data]; if req_data.accessors.is_empty() { diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs index 76dfbca9d..3e7eee151 100644 --- a/src/aggregation/agg_limits.rs +++ b/src/aggregation/agg_limits.rs @@ -70,7 +70,7 @@ impl AggregationLimitsGuard { /// *memory_limit* /// memory_limit is defined in bytes. /// Aggregation fails when the estimated memory consumption of the aggregation is higher than - /// memory_limit. + /// memory_limit. /// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB) /// /// *bucket_limit* diff --git a/src/aggregation/bucket/term_agg/default_impl.rs b/src/aggregation/bucket/term_agg/default_impl.rs new file mode 100644 index 000000000..22f00426f --- /dev/null +++ b/src/aggregation/bucket/term_agg/default_impl.rs @@ -0,0 +1,196 @@ +use std::fmt::Debug; + +use columnar::ColumnType; +use rustc_hash::FxHashMap; + +use super::OrderTarget; +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; +use crate::aggregation::agg_limits::MemoryConsumption; +use crate::aggregation::bucket::get_agg_name_and_property; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::TantivyError; + +#[derive(Clone, Debug, Default)] +/// Container to store term_ids/or u64 values and their buckets. +struct TermBuckets { + pub(crate) entries: FxHashMap, + pub(crate) sub_aggs: FxHashMap>, +} + +impl TermBuckets { + fn get_memory_consumption(&self) -> usize { + let sub_aggs_mem = self.sub_aggs.memory_consumption(); + let buckets_mem = self.entries.memory_consumption(); + sub_aggs_mem + buckets_mem + } + + fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + for sub_aggregations in &mut self.sub_aggs.values_mut() { + sub_aggregations.as_mut().flush(agg_data)?; + } + Ok(()) + } +} + +/// The collector puts values from the fast field into the correct buckets and does a conversion to +/// the correct datatype. +#[derive(Clone, Debug)] +pub struct SegmentTermCollector { + /// The buckets containing the aggregation data. + term_buckets: TermBuckets, + accessor_idx: usize, +} + +impl SegmentAggregationCollector for SegmentTermCollector { + fn add_intermediate_aggregation_result( + self: Box, + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); + + let entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect(); + let bucket = super::into_intermediate_bucket_result( + self.accessor_idx, + entries, + self.term_buckets.sub_aggs, + agg_data, + )?; + results.push(name, IntermediateAggregationResult::Bucket(bucket))?; + + Ok(()) + } + + #[inline] + fn collect( + &mut self, + doc: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.collect_block(&[doc], agg_data) + } + + #[inline] + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + let mut req_data = agg_data.take_term_req_data(self.accessor_idx); + + let mem_pre = self.get_memory_consumption(); + + if let Some(missing) = req_data.missing_value_for_accessor { + req_data.column_block_accessor.fetch_block_with_missing( + docs, + &req_data.accessor, + missing, + ); + } else { + req_data + .column_block_accessor + .fetch_block(docs, &req_data.accessor); + } + + for term_id in req_data.column_block_accessor.iter_vals() { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } + let entry = self.term_buckets.entries.entry(term_id).or_default(); + *entry += 1; + } + // has subagg + if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() { + for (doc, term_id) in req_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor) + { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } + let sub_aggregations = self + .term_buckets + .sub_aggs + .entry(term_id) + .or_insert_with(|| blueprint.clone()); + sub_aggregations.collect(doc, agg_data)?; + } + } + + let mem_delta = self.get_memory_consumption() - mem_pre; + if mem_delta > 0 { + agg_data + .context + .limits + .add_memory_consumed(mem_delta as u64)?; + } + agg_data.put_back_term_req_data(self.accessor_idx, req_data); + + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + self.term_buckets.force_flush(agg_data)?; + Ok(()) + } +} + +impl SegmentTermCollector { + pub fn from_req_and_validate( + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, + ) -> crate::Result { + let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); + let column_type = terms_req_data.column_type; + let accessor_idx = node.idx_in_req_data; + if column_type == ColumnType::Bytes { + return Err(TantivyError::InvalidArgument(format!( + "terms aggregation is not supported for column type {column_type:?}" + ))); + } + let term_buckets = TermBuckets::default(); + + // Validate sub aggregation exists + if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { + let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); + + node.get_sub_agg(agg_name, &req_data.per_request) + .ok_or_else(|| { + TantivyError::InvalidArgument(format!( + "could not find aggregation with name {agg_name} in metric \ + sub_aggregations" + )) + })?; + } + + let has_sub_aggregations = !node.children.is_empty(); + let blueprint = if has_sub_aggregations { + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + Some(sub_aggregation) + } else { + None + }; + let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data); + terms_req_data.sub_aggregation_blueprint = blueprint; + + Ok(SegmentTermCollector { + term_buckets, + accessor_idx, + }) + } + + fn get_memory_consumption(&self) -> usize { + let self_mem = std::mem::size_of::(); + let term_buckets_mem = self.term_buckets.get_memory_consumption(); + self_mem + term_buckets_mem + } +} diff --git a/src/aggregation/bucket/term_agg/low_cardinality_impl.rs b/src/aggregation/bucket/term_agg/low_cardinality_impl.rs new file mode 100644 index 000000000..42e96416f --- /dev/null +++ b/src/aggregation/bucket/term_agg/low_cardinality_impl.rs @@ -0,0 +1,228 @@ +use std::vec; + +use rustc_hash::FxHashMap; + +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; +use crate::aggregation::bucket::{get_agg_name_and_property, OrderTarget}; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::{DocId, TantivyError}; + +const MAX_BATCH_SIZE: usize = 1_024; + +#[derive(Debug, Clone)] +struct LowCardTermBuckets { + entries: Box<[u32]>, + sub_aggs: Vec>, + doc_buffers: Box<[Vec]>, +} + +impl LowCardTermBuckets { + pub fn with_num_buckets( + num_buckets: usize, + sub_aggs_blueprint_opt: Option<&Box>, + ) -> Self { + let sub_aggs = sub_aggs_blueprint_opt + .as_ref() + .map(|blueprint| { + std::iter::repeat_with(|| blueprint.clone_box()) + .take(num_buckets) + .collect::>() + }) + .unwrap_or_default(); + Self { + entries: vec![0; num_buckets].into_boxed_slice(), + sub_aggs, + doc_buffers: std::iter::repeat_with(|| Vec::with_capacity(MAX_BATCH_SIZE)) + .take(num_buckets) + .collect::>() + .into_boxed_slice(), + } + } + + fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + + self.entries.len() * std::mem::size_of::() + + self.doc_buffers.len() + * (std::mem::size_of::>() + + std::mem::size_of::() * MAX_BATCH_SIZE) + } +} + +#[derive(Debug, Clone)] +pub struct LowCardSegmentTermCollector { + term_buckets: LowCardTermBuckets, + accessor_idx: usize, +} + +impl LowCardSegmentTermCollector { + pub fn from_req_and_validate( + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, + ) -> crate::Result { + let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); + let accessor_idx = node.idx_in_req_data; + let cardinality = terms_req_data + .accessor + .max_value() + .max(terms_req_data.missing_value_for_accessor.unwrap_or(0)) + + 1; + assert!(cardinality <= super::LOW_CARDINALITY_THRESHOLD); + + // Validate sub aggregation exists + if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { + let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); + + node.get_sub_agg(agg_name, &req_data.per_request) + .ok_or_else(|| { + TantivyError::InvalidArgument(format!( + "could not find aggregation with name {agg_name} in metric \ + sub_aggregations" + )) + })?; + } + + let has_sub_aggregations = !node.children.is_empty(); + let blueprint = if has_sub_aggregations { + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + Some(sub_aggregation) + } else { + None + }; + let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data); + + let term_buckets = + LowCardTermBuckets::with_num_buckets(cardinality as usize, blueprint.as_ref()); + + terms_req_data.sub_aggregation_blueprint = blueprint; + + Ok(LowCardSegmentTermCollector { + term_buckets, + accessor_idx, + }) + } + + fn get_memory_consumption(&self) -> usize { + let self_mem = std::mem::size_of::(); + let term_buckets_mem = self.term_buckets.get_memory_consumption(); + self_mem + term_buckets_mem + } +} + +impl SegmentAggregationCollector for LowCardSegmentTermCollector { + fn add_intermediate_aggregation_result( + self: Box, + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); + let sub_aggs: FxHashMap> = self + .term_buckets + .sub_aggs + .into_iter() + .enumerate() + .filter(|(bucket_id, _sub_agg)| self.term_buckets.entries[*bucket_id] > 0) + .map(|(bucket_id, sub_agg)| (bucket_id as u64, sub_agg)) + .collect(); + let entries: Vec<(u64, u32)> = self + .term_buckets + .entries + .iter() + .enumerate() + .filter(|(_, count)| **count > 0) + .map(|(bucket_id, count)| (bucket_id as u64, *count)) + .collect(); + + let bucket = + super::into_intermediate_bucket_result(self.accessor_idx, entries, sub_aggs, agg_data)?; + results.push(name, IntermediateAggregationResult::Bucket(bucket))?; + Ok(()) + } + + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if docs.len() > MAX_BATCH_SIZE { + for batch in docs.chunks(MAX_BATCH_SIZE) { + self.collect_block(batch, agg_data)?; + } + } + + let mut req_data = agg_data.take_term_req_data(self.accessor_idx); + + let mem_pre = self.get_memory_consumption(); + + if let Some(missing) = req_data.missing_value_for_accessor { + req_data.column_block_accessor.fetch_block_with_missing( + docs, + &req_data.accessor, + missing, + ); + } else { + req_data + .column_block_accessor + .fetch_block(docs, &req_data.accessor); + } + + // has subagg + if req_data.sub_aggregation_blueprint.is_some() { + for (doc, term_id) in req_data + .column_block_accessor + .iter_docid_vals(docs, &req_data.accessor) + { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } + self.term_buckets.doc_buffers[term_id as usize].push(doc); + } + for (bucket_id, docs) in self.term_buckets.doc_buffers.iter_mut().enumerate() { + self.term_buckets.entries[bucket_id] += docs.len() as u32; + self.term_buckets.sub_aggs[bucket_id].collect_block(&docs[..], agg_data)?; + docs.clear(); + } + } else { + for term_id in req_data.column_block_accessor.iter_vals() { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } + self.term_buckets.entries[term_id as usize] += 1; + } + } + + let mem_delta = self.get_memory_consumption() - mem_pre; + if mem_delta > 0 { + agg_data + .context + .limits + .add_memory_consumed(mem_delta as u64)?; + } + agg_data.put_back_term_req_data(self.accessor_idx, req_data); + + Ok(()) + } + + fn collect( + &mut self, + doc: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.collect_block(&[doc], agg_data) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + for sub_aggregations in &mut self.term_buckets.sub_aggs.iter_mut() { + sub_aggregations.as_mut().flush(agg_data)?; + } + Ok(()) + } +} diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg/mod.rs similarity index 82% rename from src/aggregation/bucket/term_agg.rs rename to src/aggregation/bucket/term_agg/mod.rs index 0b18eaa6b..79aac45f0 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg/mod.rs @@ -1,3 +1,6 @@ +mod default_impl; +mod low_cardinality_impl; + use std::fmt::Debug; use std::io; use std::net::Ipv6Addr; @@ -12,20 +15,24 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::{CustomOrder, Order, OrderTarget}; -use crate::aggregation::agg_data::{ - build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, -}; -use crate::aggregation::agg_limits::MemoryConsumption; +use crate::aggregation::agg_data::{AggRefNode, AggregationsSegmentCtx}; use crate::aggregation::agg_req::Aggregations; +use crate::aggregation::bucket::term_agg::default_impl::SegmentTermCollector; +use crate::aggregation::bucket::term_agg::low_cardinality_impl::LowCardSegmentTermCollector; use crate::aggregation::intermediate_agg_result::{ - IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, - IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, + IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, + IntermediateTermBucketEntry, IntermediateTermBucketResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::{format_date, Key}; use crate::error::DataCorruption; use crate::TantivyError; +pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { + let (agg_name, agg_property) = name.split_once('.').unwrap_or((name, "")); + (agg_name, agg_property) +} + /// Contains all information required by the SegmentTermCollector to perform the /// terms aggregation on a segment. pub struct TermsAggReqData { @@ -331,415 +338,34 @@ impl TermsAggregationInternal { } } -#[derive(Clone, Debug, Default)] -/// Container to store term_ids/or u64 values and their buckets. -struct TermBuckets { - pub(crate) entries: FxHashMap, - pub(crate) sub_aggs: FxHashMap>, -} +const LOW_CARDINALITY_THRESHOLD: u64 = 10; -impl TermBuckets { - fn get_memory_consumption(&self) -> usize { - let sub_aggs_mem = self.sub_aggs.memory_consumption(); - let buckets_mem = self.entries.memory_consumption(); - sub_aggs_mem + buckets_mem +pub(crate) fn build_segment_aggregation_collector( + req: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let terms_req_data = req.get_term_req_data(node.idx_in_req_data); + let column_type = terms_req_data.column_type; + if column_type == ColumnType::Bytes { + return Err(TantivyError::InvalidArgument(format!( + "terms aggregation is not supported for column type {column_type:?}" + ))); } - fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregations in &mut self.sub_aggs.values_mut() { - sub_aggregations.as_mut().flush(agg_data)?; - } - Ok(()) - } -} + let cardinality = terms_req_data + .accessor + .max_value() + .max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) + .saturating_add(1); -/// The collector puts values from the fast field into the correct buckets and does a conversion to -/// the correct datatype. -#[derive(Clone, Debug)] -pub struct SegmentTermCollector { - /// The buckets containing the aggregation data. - term_buckets: TermBuckets, - accessor_idx: usize, -} - -pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { - let (agg_name, agg_property) = name.split_once('.').unwrap_or((name, "")); - (agg_name, agg_property) -} - -impl SegmentAggregationCollector for SegmentTermCollector { - fn add_intermediate_aggregation_result( - self: Box, - agg_data: &AggregationsSegmentCtx, - results: &mut IntermediateAggregationResults, - ) -> crate::Result<()> { - let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - - let bucket = self.into_intermediate_bucket_result(agg_data)?; - results.push(name, IntermediateAggregationResult::Bucket(bucket))?; - - Ok(()) - } - - #[inline] - fn collect( - &mut self, - doc: crate::DocId, - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - self.collect_block(&[doc], agg_data) - } - - #[inline] - fn collect_block( - &mut self, - docs: &[crate::DocId], - agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()> { - let mut req_data = agg_data.take_term_req_data(self.accessor_idx); - - let mem_pre = self.get_memory_consumption(); - - if let Some(missing) = req_data.missing_value_for_accessor { - req_data.column_block_accessor.fetch_block_with_missing( - docs, - &req_data.accessor, - missing, - ); - } else { - req_data - .column_block_accessor - .fetch_block(docs, &req_data.accessor); - } - - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let entry = self.term_buckets.entries.entry(term_id).or_default(); - *entry += 1; - } - // has subagg - if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() { - for (doc, term_id) in req_data - .column_block_accessor - .iter_docid_vals(docs, &req_data.accessor) - { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; - } - } - let sub_aggregations = self - .term_buckets - .sub_aggs - .entry(term_id) - .or_insert_with(|| blueprint.clone()); - sub_aggregations.collect(doc, agg_data)?; - } - } - - let mem_delta = self.get_memory_consumption() - mem_pre; - if mem_delta > 0 { - agg_data - .context - .limits - .add_memory_consumed(mem_delta as u64)?; - } - agg_data.put_back_term_req_data(self.accessor_idx, req_data); - - Ok(()) - } - - fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.force_flush(agg_data)?; - Ok(()) - } -} - -impl SegmentTermCollector { - fn get_memory_consumption(&self) -> usize { - let self_mem = std::mem::size_of::(); - let term_buckets_mem = self.term_buckets.get_memory_consumption(); - self_mem + term_buckets_mem - } - - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); - let column_type = terms_req_data.column_type; - let accessor_idx = node.idx_in_req_data; - if column_type == ColumnType::Bytes { - return Err(TantivyError::InvalidArgument(format!( - "terms aggregation is not supported for column type {column_type:?}" - ))); - } - let term_buckets = TermBuckets::default(); - - // Validate sub aggregation exists - if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { - let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); - - node.get_sub_agg(agg_name, &req_data.per_request) - .ok_or_else(|| { - TantivyError::InvalidArgument(format!( - "could not find aggregation with name {agg_name} in metric \ - sub_aggregations" - )) - })?; - } - - let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data); - terms_req_data.sub_aggregation_blueprint = blueprint; - - Ok(SegmentTermCollector { - term_buckets, - accessor_idx, - }) - } - - #[inline] - pub(crate) fn into_intermediate_bucket_result( - mut self, - agg_data: &AggregationsSegmentCtx, - ) -> crate::Result { - let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect(); - - let order_by_sub_aggregation = - matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); - - match &term_req.req.order.target { - OrderTarget::Key => { - // We rely on the fact, that term ordinals match the order of the strings - // TODO: We could have a special collector, that keeps only TOP n results at any - // time. - if term_req.req.order.order == Order::Desc { - entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.0)); - } else { - entries.sort_unstable_by_key(|bucket| bucket.0); - } - } - OrderTarget::SubAggregation(_name) => { - // don't sort and cut off since it's hard to make assumptions on the quality of the - // results when cutting off du to unknown nature of the sub_aggregation (possible - // to check). - } - OrderTarget::Count => { - if term_req.req.order.order == Order::Desc { - entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1)); - } else { - entries.sort_unstable_by_key(|bucket| bucket.1); - } - } - } - - let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation { - (0, 0) - } else { - cut_off_buckets(&mut entries, term_req.req.segment_size as usize) - }; - - let mut dict: FxHashMap = Default::default(); - dict.reserve(entries.len()); - - let mut into_intermediate_bucket_entry = - |id, doc_count| -> crate::Result { - let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { - let mut sub_aggregation_res = IntermediateAggregationResults::default(); - self.term_buckets - .sub_aggs - .remove(&id) - .unwrap_or_else(|| { - panic!("Internal Error: could not find subaggregation for id {id}") - }) - .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - - IntermediateTermBucketEntry { - doc_count, - sub_aggregation: sub_aggregation_res, - } - } else { - IntermediateTermBucketEntry { - doc_count, - sub_aggregation: Default::default(), - } - }; - Ok(intermediate_entry) - }; - - if term_req.column_type == ColumnType::Str { - let fallback_dict = Dictionary::empty(); - let term_dict = term_req - .str_dict_column - .as_ref() - .map(|el| el.dictionary()) - .unwrap_or_else(|| &fallback_dict); - let mut buffer = Vec::new(); - - // special case for missing key - if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) { - let entry = entries[index]; - let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?; - let missing_key = term_req - .req - .missing - .as_ref() - .expect("Found placeholder term_id but `missing` is None"); - match missing_key { - Key::Str(missing) => { - buffer.clear(); - buffer.extend_from_slice(missing.as_bytes()); - dict.insert( - IntermediateKey::Str( - String::from_utf8(buffer.to_vec()) - .expect("could not convert to String"), - ), - intermediate_entry, - ); - } - Key::F64(val) => { - dict.insert(IntermediateKey::F64(*val), intermediate_entry); - } - Key::U64(val) => { - dict.insert(IntermediateKey::U64(*val), intermediate_entry); - } - Key::I64(val) => { - dict.insert(IntermediateKey::I64(*val), intermediate_entry); - } - } - - entries.swap_remove(index); - } - - // Sort by term ord - entries.sort_unstable_by_key(|bucket| bucket.0); - let mut idx = 0; - term_dict.sorted_ords_to_term_cb( - entries.iter().map(|(term_id, _)| *term_id), - |term| { - let entry = entries[idx]; - let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1) - .map_err(io::Error::other)?; - dict.insert( - IntermediateKey::Str( - String::from_utf8(term.to_vec()).expect("could not convert to String"), - ), - intermediate_entry, - ); - idx += 1; - Ok(()) - }, - )?; - - if term_req.req.min_doc_count == 0 { - // TODO: Handle rev streaming for descending sorting by keys - let mut stream = term_dict.stream()?; - let empty_sub_aggregation = - IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations); - while stream.advance() { - if dict.len() >= term_req.req.segment_size as usize { - break; - } - - // Respect allowed filters if present - if let Some(allowed_bs) = term_req.allowed_term_ids.as_ref() { - if !allowed_bs.contains(stream.term_ord() as u32) { - continue; - } - } - - let key = IntermediateKey::Str( - std::str::from_utf8(stream.key()) - .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? - .to_string(), - ); - - dict.entry(key.clone()) - .or_insert_with(|| IntermediateTermBucketEntry { - doc_count: 0, - sub_aggregation: empty_sub_aggregation.clone(), - }); - } - } - } else if term_req.column_type == ColumnType::DateTime { - for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val = i64::from_u64(val); - let date = format_date(val)?; - dict.insert(IntermediateKey::Str(date), intermediate_entry); - } - } else if term_req.column_type == ColumnType::Bool { - for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val = bool::from_u64(val); - dict.insert(IntermediateKey::Bool(val), intermediate_entry); - } - } else if term_req.column_type == ColumnType::IpAddr { - let compact_space_accessor = term_req - .accessor - .values - .clone() - .downcast_arc::() - .map_err(|_| { - TantivyError::AggregationError( - crate::aggregation::AggregationError::InternalError( - "Type mismatch: Could not downcast to CompactSpaceU64Accessor" - .to_string(), - ), - ) - })?; - - for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val: u128 = compact_space_accessor.compact_to_u128(val as u32); - let val = Ipv6Addr::from_u128(val); - dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); - } - } else { - for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - if term_req.column_type == ColumnType::U64 { - dict.insert(IntermediateKey::U64(val), intermediate_entry); - } else if term_req.column_type == ColumnType::I64 { - dict.insert(IntermediateKey::I64(i64::from_u64(val)), intermediate_entry); - } else { - let val = f64::from_u64(val); - let val: NumericalValue = val.into(); - - match val.normalize() { - NumericalValue::U64(val) => { - dict.insert(IntermediateKey::U64(val), intermediate_entry); - } - NumericalValue::I64(val) => { - dict.insert(IntermediateKey::I64(val), intermediate_entry); - } - NumericalValue::F64(val) => { - dict.insert(IntermediateKey::F64(val), intermediate_entry); - } - } - }; - } - }; - - Ok(IntermediateBucketResult::Terms { - buckets: IntermediateTermBucketResult { - entries: dict, - sum_other_doc_count, - doc_count_error_upper_bound: term_doc_count_before_cutoff, - }, - }) + if cardinality <= LOW_CARDINALITY_THRESHOLD { + Ok(Box::new( + LowCardSegmentTermCollector::from_req_and_validate(req, node)?, + )) + } else { + Ok(Box::new(SegmentTermCollector::from_req_and_validate( + req, node, + )?)) } } @@ -775,6 +401,232 @@ pub(crate) fn cut_off_buckets( (term_doc_count_before_cutoff, sum_other_doc_count) } +fn into_intermediate_bucket_result( + accessor_idx: usize, + mut entries: Vec<(u64, u32)>, + mut sub_aggs: FxHashMap>, + agg_data: &AggregationsSegmentCtx, +) -> crate::Result { + let term_req = agg_data.get_term_req_data(accessor_idx); + + let order_by_sub_aggregation = + matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); + + match &term_req.req.order.target { + OrderTarget::Key => { + // We rely on the fact, that term ordinals match the order of the strings + // TODO: We could have a special collector, that keeps only TOP n results at any + // time. + if term_req.req.order.order == Order::Desc { + entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.0)); + } else { + entries.sort_unstable_by_key(|bucket| bucket.0); + } + } + OrderTarget::SubAggregation(_name) => { + // don't sort and cut off since it's hard to make assumptions on the quality of the + // results when cutting off du to unknown nature of the sub_aggregation (possible + // to check). + } + OrderTarget::Count => { + if term_req.req.order.order == Order::Desc { + entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1)); + } else { + entries.sort_unstable_by_key(|bucket| bucket.1); + } + } + } + + let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation { + (0, 0) + } else { + cut_off_buckets(&mut entries, term_req.req.segment_size as usize) + }; + + let mut dict: FxHashMap = Default::default(); + dict.reserve(entries.len()); + + let mut into_intermediate_bucket_entry = + |id, doc_count| -> crate::Result { + let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + sub_aggs + .remove(&id) + .unwrap_or_else(|| { + panic!("Internal Error: could not find subaggregation for id {id}") + }) + .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + + IntermediateTermBucketEntry { + doc_count, + sub_aggregation: sub_aggregation_res, + } + } else { + IntermediateTermBucketEntry { + doc_count, + sub_aggregation: Default::default(), + } + }; + Ok(intermediate_entry) + }; + + if term_req.column_type == ColumnType::Str { + let fallback_dict = Dictionary::empty(); + let term_dict = term_req + .str_dict_column + .as_ref() + .map(|el| el.dictionary()) + .unwrap_or_else(|| &fallback_dict); + let mut buffer = Vec::new(); + + // special case for missing key + if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) { + let entry = entries[index]; + let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?; + let missing_key = term_req + .req + .missing + .as_ref() + .expect("Found placeholder term_id but `missing` is None"); + match missing_key { + Key::Str(missing) => { + buffer.clear(); + buffer.extend_from_slice(missing.as_bytes()); + dict.insert( + IntermediateKey::Str( + String::from_utf8(buffer.to_vec()) + .expect("could not convert to String"), + ), + intermediate_entry, + ); + } + Key::F64(val) => { + dict.insert(IntermediateKey::F64(*val), intermediate_entry); + } + Key::U64(val) => { + dict.insert(IntermediateKey::U64(*val), intermediate_entry); + } + Key::I64(val) => { + dict.insert(IntermediateKey::I64(*val), intermediate_entry); + } + } + + entries.swap_remove(index); + } + + // Sort by term ord + entries.sort_unstable_by_key(|bucket| bucket.0); + let mut idx = 0; + term_dict.sorted_ords_to_term_cb(entries.iter().map(|(term_id, _)| *term_id), |term| { + let entry = entries[idx]; + let intermediate_entry = + into_intermediate_bucket_entry(entry.0, entry.1).map_err(io::Error::other)?; + dict.insert( + IntermediateKey::Str( + String::from_utf8(term.to_vec()).expect("could not convert to String"), + ), + intermediate_entry, + ); + idx += 1; + Ok(()) + })?; + + if term_req.req.min_doc_count == 0 { + // TODO: Handle rev streaming for descending sorting by keys + let mut stream = term_dict.stream()?; + let empty_sub_aggregation = + IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations); + while stream.advance() { + if dict.len() >= term_req.req.segment_size as usize { + break; + } + + // Respect allowed filters if present + if let Some(allowed_bs) = term_req.allowed_term_ids.as_ref() { + if !allowed_bs.contains(stream.term_ord() as u32) { + continue; + } + } + + let key = IntermediateKey::Str( + std::str::from_utf8(stream.key()) + .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? + .to_string(), + ); + + dict.entry(key.clone()) + .or_insert_with(|| IntermediateTermBucketEntry { + doc_count: 0, + sub_aggregation: empty_sub_aggregation.clone(), + }); + } + } + } else if term_req.column_type == ColumnType::DateTime { + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val = i64::from_u64(val); + let date = format_date(val)?; + dict.insert(IntermediateKey::Str(date), intermediate_entry); + } + } else if term_req.column_type == ColumnType::Bool { + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val = bool::from_u64(val); + dict.insert(IntermediateKey::Bool(val), intermediate_entry); + } + } else if term_req.column_type == ColumnType::IpAddr { + let compact_space_accessor = term_req + .accessor + .values + .clone() + .downcast_arc::() + .map_err(|_| { + TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError( + "Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(), + )) + })?; + + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val: u128 = compact_space_accessor.compact_to_u128(val as u32); + let val = Ipv6Addr::from_u128(val); + dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); + } + } else { + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + if term_req.column_type == ColumnType::U64 { + dict.insert(IntermediateKey::U64(val), intermediate_entry); + } else if term_req.column_type == ColumnType::I64 { + dict.insert(IntermediateKey::I64(i64::from_u64(val)), intermediate_entry); + } else { + let val = f64::from_u64(val); + let val: NumericalValue = val.into(); + + match val.normalize() { + NumericalValue::U64(val) => { + dict.insert(IntermediateKey::U64(val), intermediate_entry); + } + NumericalValue::I64(val) => { + dict.insert(IntermediateKey::I64(val), intermediate_entry); + } + NumericalValue::F64(val) => { + dict.insert(IntermediateKey::F64(val), intermediate_entry); + } + } + }; + } + }; + + Ok(IntermediateBucketResult::Terms { + buckets: IntermediateTermBucketResult { + entries: dict, + sum_other_doc_count, + doc_count_error_upper_bound: term_doc_count_before_cutoff, + }, + }) +} + #[cfg(test)] mod tests { use std::net::IpAddr; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 5cc2650b6..cecdb2d8e 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -17,11 +17,14 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug { results: &mut IntermediateAggregationResults, ) -> crate::Result<()>; + #[inline] fn collect( &mut self, doc: crate::DocId, agg_data: &mut AggregationsSegmentCtx, - ) -> crate::Result<()>; + ) -> crate::Result<()> { + self.collect_block(&[doc], agg_data) + } fn collect_block( &mut self,