Optimize term aggregation with low cardinality + some refactoring (#2740)

This introduce an optimization of top level term aggregation on field with a low cardinality.

We then use a Vec as the underlying map.
In addition, we buffer subaggregations.

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Paul Masurel <paul@quickwit.io>
This commit is contained in:
Paul Masurel
2025-11-21 14:46:29 +01:00
committed by GitHub
parent 70e591e230
commit c363bbd23d
9 changed files with 535 additions and 175 deletions

View File

@@ -59,6 +59,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
@@ -220,6 +222,19 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {

View File

@@ -16,15 +16,16 @@ use crate::index::SegmentReader;
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
column_max_value: u64,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}

View File

@@ -12,7 +12,7 @@ use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations
use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
@@ -373,9 +373,7 @@ pub(crate) fn build_segment_agg_collector(
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match node.kind {
AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node),
AggKind::MissingTerm => {
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
if req_data.accessors.is_empty() {
@@ -498,7 +496,7 @@ pub(crate) fn build_aggregations_data_from_req(
};
for (name, agg) in aggs.iter() {
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?;
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?;
data.per_request.agg_tree.extend(nodes);
}
Ok(data)
@@ -510,6 +508,7 @@ fn build_nodes(
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
use AggregationVariants::*;
match &req.agg {
@@ -596,6 +595,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Terms(terms_req.clone()),
is_top_level,
),
Cardinality(card_req) => build_terms_or_cardinality_nodes(
agg_name,
@@ -606,6 +606,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Cardinality(card_req.clone()),
is_top_level,
),
Average(AverageAggregation { field, missing, .. })
| Max(MaxAggregation { field, missing, .. })
@@ -734,7 +735,7 @@ fn build_nodes(
// Build the query and evaluator upfront
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(&schema, tokenizers)?;
let query = filter_req.parse_query(schema, tokenizers)?;
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
@@ -771,7 +772,14 @@ fn build_children(
) -> crate::Result<Vec<AggRefNode>> {
let mut children = Vec::new();
for (name, agg) in aggs.iter() {
children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?);
children.extend(build_nodes(
name,
agg,
reader,
segment_ordinal,
data,
false,
)?);
}
Ok(children)
}
@@ -835,6 +843,7 @@ fn build_terms_or_cardinality_nodes(
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: TermsOrCardinalityRequest,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
let mut nodes = Vec::new();
@@ -891,7 +900,7 @@ fn build_terms_or_cardinality_nodes(
let missing_value_for_accessor = if use_special_missing_agg {
None
} else if let Some(m) = missing.as_ref() {
get_missing_val_as_u64_lenient(column_type, m, field_name)?
get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)?
} else {
None
};
@@ -924,6 +933,7 @@ fn build_terms_or_cardinality_nodes(
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
});
(idx_in_req_data, AggKind::Terms)
}

View File

@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
/// Allocated memory with this guard.
allocated_with_the_guard: u64,
}
impl Clone for AggregationLimitsGuard {
fn clone(&self) -> Self {
Self {

View File

@@ -639,16 +639,14 @@ pub struct IntermediateFilterBucketResult {
#[cfg(test)]
mod tests {
use std::time::Instant;
use serde_json::{json, Value};
use super::*;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::{AggContextParams, AggregationCollector};
use crate::query::{AllQuery, QueryParser, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT};
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT};
use crate::{doc, Index, IndexWriter};
// Test helper functions

View File

@@ -17,6 +17,7 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::buf_collector::BufAggregationCollector;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
@@ -49,6 +50,8 @@ pub struct TermsAggReqData {
pub req: TermsAggregationInternal,
/// Preloaded allowed term ords (string columns only). If set, only ords present are collected.
pub allowed_term_ids: Option<BitSet>,
/// True if this terms aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl TermsAggReqData {
@@ -331,34 +334,371 @@ impl TermsAggregationInternal {
}
}
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, u32>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
impl<'a> From<&'a dyn SegmentAggregationCollector> for BufAggregationCollector {
#[inline(always)]
fn from(sub_agg_blueprint_opt: &'a dyn SegmentAggregationCollector) -> Self {
let sub_agg = sub_agg_blueprint_opt.clone_box();
BufAggregationCollector::new(sub_agg)
}
}
impl TermBuckets {
fn get_memory_consumption(&self) -> usize {
let sub_aggs_mem = self.sub_aggs.memory_consumption();
let buckets_mem = self.entries.memory_consumption();
sub_aggs_mem + buckets_mem
#[derive(Debug, Clone)]
struct BoxedAggregation(Box<dyn SegmentAggregationCollector>);
impl<'a> From<&'a dyn SegmentAggregationCollector> for BoxedAggregation {
#[inline(always)]
fn from(sub_agg_blueprint: &'a dyn SegmentAggregationCollector) -> Self {
BoxedAggregation(sub_agg_blueprint.clone_box())
}
}
impl SegmentAggregationCollector for BoxedAggregation {
#[inline(always)]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
self.0
.add_intermediate_aggregation_result(agg_data, results)
}
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
#[inline(always)]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.0.collect(doc, agg_data)
}
#[inline(always)]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.0.collect_block(docs, agg_data)
}
}
#[derive(Debug, Clone, Copy)]
struct NoSubAgg;
impl SegmentAggregationCollector for NoSubAgg {
#[inline(always)]
fn add_intermediate_aggregation_result(
self: Box<Self>,
_agg_data: &AggregationsSegmentCtx,
_results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Ok(())
}
#[inline(always)]
fn collect(
&mut self,
_doc: crate::DocId,
_agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
Ok(())
}
#[inline(always)]
fn collect_block(
&mut self,
_docs: &[crate::DocId],
_agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
Ok(())
}
}
/// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_term_collector(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let column_type = {
let terms_req_data = req_data.get_term_req_data(accessor_idx);
terms_req_data.column_type
};
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
// Validate sub aggregation exists when ordering by sub-aggregation.
{
let terms_req_data = req_data.get_term_req_data(accessor_idx);
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
}
// Build sub-aggregation blueprint if there are children.
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
{
let terms_req_data_mut = req_data.get_term_req_data_mut(accessor_idx);
terms_req_data_mut.sub_aggregation_blueprint = blueprint;
}
// Decide whether to use a Vec-backed or HashMap-backed bucket storage.
let terms_req_data = req_data.get_term_req_data(accessor_idx);
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If term agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can still use Vec.
let can_use_vec = terms_req_data.is_top_level;
// TODO: Benchmark to validate the threshold
const MAX_NUM_TERMS_FOR_VEC: usize = 100;
// Let's see if we can use a vec to aggregate our data
// instead of a hashmap.
let col_max_value = terms_req_data.accessor.max_value();
let max_term: usize =
col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64)) as usize;
// - use a Vec instead of a hashmap for our aggregation.
// - buffer aggregation of our child aggregations (in any)
#[allow(clippy::collapsible_else_if)]
if can_use_vec && max_term < MAX_NUM_TERMS_FOR_VEC {
if has_sub_aggregations {
let sub_agg_blueprint = &req_data
.get_term_req_data_mut(accessor_idx)
.sub_aggregation_blueprint
.as_ref()
.ok_or_else(|| {
// Handle the error case here
// For example, return an error message or a default value
TantivyError::InternalError("Sub-aggregation blueprint not found".to_string())
})?;
let term_buckets = VecTermBuckets::new(max_term + 1, || {
let collector_clone = sub_agg_blueprint.clone_box();
BufAggregationCollector::new(collector_clone)
});
let collector = SegmentTermCollector {
term_buckets,
accessor_idx,
};
Ok(Box::new(collector))
} else {
let term_buckets = VecTermBuckets::new(max_term + 1, || NoSubAgg);
let collector = SegmentTermCollector {
term_buckets,
accessor_idx,
};
Ok(Box::new(collector))
}
} else {
if has_sub_aggregations {
let term_buckets: HashMapTermBuckets<BoxedAggregation> = HashMapTermBuckets::default();
let collector: SegmentTermCollector<HashMapTermBuckets<BoxedAggregation>> =
SegmentTermCollector {
term_buckets,
accessor_idx,
};
Ok(Box::new(collector))
} else {
let term_buckets: HashMapTermBuckets<NoSubAgg> = HashMapTermBuckets::default();
let collector: SegmentTermCollector<HashMapTermBuckets<NoSubAgg>> =
SegmentTermCollector {
term_buckets,
accessor_idx,
};
Ok(Box::new(collector))
}
}
}
#[derive(Debug, Clone)]
struct Bucket<SubAgg> {
pub count: u32,
pub sub_agg: SubAgg,
}
impl<SubAgg> Bucket<SubAgg> {
#[inline(always)]
fn new(sub_agg: SubAgg) -> Self {
Self { count: 0, sub_agg }
}
}
/// Abstraction over the storage used for term buckets (counts only).
trait TermAggregationMap: Clone + Debug + 'static {
type SubAggregation: SegmentAggregationCollector + Debug + Clone + 'static;
/// Estimate the memory consumption of this struct in bytes.
fn get_memory_consumption(&self) -> usize;
/// Returns the bucket assocaited to a given term_id.
fn term_entry(
&mut self,
term_id: u64,
blue_print: &dyn SegmentAggregationCollector,
) -> &mut Bucket<Self::SubAggregation>;
/// If the tree of aggregations contains buffered aggregations, flush them.
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()>;
/// Returns the term aggregation as a vector of (term_id, bucket) pairs,
/// in any order.
fn into_vec(self) -> Vec<(u64, Bucket<Self::SubAggregation>)>;
}
#[derive(Clone, Debug)]
struct HashMapTermBuckets<SubAgg> {
bucket_map: FxHashMap<u64, Bucket<SubAgg>>,
}
impl<SubAgg> Default for HashMapTermBuckets<SubAgg> {
#[inline(always)]
fn default() -> Self {
Self {
bucket_map: FxHashMap::default(),
}
}
}
impl<
SubAgg: Debug
+ Clone
+ SegmentAggregationCollector
+ for<'a> From<&'a dyn SegmentAggregationCollector>
+ 'static,
> TermAggregationMap for HashMapTermBuckets<SubAgg>
{
type SubAggregation = SubAgg;
#[inline]
fn get_memory_consumption(&self) -> usize {
self.bucket_map.memory_consumption()
}
#[inline(always)]
fn term_entry(
&mut self,
term_id: u64,
sub_agg_blueprint: &dyn SegmentAggregationCollector,
) -> &mut Bucket<SubAgg> {
self.bucket_map
.entry(term_id)
.or_insert_with(|| Bucket::new(SubAgg::from(sub_agg_blueprint)))
}
#[inline(always)]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.bucket_map.values_mut() {
bucket.sub_agg.flush(agg_data)?;
}
Ok(())
}
fn into_vec(self) -> Vec<(u64, Bucket<SubAgg>)> {
self.bucket_map.into_iter().collect()
}
}
/// An optimized term map implementation for a compact set of term ordinals.
#[derive(Clone, Debug)]
struct VecTermBuckets<SubAgg> {
buckets: Vec<Bucket<SubAgg>>,
}
impl<SubAgg> VecTermBuckets<SubAgg> {
fn new(num_terms: usize, item_factory_fn: impl Fn() -> SubAgg) -> Self {
VecTermBuckets {
buckets: std::iter::repeat_with(item_factory_fn)
.map(Bucket::new)
.take(num_terms)
.collect(),
}
}
}
impl<SubAgg: Debug + Clone + SegmentAggregationCollector + 'static> TermAggregationMap
for VecTermBuckets<SubAgg>
{
type SubAggregation = SubAgg;
/// Estimate the memory consumption of this struct in bytes.
fn get_memory_consumption(&self) -> usize {
// We do not include `std::mem::size_of::<Self>()`
// It is already measure by the parent aggregation.
//
// The root aggregation mem size is not measure but we do not care.
self.buckets.capacity() * std::mem::size_of::<Bucket<SubAgg>>()
}
/// Add an occurrence of the given term id.
#[inline(always)]
fn term_entry(
&mut self,
term_id: u64,
_sub_agg_blueprint: &dyn SegmentAggregationCollector,
) -> &mut Bucket<SubAgg> {
let term_id_usize = term_id as usize;
debug_assert!(
term_id_usize < self.buckets.len(),
"term_id {} out of bounds for VecTermBuckets (len={})",
term_id,
self.buckets.len()
);
unsafe { self.buckets.get_unchecked_mut(term_id_usize) }
}
#[inline(always)]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in &mut self.buckets {
if bucket.count > 0 {
bucket.sub_agg.flush(agg_data)?;
}
}
Ok(())
}
fn into_vec(self) -> Vec<(u64, Bucket<SubAgg>)> {
self.buckets
.into_iter()
.enumerate()
.filter(|(_, bucket)| bucket.count > 0)
.map(|(term_id, bucket)| (term_id as u64, bucket))
.collect()
}
}
impl<'a> From<&'a dyn SegmentAggregationCollector> for NoSubAgg {
#[inline(always)]
fn from(_: &'a dyn SegmentAggregationCollector) -> Self {
Self
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentTermCollector {
struct SegmentTermCollector<TermMap> {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
term_buckets: TermMap,
accessor_idx: usize,
}
@@ -367,17 +707,19 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
(agg_name, agg_property)
}
impl SegmentAggregationCollector for SegmentTermCollector {
impl<TermMap> SegmentAggregationCollector for SegmentTermCollector<TermMap>
where
TermMap: TermAggregationMap,
TermMap::SubAggregation: for<'a> From<&'a dyn SegmentAggregationCollector>,
{
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
@@ -412,17 +754,23 @@ impl SegmentAggregationCollector for SegmentTermCollector {
.fetch_block(docs, &req_data.accessor);
}
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
if std::any::TypeId::of::<NoSubAgg>() == std::any::TypeId::of::<TermMap::SubAggregation>() {
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let bucket = self.term_buckets.term_entry(term_id, &NoSubAgg);
bucket.count += 1;
}
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
} else {
let Some(sub_aggregation_blueprint) = req_data.sub_aggregation_blueprint.as_deref()
else {
return Err(TantivyError::InternalError(
"Could not find sub-aggregation blueprint".to_string(),
));
};
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
@@ -432,12 +780,11 @@ impl SegmentAggregationCollector for SegmentTermCollector {
continue;
}
}
let sub_aggregations = self
let bucket = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(doc, agg_data)?;
.term_entry(term_id, sub_aggregation_blueprint);
bucket.count += 1;
bucket.sub_agg.collect(doc, agg_data)?;
}
}
@@ -453,69 +800,51 @@ impl SegmentAggregationCollector for SegmentTermCollector {
Ok(())
}
#[inline(always)]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.term_buckets.force_flush(agg_data)?;
self.term_buckets.flush(agg_data)?;
Ok(())
}
}
impl SegmentTermCollector {
/// Missing value are represented as a sentinel value in the column.
///
/// This function extracts the missing value from the entries vector,
/// computes the intermediate key, and returns it the key and the bucket
/// in an Option.
fn extract_missing_value<T>(
entries: &mut Vec<(u64, T)>,
term_req: &TermsAggReqData,
) -> Option<(IntermediateKey, T)> {
let missing_sentinel = term_req.missing_value_for_accessor?;
let missing_value_entry_pos = entries
.iter()
.position(|(term_id, _)| *term_id == missing_sentinel)?;
let (_term_id, bucket) = entries.swap_remove(missing_value_entry_pos);
let missing_key = term_req.req.missing.as_ref()?;
let key = match missing_key {
Key::Str(missing) => IntermediateKey::Str(missing.clone()),
Key::F64(val) => IntermediateKey::F64(*val),
Key::U64(val) => IntermediateKey::U64(*val),
Key::I64(val) => IntermediateKey::I64(*val),
};
Some((key, bucket))
}
impl<TermMap> SegmentTermCollector<TermMap>
where TermMap: TermAggregationMap
{
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data;
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
let term_buckets = TermBuckets::default();
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(SegmentTermCollector {
term_buckets,
accessor_idx,
})
self.term_buckets.get_memory_consumption()
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
mut self,
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let term_req = agg_data.get_term_req_data(self.accessor_idx);
let mut entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
let mut entries: Vec<(u64, Bucket<TermMap::SubAggregation>)> = self.term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
@@ -538,9 +867,9 @@ impl SegmentTermCollector {
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1.count));
} else {
entries.sort_unstable_by_key(|bucket| bucket.1);
entries.sort_unstable_by_key(|bucket| bucket.1.count);
}
}
}
@@ -554,25 +883,20 @@ impl SegmentTermCollector {
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let into_intermediate_bucket_entry =
|bucket: Bucket<TermMap::SubAggregation>| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.term_buckets
.sub_aggs
.remove(&id)
.unwrap_or_else(|| {
panic!("Internal Error: could not find subaggregation for id {id}")
})
// TODO remove box new
Box::new(bucket.sub_agg)
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
IntermediateTermBucketEntry {
doc_count,
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
}
} else {
IntermediateTermBucketEntry {
doc_count,
doc_count: bucket.count,
sub_aggregation: Default::default(),
}
};
@@ -586,62 +910,32 @@ impl SegmentTermCollector {
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
// special case for missing key
if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) {
let entry = entries[index];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?;
let missing_key = term_req
.req
.missing
.as_ref()
.expect("Found placeholder term_id but `missing` is None");
match missing_key {
Key::Str(missing) => {
buffer.clear();
buffer.extend_from_slice(missing.as_bytes());
dict.insert(
IntermediateKey::Str(
String::from_utf8(buffer.to_vec())
.expect("could not convert to String"),
),
intermediate_entry,
);
}
Key::F64(val) => {
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
}
Key::U64(val) => {
dict.insert(IntermediateKey::U64(*val), intermediate_entry);
}
Key::I64(val) => {
dict.insert(IntermediateKey::I64(*val), intermediate_entry);
}
}
entries.swap_remove(index);
if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(bucket)?;
dict.insert(intermediate_key, intermediate_entry);
}
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
let mut idx = 0;
term_dict.sorted_ords_to_term_cb(
entries.iter().map(|(term_id, _)| *term_id),
|term| {
let entry = entries[idx];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)
.map_err(io::Error::other)?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
idx += 1;
Ok(())
},
)?;
let (term_ids, buckets): (Vec<u64>, Vec<Bucket<TermMap::SubAggregation>>) =
entries.into_iter().unzip();
let mut buckets_it = buckets.into_iter();
term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| {
let bucket = buckets_it.next().unwrap();
let intermediate_entry =
into_intermediate_bucket_entry(bucket).map_err(io::Error::other)?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
Ok(())
})?;
if term_req.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
@@ -675,14 +969,14 @@ impl SegmentTermCollector {
}
} else if term_req.column_type == ColumnType::DateTime {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count)?;
let val = i64::from_u64(val);
let date = format_date(val)?;
dict.insert(IntermediateKey::Str(date), intermediate_entry);
}
} else if term_req.column_type == ColumnType::Bool {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count)?;
let val = bool::from_u64(val);
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
}
@@ -702,14 +996,14 @@ impl SegmentTermCollector {
})?;
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count)?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
dict.insert(IntermediateKey::IpAddr(val), intermediate_entry);
}
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count)?;
if term_req.column_type == ColumnType::U64 {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
} else if term_req.column_type == ColumnType::I64 {
@@ -746,17 +1040,19 @@ impl SegmentTermCollector {
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u64, u32) {
fn doc_count(&self) -> u64 {
self.1 as u64
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count as u64
}
}
impl<SubAgg> GetDocCount for (u64, Bucket<SubAgg>) {
fn doc_count(&self) -> u64 {
self.1.count as u64
}
}
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
@@ -1101,6 +1397,40 @@ mod tests {
Ok(())
}
#[test]
fn test_simple_agg() {
let segment_and_terms = vec![vec![(5.0, "terma".to_string())]];
let index = get_test_index_from_values_and_terms(true, &segment_and_terms).unwrap();
let sub_agg: Aggregations = serde_json::from_value(json!({
"avg_score": {
"avg": {
"field": "score",
}
}
}))
.unwrap();
// sub agg desc
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": {
"_count": "asc",
},
},
"aggs": sub_agg,
}
}))
.unwrap();
let res = exec_request(agg_req, &index).unwrap();
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0);
}
#[test]
fn terms_aggregation_test_order_sub_agg_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_sub_agg_merge_segment(true)

View File

@@ -3,7 +3,12 @@ use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
@@ -15,7 +20,7 @@ pub(crate) struct BufAggregationCollector {
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
@@ -66,7 +71,6 @@ impl SegmentAggregationCollector for BufAggregationCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}

View File

@@ -181,6 +181,7 @@ mod tests_mmap {
let field_name_out = ".";
test_json_field_name(field_name_in, field_name_out);
}
#[test]
fn test_json_field_dot() {
// Test when field name contains a '.'

View File

@@ -101,7 +101,7 @@ impl TermQuery {
EnableScoring::Enabled {
statistics_provider,
..
} => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?,
} => Bm25Weight::for_terms(statistics_provider, std::slice::from_ref(&self.term))?,
EnableScoring::Disabled { .. } => {
Bm25Weight::new(Explanation::new("<no score>", 1.0f32), 1.0f32)
}