Compare commits

...

1 Commits

Author SHA1 Message Date
trinity.pointard
d7f69903fa first swab at multi-terms impl 2026-06-26 21:48:18 +00:00
7 changed files with 1466 additions and 6 deletions

View File

@@ -81,6 +81,11 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, composite_histogram);
register!(group, composite_histogram_calendar);
// multi_terms aggregation benchmarks
register!(group, multi_terms_status_with_zipf_1000);
register!(group, multi_terms_zipf_1000_with_status);
register!(group, multi_terms_status_with_zipf_1000_sub_agg);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
@@ -568,6 +573,58 @@ fn composite_histogram_calendar(index: &Index) {
execute_agg(index, agg_req);
}
/// multi_terms equivalent of terms_status_with_terms_zipf_1000_sub_agg:
/// flat GroupBy(status, zipf_1000) vs nested terms(status) -> terms(zipf_1000)
fn multi_terms_status_with_zipf_1000(index: &Index) {
let agg_req = json!({
"mt": {
"multi_terms": {
"terms": [
{"field": "text_few_terms_status"},
{"field": "text_1000_terms_zipf"}
],
"size": 10
}
}
});
execute_agg(index, agg_req);
}
/// multi_terms equivalent of terms_zipf_1000_with_terms_status_sub_agg:
/// flat GroupBy(zipf_1000, status) vs nested terms(zipf_1000) -> terms(status)
fn multi_terms_zipf_1000_with_status(index: &Index) {
let agg_req = json!({
"mt": {
"multi_terms": {
"terms": [
{"field": "text_1000_terms_zipf"},
{"field": "text_few_terms_status"}
],
"size": 100
}
}
});
execute_agg(index, agg_req);
}
/// multi_terms on the same field pair as the nested benchmarks, with an avg sub-aggregation
fn multi_terms_status_with_zipf_1000_sub_agg(index: &Index) {
let agg_req = json!({
"mt": {
"multi_terms": {
"terms": [
{"field": "text_few_terms_status"},
{"field": "text_1000_terms_zipf"}
]
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
}
});
execute_agg(index, agg_req);
}
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
let collector = get_collector(agg_req);

View File

@@ -13,8 +13,9 @@ use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_histogram_collector,
build_segment_range_collector, CompositeAggReqData, CompositeAggregation,
CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, TermMissingAgg, TermsAggReqData,
TermsAggregation, TermsAggregationInternal,
IncludeExcludeParam, MissingTermAggReqData, MultiTermsAggReqData, MultiTermsAggregation,
MultiTermsFieldAccessors, RangeAggReqData, SegmentMultiTermsCollector, TermMissingAgg,
TermsAggReqData, TermsAggregation, TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
@@ -76,6 +77,10 @@ impl AggregationsSegmentCtx {
self.per_request.composite_req_data.push(data);
self.per_request.composite_req_data.len() - 1
}
pub(crate) fn push_multi_terms_req_data(&mut self, data: MultiTermsAggReqData) -> usize {
self.per_request.multi_terms_req_data.push(data);
self.per_request.multi_terms_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -125,6 +130,8 @@ pub struct PerRequestAggSegCtx {
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<CompositeAggReqData>,
/// MultiTermsAggReqData contains the request data for a multi_terms aggregation.
pub multi_terms_req_data: Vec<MultiTermsAggReqData>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -177,6 +184,11 @@ impl PerRequestAggSegCtx {
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.multi_terms_req_data
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -194,6 +206,7 @@ impl PerRequestAggSegCtx {
AggKind::Range => self.range_req_data[idx].name.as_str(),
AggKind::Filter => self.filter_req_data[idx].name.as_str(),
AggKind::Composite => self.composite_req_data[idx].name.as_str(),
AggKind::MultiTerms => self.multi_terms_req_data[idx].name.as_str(),
}
}
@@ -347,6 +360,9 @@ pub(crate) fn build_segment_agg_collector(
req, node,
)?,
)),
AggKind::MultiTerms => Ok(Box::new(SegmentMultiTermsCollector::from_req_and_validate(
req, node,
)?)),
}
}
@@ -378,6 +394,7 @@ pub enum AggKind {
Range,
Filter,
Composite,
MultiTerms,
}
impl AggKind {
@@ -394,6 +411,7 @@ impl AggKind {
AggKind::Range => "Range",
AggKind::Filter => "Filter",
AggKind::Composite => "Composite",
AggKind::MultiTerms => "MultiTerms",
}
}
}
@@ -649,6 +667,14 @@ fn build_nodes(
&req.sub_aggregation,
composite_req,
)?]),
AggregationVariants::MultiTerms(multi_terms_req) => Ok(vec![build_multi_terms_node(
agg_name,
reader,
segment_ordinal,
data,
&req.sub_aggregation,
multi_terms_req,
)?]),
AggregationVariants::Filter(filter_req) => {
// Build the query and evaluator upfront
let schema = reader.schema();
@@ -707,6 +733,111 @@ fn build_composite_node(
})
}
fn build_multi_terms_node(
agg_name: &str,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: &MultiTermsAggregation,
) -> crate::Result<AggRefNode> {
use crate::aggregation::bucket::KeyElem;
if req.terms.is_empty() {
return Err(crate::TantivyError::InvalidArgument(
"multi_terms aggregation requires at least one field".to_string(),
));
}
let mut fields = Vec::with_capacity(req.terms.len());
for field_def in &req.terms {
let field_name = &field_def.field;
let str_dict_column = reader.fast_fields().str(field_name)?;
// Collect all columns for this field (handles JSON multi-type fields).
let columns = get_term_agg_accessors(reader, field_name, &field_def.missing)?;
// Precompute the missing KeyElem (or None -> drop combo).
let missing_key_elem = if let Some(missing) = &field_def.missing {
match missing {
Key::Str(missing_str) => {
match columns.iter().position(|(_, ct)| *ct == ColumnType::Str) {
Some(idx) => {
match str_dict_column
.as_ref()
.unwrap()
.dictionary()
.term_ord(missing_str.as_bytes())?
{
Some(ord) => Some(KeyElem::new(idx as u32, ord)),
None => Some(KeyElem::synthetic_missing()),
}
}
None => Some(KeyElem::synthetic_missing()),
}
}
_ => {
// Non-string missing: find the column whose type best matches the
// missing key. Prefer an exact-type match; fall back to any numeric
// column so cross-type coercions (e.g. Key::F64 on an I64 column)
// still work.
let preferred_type = match missing {
Key::F64(_) => ColumnType::F64,
Key::I64(_) => ColumnType::I64,
Key::U64(_) => ColumnType::U64,
Key::Str(_) => unreachable!("handled by Key::Str arm"),
};
let idx = columns
.iter()
.position(|(_, ct)| *ct == preferred_type)
.or_else(|| {
columns
.iter()
.position(|(_, ct)| ct.numerical_type().is_some())
});
match idx {
Some(idx) => {
let (col, col_type) = &columns[idx];
get_missing_val_as_u64_lenient(
*col_type,
col.max_value(),
missing,
field_name,
)?
.map(|sentinel| KeyElem::new(idx as u32, sentinel))
}
None => Some(KeyElem::synthetic_missing()),
}
}
}
} else {
None
};
fields.push(MultiTermsFieldAccessors {
columns,
str_dict_column,
missing: field_def.missing.clone(),
missing_key_elem,
field: field_name.clone(),
});
}
let idx = data.push_multi_terms_req_data(MultiTermsAggReqData {
name: agg_name.to_string(),
req: req.clone(),
fields,
sub_aggregations: sub_aggs.clone(),
});
let children = build_children(sub_aggs, reader, segment_ordinal, data)?;
Ok(AggRefNode {
kind: AggKind::MultiTerms,
idx_in_req_data: idx,
children,
})
}
fn build_children(
aggs: &Aggregations,
reader: &SegmentReader,

View File

@@ -33,7 +33,7 @@ use serde::{Deserialize, Serialize};
use super::bucket::{
CompositeAggregation, DateHistogramAggregationReq, FilterAggregation, HistogramAggregation,
RangeAggregation, TermsAggregation,
MultiTermsAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -202,6 +202,9 @@ pub enum AggregationVariants {
/// Multi-dimensional, paginable bucket aggregation.
#[serde(rename = "composite")]
Composite(CompositeAggregation),
/// Bucket aggregation over unique combinations of values across multiple term fields.
#[serde(rename = "multi_terms")]
MultiTerms(MultiTermsAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -253,6 +256,9 @@ impl AggregationVariants {
.iter()
.map(|source| source.field())
.collect(),
AggregationVariants::MultiTerms(mt) => {
mt.terms.iter().map(|t| t.field.as_str()).collect()
}
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -293,6 +299,12 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_multi_terms(&self) -> Option<&MultiTermsAggregation> {
match &self {
AggregationVariants::MultiTerms(mt) => Some(mt),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -152,12 +152,25 @@ pub enum BucketResult {
///
/// See [`TermsAggregation`](super::bucket::TermsAggregation)
buckets: Vec<BucketEntry>,
/// The number of documents that didnt make it into to TOP N due to shard_size or size
/// The number of documents that didn't make it into to TOP N due to shard_size or size
sum_other_doc_count: u64,
#[serde(skip_serializing_if = "Option::is_none")]
/// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>,
},
/// This is the multi_terms result -- placed AFTER Terms so that a zero-bucket result
/// deserializes as Terms (the more common case). Non-empty MultiTerms still deserializes
/// correctly because its array `key` fails Terms' scalar `key` check first. The only known
/// ambiguity is an empty MultiTerms result decoding as Terms (deserialization only).
MultiTerms {
/// The buckets (one per unique combination of field values).
buckets: Vec<MultiTermsBucketEntry>,
/// The number of documents that didn't make it into the TOP N.
sum_other_doc_count: u64,
#[serde(skip_serializing_if = "Option::is_none")]
/// The upper bound error for the doc count of each term combination.
doc_count_error_upper_bound: Option<u64>,
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
/// This is the composite result
@@ -179,6 +192,11 @@ impl BucketResult {
BucketResult::Histogram { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::MultiTerms {
buckets,
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
BucketResult::Terms {
buckets,
sum_other_doc_count: _,
@@ -272,6 +290,35 @@ impl GetDocCount for BucketEntry {
}
}
/// Bucket entry for a [`multi_terms`](super::bucket::MultiTermsAggregation) aggregation.
///
/// The key is a vector of values (one per declared field), and `key_as_string` is the pipe-joined
/// representation.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct MultiTermsBucketEntry {
/// Pipe-joined string representation of all key elements, e.g. `"rock|Product A"`.
pub key_as_string: String,
/// The composite key: one [`Key`] per field in declaration order.
pub key: Vec<Key>,
/// Number of documents in this bucket.
pub doc_count: u64,
/// Sub-aggregation results.
#[serde(flatten)]
pub sub_aggregation: AggregationResults,
}
impl MultiTermsBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}
impl GetDocCount for MultiTermsBucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
}
}
/// This is the range entry for a bucket, which contains a key, count, and optionally
/// sub-aggregations.
///

View File

@@ -25,6 +25,7 @@
mod composite;
mod filter;
mod histogram;
mod multi_terms;
mod range;
mod term_agg;
mod term_missing_agg;
@@ -35,6 +36,7 @@ use std::fmt;
pub use composite::*;
pub use filter::*;
pub use histogram::*;
pub use multi_terms::*;
pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
pub use term_agg::*;

File diff suppressed because it is too large Load Diff

View File

@@ -28,7 +28,7 @@ use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::bucket::{IntermediateMultiTermsBucketResult, TermsAggregationInternal};
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
@@ -82,7 +82,7 @@ impl From<IntermediateKey> for Key {
}
}
IntermediateKey::F64(f) => Self::F64(f),
IntermediateKey::Bool(f) => Self::U64(f as u64),
IntermediateKey::Bool(f) => Self::Str(f.to_string()),
IntermediateKey::U64(f) => Self::U64(f),
IntermediateKey::I64(f) => Self::I64(f),
}
@@ -286,6 +286,11 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
buckets: IntermediateCompositeBucketResult::default(),
})
}
MultiTerms(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::MultiTerms {
buckets: Default::default(),
})
}
}
}
@@ -499,6 +504,11 @@ pub enum IntermediateBucketResult {
/// The composite buckets
buckets: IntermediateCompositeBucketResult,
},
/// Multi-terms aggregation
MultiTerms {
/// The multi-terms buckets
buckets: IntermediateMultiTermsBucketResult,
},
}
impl IntermediateBucketResult {
@@ -601,6 +611,13 @@ impl IntermediateBucketResult {
.expect("unexpected aggregation, expected composite aggregation");
buckets.into_final_result(composite_req, req.sub_aggregation(), limits)
}
IntermediateBucketResult::MultiTerms { buckets } => {
let multi_terms_req = req
.agg
.as_multi_terms()
.expect("unexpected aggregation, expected multi_terms aggregation");
buckets.into_final_result(multi_terms_req, req.sub_aggregation(), limits)
}
}
}
@@ -677,6 +694,14 @@ impl IntermediateBucketResult {
) => {
composite_left.merge_fruits(composite_right)?;
}
(
IntermediateBucketResult::MultiTerms { buckets: mt_left },
IntermediateBucketResult::MultiTerms { buckets: mt_right },
) => {
merge_maps(&mut mt_left.entries, mt_right.entries)?;
mt_left.sum_other_doc_count += mt_right.sum_other_doc_count;
mt_left.doc_count_error_upper_bound += mt_right.doc_count_error_upper_bound;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -692,6 +717,9 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Composite { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::MultiTerms { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}