diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 8df1ba539..751f27d96 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -59,6 +59,8 @@ fn bench_agg(mut group: InputGroup) { register!(group, terms_many_order_by_term); register!(group, terms_many_with_top_hits); register!(group, terms_many_with_avg_sub_agg); + register!(group, terms_few_with_avg_sub_agg); + register!(group, terms_many_json_mixed_type_with_avg_sub_agg); register!(group, cardinality_agg); @@ -220,6 +222,19 @@ fn terms_many_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } + +fn terms_few_with_avg_sub_agg(index: &Index) { + let agg_req = json!({ + "my_texts": { + "terms": { "field": "text_few_terms" }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } + } + }, + }); + execute_agg(index, agg_req); +} + fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { let agg_req = json!({ "my_texts": { diff --git a/src/aggregation/accessor_helpers.rs b/src/aggregation/accessor_helpers.rs index eb44a734b..fa51041e4 100644 --- a/src/aggregation/accessor_helpers.rs +++ b/src/aggregation/accessor_helpers.rs @@ -16,15 +16,16 @@ use crate::index::SegmentReader; /// That way we can use it the same way as if it would come from the fastfield. pub(crate) fn get_missing_val_as_u64_lenient( column_type: ColumnType, + column_max_value: u64, missing: &Key, field_name: &str, ) -> crate::Result> { let missing_val = match missing { - Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX), + Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1), // Allow fallback to number on text fields - Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX), - Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX), - Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX), + Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1), + Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1), + Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1), Key::F64(val) if column_type.numerical_type().is_some() => { f64_to_fastfield_u64(*val, &column_type) } diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index 3b29830a7..deedb5781 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -12,7 +12,7 @@ use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations use crate::aggregation::bucket::{ FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, - SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ @@ -373,9 +373,7 @@ pub(crate) fn build_segment_agg_collector( node: &AggRefNode, ) -> crate::Result> { match node.kind { - AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate( - req, node, - )?)), + AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node), AggKind::MissingTerm => { let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data]; if req_data.accessors.is_empty() { @@ -498,7 +496,7 @@ pub(crate) fn build_aggregations_data_from_req( }; for (name, agg) in aggs.iter() { - let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?; + let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?; data.per_request.agg_tree.extend(nodes); } Ok(data) @@ -510,6 +508,7 @@ fn build_nodes( reader: &SegmentReader, segment_ordinal: SegmentOrdinal, data: &mut AggregationsSegmentCtx, + is_top_level: bool, ) -> crate::Result> { use AggregationVariants::*; match &req.agg { @@ -596,6 +595,7 @@ fn build_nodes( data, &req.sub_aggregation, TermsOrCardinalityRequest::Terms(terms_req.clone()), + is_top_level, ), Cardinality(card_req) => build_terms_or_cardinality_nodes( agg_name, @@ -606,6 +606,7 @@ fn build_nodes( data, &req.sub_aggregation, TermsOrCardinalityRequest::Cardinality(card_req.clone()), + is_top_level, ), Average(AverageAggregation { field, missing, .. }) | Max(MaxAggregation { field, missing, .. }) @@ -734,7 +735,7 @@ fn build_nodes( // Build the query and evaluator upfront let schema = reader.schema(); let tokenizers = &data.context.tokenizers; - let query = filter_req.parse_query(&schema, tokenizers)?; + let query = filter_req.parse_query(schema, tokenizers)?; let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new( query, schema.clone(), @@ -771,7 +772,14 @@ fn build_children( ) -> crate::Result> { let mut children = Vec::new(); for (name, agg) in aggs.iter() { - children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?); + children.extend(build_nodes( + name, + agg, + reader, + segment_ordinal, + data, + false, + )?); } Ok(children) } @@ -835,6 +843,7 @@ fn build_terms_or_cardinality_nodes( data: &mut AggregationsSegmentCtx, sub_aggs: &Aggregations, req: TermsOrCardinalityRequest, + is_top_level: bool, ) -> crate::Result> { let mut nodes = Vec::new(); @@ -891,7 +900,7 @@ fn build_terms_or_cardinality_nodes( let missing_value_for_accessor = if use_special_missing_agg { None } else if let Some(m) = missing.as_ref() { - get_missing_val_as_u64_lenient(column_type, m, field_name)? + get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)? } else { None }; @@ -924,6 +933,7 @@ fn build_terms_or_cardinality_nodes( sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), allowed_term_ids, + is_top_level, }); (idx_in_req_data, AggKind::Terms) } diff --git a/src/aggregation/agg_limits.rs b/src/aggregation/agg_limits.rs index 76dfbca9d..8e023e09b 100644 --- a/src/aggregation/agg_limits.rs +++ b/src/aggregation/agg_limits.rs @@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard { /// Allocated memory with this guard. allocated_with_the_guard: u64, } + impl Clone for AggregationLimitsGuard { fn clone(&self) -> Self { Self { @@ -70,7 +71,7 @@ impl AggregationLimitsGuard { /// *memory_limit* /// memory_limit is defined in bytes. /// Aggregation fails when the estimated memory consumption of the aggregation is higher than - /// memory_limit. + /// memory_limit. /// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB) /// /// *bucket_limit* diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index 18fc40b25..d4461bf1f 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -639,16 +639,14 @@ pub struct IntermediateFilterBucketResult { #[cfg(test)] mod tests { - use std::time::Instant; - use serde_json::{json, Value}; use super::*; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::{AggContextParams, AggregationCollector}; - use crate::query::{AllQuery, QueryParser, TermQuery}; - use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT}; + use crate::query::{AllQuery, TermQuery}; + use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT}; use crate::{doc, Index, IndexWriter}; // Test helper functions diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 0b18eaa6b..53ce7a5e5 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -17,6 +17,7 @@ use crate::aggregation::agg_data::{ }; use crate::aggregation::agg_limits::MemoryConsumption; use crate::aggregation::agg_req::Aggregations; +use crate::aggregation::buf_collector::BufAggregationCollector; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, @@ -49,6 +50,8 @@ pub struct TermsAggReqData { pub req: TermsAggregationInternal, /// Preloaded allowed term ords (string columns only). If set, only ords present are collected. pub allowed_term_ids: Option, + /// True if this terms aggregation is at the top level of the aggregation tree (not nested). + pub is_top_level: bool, } impl TermsAggReqData { @@ -331,34 +334,371 @@ impl TermsAggregationInternal { } } -#[derive(Clone, Debug, Default)] -/// Container to store term_ids/or u64 values and their buckets. -struct TermBuckets { - pub(crate) entries: FxHashMap, - pub(crate) sub_aggs: FxHashMap>, +impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector { + #[inline(always)] + fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self { + let sub_agg = sub_agg_blueprint_opt.clone_box(); + BufAggregationCollector::new(sub_agg) + } } -impl TermBuckets { - fn get_memory_consumption(&self) -> usize { - let sub_aggs_mem = self.sub_aggs.memory_consumption(); - let buckets_mem = self.entries.memory_consumption(); - sub_aggs_mem + buckets_mem +#[derive(Debug, Clone)] +struct BoxedAggregation(Box); + +impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation { + #[inline(always)] + fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self { + BoxedAggregation(sub_agg_blueprint.clone_box()) + } +} + +impl SegmentAggregationCollector for BoxedAggregation { + #[inline(always)] + fn add_intermediate_aggregation_result( + self: Box, + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + self.0 + .add_intermediate_aggregation_result(agg_data, results) } - fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - for sub_aggregations in &mut self.sub_aggs.values_mut() { - sub_aggregations.as_mut().flush(agg_data)?; + #[inline(always)] + fn collect( + &mut self, + doc: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.0.collect(doc, agg_data) + } + + #[inline(always)] + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.0.collect_block(docs, agg_data) + } +} + +#[derive(Debug, Clone, Copy)] +struct NoSubAgg; + +impl SegmentAggregationCollector for NoSubAgg { + #[inline(always)] + fn add_intermediate_aggregation_result( + self: Box, + _agg_data: &AggregationsSegmentCtx, + _results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + Ok(()) + } + + #[inline(always)] + fn collect( + &mut self, + _doc: crate::DocId, + _agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + Ok(()) + } + + #[inline(always)] + fn collect_block( + &mut self, + _docs: &[crate::DocId], + _agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + Ok(()) + } +} + +/// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed +/// bucket storage, depending on the column type and aggregation level. +pub(crate) fn build_segment_term_collector( + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, +) -> crate::Result> { + let accessor_idx = node.idx_in_req_data; + let column_type = { + let terms_req_data = req_data.get_term_req_data(accessor_idx); + terms_req_data.column_type + }; + + if column_type == ColumnType::Bytes { + return Err(TantivyError::InvalidArgument(format!( + "terms aggregation is not supported for column type {column_type:?}" + ))); + } + + // Validate sub aggregation exists when ordering by sub-aggregation. + { + let terms_req_data = req_data.get_term_req_data(accessor_idx); + if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { + let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); + + node.get_sub_agg(agg_name, &req_data.per_request) + .ok_or_else(|| { + TantivyError::InvalidArgument(format!( + "could not find aggregation with name {agg_name} in metric \ + sub_aggregations" + )) + })?; + } + } + + // Build sub-aggregation blueprint if there are children. + let has_sub_aggregations = !node.children.is_empty(); + let blueprint = if has_sub_aggregations { + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + Some(sub_aggregation) + } else { + None + }; + { + let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx); + terms_req_data_mut.sub_aggregation_blueprint = blueprint; + } + + // Decide whether to use a Vec-backed or HashMap-backed bucket storage. + let terms_req_data = req_data.get_term_req_data(accessor_idx); + + // TODO: A better metric instead of is_top_level would be the number of buckets expected. + // E.g. If term agg is not top level, but the parent is a bucket agg with less than 10 buckets, + // we can still use Vec. + let can_use_vec = terms_req_data.is_top_level; + + // TODO: Benchmark to validate the threshold + const MAX_NUM_TERMS_FOR_VEC: usize = 100; + + // Let's see if we can use a vec to aggregate our data + // instead of a hashmap. + let col_max_value = terms_req_data.accessor.max_value(); + let max_term: usize = + col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize; + + // - use a Vec instead of a hashmap for our aggregation. + // - buffer aggregation of our child aggregations (in any) + #[allow(clippy::collapsible_else_if)] + if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC { + if has_sub_aggregations { + let sub_agg_blueprint = &req_data + .get_term_req_data_mut(accessor_idx) + .sub_aggregation_blueprint + .as_ref() + .ok_or_else(|| { + // Handle the error case here + // For example, return an error message or a default value + TantivyError::InternalError("Sub-aggregation blueprint not found".to_string()) + })?; + let term_buckets = VecTermBuckets::new(max_term + 1, || { + let collector_clone = sub_agg_blueprint.clone_box(); + BufAggregationCollector::new(collector_clone) + }); + let collector = SegmentTermCollector { + term_buckets, + accessor_idx, + }; + Ok(Box::new(collector)) + } else { + let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg); + let collector = SegmentTermCollector { + term_buckets, + accessor_idx, + }; + Ok(Box::new(collector)) + } + } else { + if has_sub_aggregations { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + let collector: SegmentTermCollector> = + SegmentTermCollector { + term_buckets, + accessor_idx, + }; + Ok(Box::new(collector)) + } else { + let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default(); + let collector: SegmentTermCollector> = + SegmentTermCollector { + term_buckets, + accessor_idx, + }; + Ok(Box::new(collector)) + } + } +} + +#[derive(Debug, Clone)] +struct Bucket { + pub count: u32, + pub sub_agg: SubAgg, +} + +impl Bucket { + #[inline(always)] + fn new(sub_agg: SubAgg) -> Self { + Self { count: 0, sub_agg } + } +} + +/// Abstraction over the storage used for term buckets (counts only). +trait TermAggregationMap: Clone + Debug + 'static { + type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static; + + /// Estimate the memory consumption of this struct in bytes. + fn get_memory_consumption(&self) -> usize; + + /// Returns the bucket assocaited to a given term_id. + fn term_entry( + &mut self, + term_id: u64, + blue_print: &dyn SegmentAggregationCollector, + ) -> &mut Bucket; + + /// If the tree of aggregations contains buffered aggregations, flush them. + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>; + + /// Returns the term aggregation as a vector of (term_id, bucket) pairs, + /// in any order. + fn into_vec(self) -> Vec<(u64, Bucket)>; +} + +#[derive(Clone, Debug)] +struct HashMapTermBuckets { + bucket_map: FxHashMap>, +} + +impl Default for HashMapTermBuckets { + #[inline(always)] + fn default() -> Self { + Self { + bucket_map: FxHashMap::default(), + } + } +} + +impl< + SubAgg: Debug + + Clone + + SegmentAggregationCollector + + for<'a> From<&'a dyn SegmentAggregationCollector> + + 'static, + > TermAggregationMap for HashMapTermBuckets +{ + type SubAggregation = SubAgg; + + #[inline] + fn get_memory_consumption(&self) -> usize { + self.bucket_map.memory_consumption() + } + + #[inline(always)] + fn term_entry( + &mut self, + term_id: u64, + sub_agg_blueprint: &dyn SegmentAggregationCollector, + ) -> &mut Bucket { + self.bucket_map + .entry(term_id) + .or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint))) + } + + #[inline(always)] + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + for bucket in self.bucket_map.values_mut() { + bucket.sub_agg.flush(agg_data)?; } Ok(()) } + + fn into_vec(self) -> Vec<(u64, Bucket)> { + self.bucket_map.into_iter().collect() + } +} + +/// An optimized term map implementation for a compact set of term ordinals. +#[derive(Clone, Debug)] +struct VecTermBuckets { + buckets: Vec>, +} + +impl VecTermBuckets { + fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self { + VecTermBuckets { + buckets: std::iter::repeat_with(item_factory_fn) + .map(Bucket::new) + .take(num_terms) + .collect(), + } + } +} + +impl TermAggregationMap + for VecTermBuckets +{ + type SubAggregation = SubAgg; + + /// Estimate the memory consumption of this struct in bytes. + fn get_memory_consumption(&self) -> usize { + // We do not include `std::mem::size_of::()` + // It is already measure by the parent aggregation. + // + // The root aggregation mem size is not measure but we do not care. + self.buckets.capacity() * std::mem::size_of::>() + } + + /// Add an occurrence of the given term id. + #[inline(always)] + fn term_entry( + &mut self, + term_id: u64, + _sub_agg_blueprint: &dyn SegmentAggregationCollector, + ) -> &mut Bucket { + let term_id_usize = term_id as usize; + debug_assert!( + term_id_usize < self.buckets.len(), + "term_id {} out of bounds for VecTermBuckets (len={})", + term_id, + self.buckets.len() + ); + unsafe { self.buckets.get_unchecked_mut(term_id_usize) } + } + + #[inline(always)] + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + for bucket in &mut self.buckets { + if bucket.count > 0 { + bucket.sub_agg.flush(agg_data)?; + } + } + Ok(()) + } + + fn into_vec(self) -> Vec<(u64, Bucket)> { + self.buckets + .into_iter() + .enumerate() + .filter(|(_, bucket)| bucket.count > 0) + .map(|(term_id, bucket)| (term_id as u64, bucket)) + .collect() + } +} + +impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg { + #[inline(always)] + fn from(_: &'a dyn SegmentAggregationCollector) -> Self { + Self + } } /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug)] -pub struct SegmentTermCollector { +struct SegmentTermCollector { /// The buckets containing the aggregation data. - term_buckets: TermBuckets, + term_buckets: TermMap, accessor_idx: usize, } @@ -367,17 +707,19 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { (agg_name, agg_property) } -impl SegmentAggregationCollector for SegmentTermCollector { +impl SegmentAggregationCollector for SegmentTermCollector +where + TermMap: TermAggregationMap, + TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>, +{ fn add_intermediate_aggregation_result( self: Box, agg_data: &AggregationsSegmentCtx, results: &mut IntermediateAggregationResults, ) -> crate::Result<()> { let name = agg_data.get_term_req_data(self.accessor_idx).name.clone(); - let bucket = self.into_intermediate_bucket_result(agg_data)?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; - Ok(()) } @@ -412,17 +754,23 @@ impl SegmentAggregationCollector for SegmentTermCollector { .fetch_block(docs, &req_data.accessor); } - for term_id in req_data.column_block_accessor.iter_vals() { - if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { - if !allowed_bs.contains(term_id as u32) { - continue; + if std::any::TypeId::of::() == std::any::TypeId::of::() { + for term_id in req_data.column_block_accessor.iter_vals() { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } } + let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg); + bucket.count += 1; } - let entry = self.term_buckets.entries.entry(term_id).or_default(); - *entry += 1; - } - // has subagg - if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() { + } else { + let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref() + else { + return Err(TantivyError::InternalError( + "Could not find sub-aggregation blueprint".to_string(), + )); + }; for (doc, term_id) in req_data .column_block_accessor .iter_docid_vals(docs, &req_data.accessor) @@ -432,12 +780,11 @@ impl SegmentAggregationCollector for SegmentTermCollector { continue; } } - let sub_aggregations = self + let bucket = self .term_buckets - .sub_aggs - .entry(term_id) - .or_insert_with(|| blueprint.clone()); - sub_aggregations.collect(doc, agg_data)?; + .term_entry(term_id, sub_aggregation_blueprint); + bucket.count += 1; + bucket.sub_agg.collect(doc, agg_data)?; } } @@ -453,69 +800,51 @@ impl SegmentAggregationCollector for SegmentTermCollector { Ok(()) } + #[inline(always)] fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { - self.term_buckets.force_flush(agg_data)?; + self.term_buckets.flush(agg_data)?; Ok(()) } } -impl SegmentTermCollector { +/// Missing value are represented as a sentinel value in the column. +/// +/// This function extracts the missing value from the entries vector, +/// computes the intermediate key, and returns it the key and the bucket +/// in an Option. +fn extract_missing_value( + entries: &mut Vec<(u64, T)>, + term_req: &TermsAggReqData, +) -> Option<(IntermediateKey, T)> { + let missing_sentinel = term_req.missing_value_for_accessor?; + let missing_value_entry_pos = entries + .iter() + .position(|(term_id, _)| *term_id == missing_sentinel)?; + let (_term_id, bucket) = entries.swap_remove(missing_value_entry_pos); + let missing_key = term_req.req.missing.as_ref()?; + let key = match missing_key { + Key::Str(missing) => IntermediateKey::Str(missing.clone()), + Key::F64(val) => IntermediateKey::F64(*val), + Key::U64(val) => IntermediateKey::U64(*val), + Key::I64(val) => IntermediateKey::I64(*val), + }; + Some((key, bucket)) +} + +impl SegmentTermCollector +where TermMap: TermAggregationMap +{ fn get_memory_consumption(&self) -> usize { - let self_mem = std::mem::size_of::(); - let term_buckets_mem = self.term_buckets.get_memory_consumption(); - self_mem + term_buckets_mem - } - - pub(crate) fn from_req_and_validate( - req_data: &mut AggregationsSegmentCtx, - node: &AggRefNode, - ) -> crate::Result { - let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); - let column_type = terms_req_data.column_type; - let accessor_idx = node.idx_in_req_data; - if column_type == ColumnType::Bytes { - return Err(TantivyError::InvalidArgument(format!( - "terms aggregation is not supported for column type {column_type:?}" - ))); - } - let term_buckets = TermBuckets::default(); - - // Validate sub aggregation exists - if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target { - let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name); - - node.get_sub_agg(agg_name, &req_data.per_request) - .ok_or_else(|| { - TantivyError::InvalidArgument(format!( - "could not find aggregation with name {agg_name} in metric \ - sub_aggregations" - )) - })?; - } - - let has_sub_aggregations = !node.children.is_empty(); - let blueprint = if has_sub_aggregations { - let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; - Some(sub_aggregation) - } else { - None - }; - let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data); - terms_req_data.sub_aggregation_blueprint = blueprint; - - Ok(SegmentTermCollector { - term_buckets, - accessor_idx, - }) + self.term_buckets.get_memory_consumption() } #[inline] pub(crate) fn into_intermediate_bucket_result( - mut self, + self, agg_data: &AggregationsSegmentCtx, ) -> crate::Result { let term_req = agg_data.get_term_req_data(self.accessor_idx); - let mut entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect(); + let mut entries: Vec<(u64, Bucket)> = self.term_buckets.into_vec(); let order_by_sub_aggregation = matches!(term_req.req.order.target, OrderTarget::SubAggregation(_)); @@ -538,9 +867,9 @@ impl SegmentTermCollector { } OrderTarget::Count => { if term_req.req.order.order == Order::Desc { - entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1)); + entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1.count)); } else { - entries.sort_unstable_by_key(|bucket| bucket.1); + entries.sort_unstable_by_key(|bucket| bucket.1.count); } } } @@ -554,25 +883,20 @@ impl SegmentTermCollector { let mut dict: FxHashMap = Default::default(); dict.reserve(entries.len()); - let mut into_intermediate_bucket_entry = - |id, doc_count| -> crate::Result { + let into_intermediate_bucket_entry = + |bucket: Bucket| -> crate::Result { let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() { let mut sub_aggregation_res = IntermediateAggregationResults::default(); - self.term_buckets - .sub_aggs - .remove(&id) - .unwrap_or_else(|| { - panic!("Internal Error: could not find subaggregation for id {id}") - }) + // TODO remove box new + Box::new(bucket.sub_agg) .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; - IntermediateTermBucketEntry { - doc_count, + doc_count: bucket.count, sub_aggregation: sub_aggregation_res, } } else { IntermediateTermBucketEntry { - doc_count, + doc_count: bucket.count, sub_aggregation: Default::default(), } }; @@ -586,62 +910,32 @@ impl SegmentTermCollector { .as_ref() .map(|el| el.dictionary()) .unwrap_or_else(|| &fallback_dict); - let mut buffer = Vec::new(); - // special case for missing key - if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) { - let entry = entries[index]; - let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?; - let missing_key = term_req - .req - .missing - .as_ref() - .expect("Found placeholder term_id but `missing` is None"); - match missing_key { - Key::Str(missing) => { - buffer.clear(); - buffer.extend_from_slice(missing.as_bytes()); - dict.insert( - IntermediateKey::Str( - String::from_utf8(buffer.to_vec()) - .expect("could not convert to String"), - ), - intermediate_entry, - ); - } - Key::F64(val) => { - dict.insert(IntermediateKey::F64(*val), intermediate_entry); - } - Key::U64(val) => { - dict.insert(IntermediateKey::U64(*val), intermediate_entry); - } - Key::I64(val) => { - dict.insert(IntermediateKey::I64(*val), intermediate_entry); - } - } - - entries.swap_remove(index); + if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req) + { + let intermediate_entry = into_intermediate_bucket_entry(bucket)?; + dict.insert(intermediate_key, intermediate_entry); } // Sort by term ord entries.sort_unstable_by_key(|bucket| bucket.0); - let mut idx = 0; - term_dict.sorted_ords_to_term_cb( - entries.iter().map(|(term_id, _)| *term_id), - |term| { - let entry = entries[idx]; - let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1) - .map_err(io::Error::other)?; - dict.insert( - IntermediateKey::Str( - String::from_utf8(term.to_vec()).expect("could not convert to String"), - ), - intermediate_entry, - ); - idx += 1; - Ok(()) - }, - )?; + + let (term_ids, buckets): (Vec, Vec>) = + entries.into_iter().unzip(); + let mut buckets_it = buckets.into_iter(); + + term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| { + let bucket = buckets_it.next().unwrap(); + let intermediate_entry = + into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?; + dict.insert( + IntermediateKey::Str( + String::from_utf8(term.to_vec()).expect("could not convert to String"), + ), + intermediate_entry, + ); + Ok(()) + })?; if term_req.req.min_doc_count == 0 { // TODO: Handle rev streaming for descending sorting by keys @@ -675,14 +969,14 @@ impl SegmentTermCollector { } } else if term_req.column_type == ColumnType::DateTime { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } } else if term_req.column_type == ColumnType::Bool { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; let val = bool::from_u64(val); dict.insert(IntermediateKey::Bool(val), intermediate_entry); } @@ -702,14 +996,14 @@ impl SegmentTermCollector { })?; for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val = Ipv6Addr::from_u128(val); dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); } } else { for (val, doc_count) in entries { - let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let intermediate_entry = into_intermediate_bucket_entry(doc_count)?; if term_req.column_type == ColumnType::U64 { dict.insert(IntermediateKey::U64(val), intermediate_entry); } else if term_req.column_type == ColumnType::I64 { @@ -746,17 +1040,19 @@ impl SegmentTermCollector { pub(crate) trait GetDocCount { fn doc_count(&self) -> u64; } -impl GetDocCount for (u64, u32) { - fn doc_count(&self) -> u64 { - self.1 as u64 - } -} + impl GetDocCount for (String, IntermediateTermBucketEntry) { fn doc_count(&self) -> u64 { self.1.doc_count as u64 } } +impl GetDocCount for (u64, Bucket) { + fn doc_count(&self) -> u64 { + self.1.count as u64 + } +} + pub(crate) fn cut_off_buckets( entries: &mut Vec, num_elem: usize, @@ -1101,6 +1397,40 @@ mod tests { Ok(()) } + #[test] + fn test_simple_agg() { + let segment_and_terms = vec![vec![(5.0, "terma".to_string())]]; + let index = get_test_index_from_values_and_terms(true, &segment_and_terms).unwrap(); + + let sub_agg: Aggregations = serde_json::from_value(json!({ + "avg_score": { + "avg": { + "field": "score", + } + } + })) + .unwrap(); + + // sub agg desc + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { + "field": "string_id", + "order": { + "_count": "asc", + }, + }, + "aggs": sub_agg, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index).unwrap(); + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0); + } + #[test] fn terms_aggregation_test_order_sub_agg_single_segment() -> crate::Result<()> { terms_aggregation_test_order_sub_agg_merge_segment(true) diff --git a/src/aggregation/buf_collector.rs b/src/aggregation/buf_collector.rs index e34c84760..17bc1ed35 100644 --- a/src/aggregation/buf_collector.rs +++ b/src/aggregation/buf_collector.rs @@ -3,7 +3,12 @@ use super::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::DocId; +#[cfg(test)] pub(crate) const DOC_BLOCK_SIZE: usize = 64; + +#[cfg(not(test))] +pub(crate) const DOC_BLOCK_SIZE: usize = 256; + pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; /// BufAggregationCollector buffers documents before calling collect_block(). @@ -15,7 +20,7 @@ pub(crate) struct BufAggregationCollector { } impl std::fmt::Debug for BufAggregationCollector { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("SegmentAggregationResultsCollector") .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) .field("num_staged_docs", &self.num_staged_docs) @@ -66,7 +71,6 @@ impl SegmentAggregationCollector for BufAggregationCollector { agg_data: &mut AggregationsSegmentCtx, ) -> crate::Result<()> { self.collector.collect_block(docs, agg_data)?; - Ok(()) } diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index f5e047a7a..b7b1b2340 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -181,6 +181,7 @@ mod tests_mmap { let field_name_out = "."; test_json_field_name(field_name_in, field_name_out); } + #[test] fn test_json_field_dot() { // Test when field name contains a '.' diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 648dd9ed6..dae6d7fe7 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -101,7 +101,7 @@ impl TermQuery { EnableScoring::Enabled { statistics_provider, .. - } => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?, + } => Bm25Weight::for_terms(statistics_provider, std::slice::from_ref(&self.term))?, EnableScoring::Disabled { .. } => { Bm25Weight::new(Explanation::new("", 1.0f32), 1.0f32) }