low cardinality optimisation

This commit is contained in:
Paul Masurel
2025-11-19 18:41:10 +01:00
parent 70e591e230
commit b2573a3b16
6 changed files with 697 additions and 420 deletions

View File

@@ -10,10 +10,10 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
build_segment_aggregation_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
SegmentHistogramCollector, SegmentRangeCollector, TermMissingAgg, TermsAggReqData,
TermsAggregation, TermsAggregationInternal,
};
use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
@@ -373,9 +373,7 @@ pub(crate) fn build_segment_agg_collector(
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match node.kind {
AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Terms => build_segment_aggregation_collector(req, node),
AggKind::MissingTerm => {
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
if req_data.accessors.is_empty() {

View File

@@ -0,0 +1,196 @@
use std::fmt::Debug;
use columnar::ColumnType;
use rustc_hash::FxHashMap;
use super::OrderTarget;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::bucket::get_agg_name_and_property;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::TantivyError;
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, u32>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}
impl TermBuckets {
fn get_memory_consumption(&self) -> usize {
let sub_aggs_mem = self.sub_aggs.memory_consumption();
let buckets_mem = self.entries.memory_consumption();
sub_aggs_mem + buckets_mem
}
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentTermCollector {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
accessor_idx: usize,
}
impl SegmentAggregationCollector for SegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
let bucket = super::into_intermediate_bucket_result(
self.accessor_idx,
entries,
self.term_buckets.sub_aggs,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(doc, agg_data)?;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.term_buckets.force_flush(agg_data)?;
Ok(())
}
}
impl SegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data;
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
let term_buckets = TermBuckets::default();
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(SegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}

View File

@@ -0,0 +1,228 @@
use std::vec;
use rustc_hash::FxHashMap;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::{get_agg_name_and_property, OrderTarget};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::{DocId, TantivyError};
const MAX_BATCH_SIZE: usize = 1_024;
#[derive(Debug, Clone)]
struct LowCardTermBuckets {
entries: Box<[u32]>,
sub_aggs: Vec<Box<dyn SegmentAggregationCollector>>,
doc_buffers: Box<[Vec<DocId>]>,
}
impl LowCardTermBuckets {
pub fn with_num_buckets(
num_buckets: usize,
sub_aggs_blueprint_opt: Option<&Box<dyn SegmentAggregationCollector>>,
) -> Self {
let sub_aggs = sub_aggs_blueprint_opt
.as_ref()
.map(|blueprint| {
std::iter::repeat_with(|| blueprint.clone_box())
.take(num_buckets)
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self {
entries: vec![0; num_buckets].into_boxed_slice(),
sub_aggs,
doc_buffers: std::iter::repeat_with(|| Vec::with_capacity(MAX_BATCH_SIZE))
.take(num_buckets)
.collect::<Vec<_>>()
.into_boxed_slice(),
}
}
fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.entries.len() * std::mem::size_of::<u32>()
+ self.doc_buffers.len()
* (std::mem::size_of::<Vec<DocId>>()
+ std::mem::size_of::<DocId>() * MAX_BATCH_SIZE)
}
}
#[derive(Debug, Clone)]
pub struct LowCardSegmentTermCollector {
term_buckets: LowCardTermBuckets,
accessor_idx: usize,
}
impl LowCardSegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let accessor_idx = node.idx_in_req_data;
let cardinality = terms_req_data
.accessor
.max_value()
.max(terms_req_data.missing_value_for_accessor.unwrap_or(0))
+ 1;
assert!(cardinality <= super::LOW_CARDINALITY_THRESHOLD);
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
let term_buckets =
LowCardTermBuckets::with_num_buckets(cardinality as usize, blueprint.as_ref());
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(LowCardSegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}
impl SegmentAggregationCollector for LowCardSegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>> = self
.term_buckets
.sub_aggs
.into_iter()
.enumerate()
.filter(|(bucket_id, _sub_agg)| self.term_buckets.entries[*bucket_id] > 0)
.map(|(bucket_id, sub_agg)| (bucket_id as u64, sub_agg))
.collect();
let entries: Vec<(u64, u32)> = self
.term_buckets
.entries
.iter()
.enumerate()
.filter(|(_, count)| **count > 0)
.map(|(bucket_id, count)| (bucket_id as u64, *count))
.collect();
let bucket =
super::into_intermediate_bucket_result(self.accessor_idx, entries, sub_aggs, agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.len() > MAX_BATCH_SIZE {
for batch in docs.chunks(MAX_BATCH_SIZE) {
self.collect_block(batch, agg_data)?;
}
}
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
// has subagg
if req_data.sub_aggregation_blueprint.is_some() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.doc_buffers[term_id as usize].push(doc);
}
for (bucket_id, docs) in self.term_buckets.doc_buffers.iter_mut().enumerate() {
self.term_buckets.entries[bucket_id] += docs.len() as u32;
self.term_buckets.sub_aggs[bucket_id].collect_block(&docs[..], agg_data)?;
docs.clear();
}
} else {
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.entries[term_id as usize] += 1;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.term_buckets.sub_aggs.iter_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}

View File

@@ -1,3 +1,6 @@
mod default_impl;
mod low_cardinality_impl;
use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr;
@@ -12,20 +15,24 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::{CustomOrder, Order, OrderTarget};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_data::{AggRefNode, AggregationsSegmentCtx};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::bucket::term_agg::default_impl::SegmentTermCollector;
use crate::aggregation::bucket::term_agg::low_cardinality_impl::LowCardSegmentTermCollector;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
IntermediateAggregationResults, IntermediateBucketResult, IntermediateKey,
IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{format_date, Key};
use crate::error::DataCorruption;
use crate::TantivyError;
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
let (agg_name, agg_property) = name.split_once('.').unwrap_or((name, ""));
(agg_name, agg_property)
}
/// Contains all information required by the SegmentTermCollector to perform the
/// terms aggregation on a segment.
pub struct TermsAggReqData {
@@ -331,191 +338,76 @@ impl TermsAggregationInternal {
}
}
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, u32>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}
const LOW_CARDINALITY_THRESHOLD: u64 = 10;
impl TermBuckets {
fn get_memory_consumption(&self) -> usize {
let sub_aggs_mem = self.sub_aggs.memory_consumption();
let buckets_mem = self.entries.memory_consumption();
sub_aggs_mem + buckets_mem
}
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentTermCollector {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
accessor_idx: usize,
}
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
let (agg_name, agg_property) = name.split_once('.').unwrap_or((name, ""));
(agg_name, agg_property)
}
impl SegmentAggregationCollector for SegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(doc, agg_data)?;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.term_buckets.force_flush(agg_data)?;
Ok(())
}
}
impl SegmentTermCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
pub(crate) fn build_segment_aggregation_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let terms_req_data = req.get_term_req_data(node.idx_in_req_data);
let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data;
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
let term_buckets = TermBuckets::default();
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
let cardinality = terms_req_data
.accessor
.max_value()
.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64))
.saturating_add(1);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
if cardinality <= LOW_CARDINALITY_THRESHOLD {
Ok(Box::new(
LowCardSegmentTermCollector::from_req_and_validate(req, node)?,
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(SegmentTermCollector {
term_buckets,
accessor_idx,
})
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
req, node,
)?))
}
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
mut self,
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u64, u32) {
fn doc_count(&self) -> u64 {
self.1 as u64
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count as u64
}
}
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
) -> (u64, u64) {
let term_doc_count_before_cutoff = entries
.get(num_elem)
.map(|entry| entry.doc_count())
.unwrap_or(0);
let sum_other_doc_count = entries
.get(num_elem..)
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
.unwrap_or(0);
entries.truncate(num_elem);
(term_doc_count_before_cutoff, sum_other_doc_count)
}
fn into_intermediate_bucket_result(
accessor_idx: usize,
mut entries: Vec<(u64, u32)>,
mut sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let term_req = agg_data.get_term_req_data(self.accessor_idx);
let mut entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
let term_req = agg_data.get_term_req_data(accessor_idx);
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
@@ -558,8 +450,7 @@ impl SegmentTermCollector {
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if term_req.sub_aggregation_blueprint.as_ref().is_some() {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
self.term_buckets
.sub_aggs
sub_aggs
.remove(&id)
.unwrap_or_else(|| {
panic!("Internal Error: could not find subaggregation for id {id}")
@@ -626,12 +517,10 @@ impl SegmentTermCollector {
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
let mut idx = 0;
term_dict.sorted_ords_to_term_cb(
entries.iter().map(|(term_id, _)| *term_id),
|term| {
term_dict.sorted_ords_to_term_cb(entries.iter().map(|(term_id, _)| *term_id), |term| {
let entry = entries[idx];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)
.map_err(io::Error::other)?;
let intermediate_entry =
into_intermediate_bucket_entry(entry.0, entry.1).map_err(io::Error::other)?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
@@ -640,8 +529,7 @@ impl SegmentTermCollector {
);
idx += 1;
Ok(())
},
)?;
})?;
if term_req.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
@@ -693,12 +581,9 @@ impl SegmentTermCollector {
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
for (val, doc_count) in entries {
@@ -741,39 +626,6 @@ impl SegmentTermCollector {
},
})
}
}
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u64, u32) {
fn doc_count(&self) -> u64 {
self.1 as u64
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count as u64
}
}
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
) -> (u64, u64) {
let term_doc_count_before_cutoff = entries
.get(num_elem)
.map(|entry| entry.doc_count())
.unwrap_or(0);
let sum_other_doc_count = entries
.get(num_elem..)
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
.unwrap_or(0);
entries.truncate(num_elem);
(term_doc_count_before_cutoff, sum_other_doc_count)
}
#[cfg(test)]
mod tests {

View File

@@ -17,11 +17,14 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
results: &mut IntermediateAggregationResults,
) -> crate::Result<()>;
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,