Add Filtering for Term Aggregations (#2717)

* Add Filtering for Term Aggregations

Closes #2702

* add AggregationsSegmentCtx memory consumption

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
This commit is contained in:
PSeitz
2025-10-15 17:39:53 +02:00
committed by GitHub
parent fc93391d0e
commit d410a3b0c0
9 changed files with 341 additions and 13 deletions

View File

@@ -1,5 +1,8 @@
use columnar::{Column, ColumnType}; use columnar::{Column, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize; use serde::Serialize;
use tantivy_fst::Regex;
use crate::aggregation::accessor_helpers::{ use crate::aggregation::accessor_helpers::{
get_all_ff_reader_or_empty, get_dynamic_columns, get_ff_reader, get_missing_val_as_u64_lenient, get_all_ff_reader_or_empty, get_dynamic_columns, get_ff_reader, get_missing_val_as_u64_lenient,
@@ -7,9 +10,9 @@ use crate::aggregation::accessor_helpers::{
}; };
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{ use crate::aggregation::bucket::{
HistogramAggReqData, HistogramBounds, MissingTermAggReqData, RangeAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector,
TermsAggReqData, TermsAggregation, TermsAggregationInternal, TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal,
}; };
use crate::aggregation::metric::{ use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
@@ -207,6 +210,45 @@ pub struct PerRequestAggSegCtx {
} }
impl PerRequestAggSegCtx { impl PerRequestAggSegCtx {
/// Estimate the memory consumption of this struct in bytes.
fn get_memory_consumption(&self) -> usize {
self.term_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.histogram_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.range_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self
.stats_metric_req_data
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.cardinality_req_data
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.top_hits_req_data
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.missing_term_req_data
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
pub fn get_name(&self, node: &AggRefNode) -> &str { pub fn get_name(&self, node: &AggRefNode) -> &str {
let idx = node.idx_in_req_data; let idx = node.idx_in_req_data;
let kind = node.kind; let kind = node.kind;
@@ -277,6 +319,8 @@ pub(crate) fn build_segment_agg_collectors(
collectors.push(build_segment_agg_collector(req, node)?); collectors.push(build_segment_agg_collector(req, node)?);
} }
req.limits
.add_memory_consumed(req.per_request.get_memory_consumption() as u64)?;
// Single collector special case // Single collector special case
if collectors.len() == 1 { if collectors.len() == 1 {
return Ok(collectors.pop().unwrap()); return Ok(collectors.pop().unwrap());
@@ -781,6 +825,19 @@ fn build_terms_or_cardinality_nodes(
let children = build_children(sub_aggs, reader, segment_ordinal, data)?; let children = build_children(sub_aggs, reader, segment_ordinal, data)?;
let (idx, kind) = match req { let (idx, kind) = match req {
TermsOrCardinalityRequest::Terms(ref req) => { TermsOrCardinalityRequest::Terms(ref req) => {
let mut allowed_term_ids = None;
if req.include.is_some() || req.exclude.is_some() {
if column_type != ColumnType::Str {
// Skip non-string columns entirely when filtering is requested.
// When excluding, the behavior could be to include non-string values
continue;
}
let str_col = str_dict_column
.as_ref()
.expect("str_dict_column must exist for string column");
allowed_term_ids =
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
};
let idx_in_req_data = data.push_term_req_data(TermsAggReqData { let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
accessor, accessor,
column_type, column_type,
@@ -788,11 +845,11 @@ fn build_terms_or_cardinality_nodes(
missing_value_for_accessor, missing_value_for_accessor,
column_block_accessor: Default::default(), column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
field_type: column_type,
req: TermsAggregationInternal::from_req(req), req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors // Will be filled later when building collectors
sub_aggregation_blueprint: None, sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(), sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
}); });
(idx_in_req_data, AggKind::Terms) (idx_in_req_data, AggKind::Terms)
} }
@@ -819,6 +876,66 @@ fn build_terms_or_cardinality_nodes(
Ok(nodes) Ok(nodes)
} }
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
fn build_allowed_term_ids_for_str(
str_col: &StrColumn,
include: &Option<IncludeExcludeParam>,
exclude: &Option<IncludeExcludeParam>,
) -> crate::Result<Option<BitSet>> {
let mut allowed: Option<BitSet> = None;
let num_terms = str_col.dictionary().num_terms() as u32;
if let Some(include) = include {
// add matches
allowed = Some(BitSet::with_max_value(num_terms));
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
};
if let Some(exclude) = exclude {
if allowed.is_none() {
// Start with all terms allowed
allowed = Some(BitSet::with_max_value_and_full(num_terms));
}
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;
}
Ok(allowed)
}
/// Apply a callback to each matching term ordinal for the given include/exclude parameter.
fn for_each_matching_term_ord(
str_col: &StrColumn,
param: &IncludeExcludeParam,
mut cb: impl FnMut(u32),
) -> crate::Result<()> {
match param {
IncludeExcludeParam::Regex(pattern) => {
let re = Regex::new(pattern).map_err(|e| {
crate::TantivyError::InvalidArgument(format!("Invalid regex `{}`: {}", pattern, e))
})?;
// TODO: we can handle patterns like `^prefix.*` more efficiently
let mut stream = str_col.dictionary().search(re).into_stream()?;
while stream.advance() {
cb(stream.term_ord() as u32);
}
}
IncludeExcludeParam::Values(values) => {
let set: FxHashSet<&str> = values.iter().map(|s| s.as_str()).collect();
let mut stream = str_col.dictionary().stream()?;
while stream.advance() {
if let Ok(key_str) = std::str::from_utf8(stream.key()) {
if set.contains(key_str) {
cb(stream.term_ord() as u32);
}
}
}
}
}
Ok(())
}
/// Convert the aggregation tree to something serializable and easy to read. /// Convert the aggregation tree to something serializable and easy to read.
#[derive(Serialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Debug, Clone, PartialEq, Eq)]
pub struct AggTreeViewNode { pub struct AggTreeViewNode {

View File

@@ -42,6 +42,12 @@ pub struct HistogramAggReqData {
/// The offset used to calculate the bucket position. /// The offset used to calculate the bucket position.
pub offset: f64, pub offset: f64,
} }
impl HistogramAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
/// Each document value is rounded down to its bucket. /// Each document value is rounded down to its bucket.

View File

@@ -31,6 +31,13 @@ pub struct RangeAggReqData {
pub name: String, pub name: String,
} }
impl RangeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Provide user-defined buckets to aggregate on. /// Provide user-defined buckets to aggregate on.
/// ///
/// Two special buckets will automatically be created to cover the whole range of values. /// Two special buckets will automatically be created to cover the whole range of values.

View File

@@ -7,6 +7,7 @@ use columnar::{
Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128, Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128,
MonotonicallyMappableToU64, NumericalValue, StrColumn, MonotonicallyMappableToU64, NumericalValue, StrColumn,
}; };
use common::BitSet;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -38,8 +39,6 @@ pub struct TermsAggReqData {
pub missing_value_for_accessor: Option<u64>, pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values. /// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>, pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// Note: sub_aggregation_blueprint is filled later when building collectors /// Note: sub_aggregation_blueprint is filled later when building collectors
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>, pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// Used to build the correct nested result when we have an empty result. /// Used to build the correct nested result when we have an empty result.
@@ -48,6 +47,21 @@ pub struct TermsAggReqData {
pub name: String, pub name: String,
/// The normalized term aggregation request. /// The normalized term aggregation request.
pub req: TermsAggregationInternal, pub req: TermsAggregationInternal,
/// Preloaded allowed term ords (string columns only). If set, only ords present are collected.
pub allowed_term_ids: Option<BitSet>,
}
impl TermsAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ std::mem::size_of::<TermsAggregationInternal>()
+ self
.allowed_term_ids
.as_ref()
.map(|bs| bs.len() / 8)
.unwrap_or(0)
}
} }
/// Creates a bucket for every unique term and counts the number of occurrences. /// Creates a bucket for every unique term and counts the number of occurrences.
@@ -120,6 +134,68 @@ pub struct TermsAggReqData {
/// } /// }
/// ``` /// ```
#[derive(Clone, Debug, PartialEq)]
pub enum IncludeExcludeParam {
/// A single string pattern is treated as regex.
Regex(String),
/// An array of strings is treated as exact values.
Values(Vec<String>),
}
impl Serialize for IncludeExcludeParam {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer {
match self {
IncludeExcludeParam::Regex(s) => serializer.serialize_str(s),
IncludeExcludeParam::Values(v) => v.serialize(serializer),
}
}
}
// Custom deserializer to accept either a single string (regex) or an array of strings (values).
impl<'de> Deserialize<'de> for IncludeExcludeParam {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de> {
use serde::de::{self, SeqAccess, Visitor};
struct IncludeExcludeVisitor;
impl<'de> Visitor<'de> for IncludeExcludeVisitor {
type Value = IncludeExcludeParam;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string (regex) or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where E: de::Error {
Ok(IncludeExcludeParam::Regex(v.to_string()))
}
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where E: de::Error {
Ok(IncludeExcludeParam::Regex(v.to_string()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where E: de::Error {
Ok(IncludeExcludeParam::Regex(v))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: SeqAccess<'de> {
let mut values: Vec<String> = Vec::new();
while let Some(elem) = seq.next_element::<String>()? {
values.push(elem);
}
Ok(IncludeExcludeParam::Values(values))
}
}
deserializer.deserialize_any(IncludeExcludeVisitor)
}
}
/// The terms aggregation allows you to group documents by unique values of a field.
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct TermsAggregation { pub struct TermsAggregation {
/// The field to aggregate on. /// The field to aggregate on.
@@ -189,6 +265,13 @@ pub struct TermsAggregation {
/// add text. /// add text.
#[serde(skip_serializing_if = "Option::is_none", default)] #[serde(skip_serializing_if = "Option::is_none", default)]
pub missing: Option<Key>, pub missing: Option<Key>,
/// Include terms by either regex (single string) or exact values (array).
#[serde(skip_serializing_if = "Option::is_none", default)]
pub include: Option<IncludeExcludeParam>,
/// Exclude terms by either regex (single string) or exact values (array).
#[serde(skip_serializing_if = "Option::is_none", default)]
pub exclude: Option<IncludeExcludeParam>,
} }
/// Same as TermsAggregation, but with populated defaults. /// Same as TermsAggregation, but with populated defaults.
@@ -330,6 +413,11 @@ impl SegmentAggregationCollector for SegmentTermCollector {
} }
for term_id in req_data.column_block_accessor.iter_vals() { for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let entry = self.term_buckets.entries.entry(term_id).or_default(); let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1; *entry += 1;
} }
@@ -339,6 +427,11 @@ impl SegmentAggregationCollector for SegmentTermCollector {
.column_block_accessor .column_block_accessor
.iter_docid_vals(docs, &req_data.accessor) .iter_docid_vals(docs, &req_data.accessor)
{ {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let sub_aggregations = self let sub_aggregations = self
.term_buckets .term_buckets
.sub_aggs .sub_aggs
@@ -375,11 +468,11 @@ impl SegmentTermCollector {
node: &AggRefNode, node: &AggRefNode,
) -> crate::Result<Self> { ) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let field_type = terms_req_data.field_type; let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data; let accessor_idx = node.idx_in_req_data;
if field_type == ColumnType::Bytes { if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!( return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {field_type:?}" "terms aggregation is not supported for column type {column_type:?}"
))); )));
} }
let term_buckets = TermBuckets::default(); let term_buckets = TermBuckets::default();
@@ -552,13 +645,20 @@ impl SegmentTermCollector {
let mut stream = term_dict.stream()?; let mut stream = term_dict.stream()?;
let empty_sub_aggregation = let empty_sub_aggregation =
IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations); IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations);
while let Some((key, _ord)) = stream.next() { while stream.advance() {
if dict.len() >= term_req.req.segment_size as usize { if dict.len() >= term_req.req.segment_size as usize {
break; break;
} }
// Respect allowed filters if present
if let Some(allowed_bs) = term_req.allowed_term_ids.as_ref() {
if !allowed_bs.contains(stream.term_ord() as u32) {
continue;
}
}
let key = IntermediateKey::Str( let key = IntermediateKey::Str(
std::str::from_utf8(key) std::str::from_utf8(stream.key())
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?
.to_string(), .to_string(),
); );
@@ -751,6 +851,77 @@ mod tests {
); );
assert_eq!(res["my_texts"]["sum_other_doc_count"], 1); assert_eq!(res["my_texts"]["sum_other_doc_count"], 1);
// include filter: only terma and termc
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"include": ["terma", "termc"],
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// exclude filter: remove termc
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"exclude": ["termc"],
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// include regex (single string): only termb
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"include": "termb",
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// include regex (term.*) with exclude regex (termc): expect terma and termb
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"include": "term.*",
"exclude": "termc",
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// test min_doc_count // test min_doc_count
let agg_req: Aggregations = serde_json::from_value(json!({ let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": { "my_texts": {

View File

@@ -28,6 +28,13 @@ pub struct MissingTermAggReqData {
pub req: TermsAggregation, pub req: TermsAggregation,
} }
impl MissingTermAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// The specialized missing term aggregation. /// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
pub struct TermMissingAgg { pub struct TermMissingAgg {

View File

@@ -114,6 +114,13 @@ pub struct CardinalityAggReqData {
pub req: CardinalityAggregationReq, pub req: CardinalityAggregationReq,
} }
impl CardinalityAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
impl CardinalityAggregationReq { impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name. /// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self { pub fn from_field_name(field_name: String) -> Self {

View File

@@ -67,6 +67,13 @@ pub struct MetricAggReqData {
pub name: String, pub name: String,
} }
impl MetricAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Single-metric aggregations use this common result structure. /// Single-metric aggregations use this common result structure.
/// ///
/// Main reason to wrap it in value is to match elasticsearch output structure. /// Main reason to wrap it in value is to match elasticsearch output structure.

View File

@@ -37,6 +37,13 @@ pub struct TopHitsAggReqData {
pub req: TopHitsAggregationReq, pub req: TopHitsAggregationReq,
} }
impl TopHitsAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// # Top Hits /// # Top Hits
/// ///
/// The top hits aggregation is a useful tool to answer questions like: /// The top hits aggregation is a useful tool to answer questions like:

View File

@@ -1,4 +1,3 @@
use core::num;
use std::collections::HashMap; use std::collections::HashMap;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::docset::COLLECT_BLOCK_BUFFER_LEN;