From d410a3b0c03d552722b0ea87ff8f434f1d9dbf5f Mon Sep 17 00:00:00 2001 From: PSeitz Date: Wed, 15 Oct 2025 17:39:53 +0200 Subject: [PATCH] Add Filtering for Term Aggregations (#2717) * Add Filtering for Term Aggregations Closes #2702 * add AggregationsSegmentCtx memory consumption --------- Co-authored-by: Pascal Seitz --- src/aggregation/agg_data.rs | 127 +++++++++++- src/aggregation/bucket/histogram/histogram.rs | 6 + src/aggregation/bucket/range.rs | 7 + src/aggregation/bucket/term_agg.rs | 185 +++++++++++++++++- src/aggregation/bucket/term_missing_agg.rs | 7 + src/aggregation/metric/cardinality.rs | 7 + src/aggregation/metric/mod.rs | 7 + src/aggregation/metric/top_hits.rs | 7 + src/query/boolean_query/boolean_weight.rs | 1 - 9 files changed, 341 insertions(+), 13 deletions(-) diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index 68bb73bf6..c7dc2e4e6 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -1,5 +1,8 @@ -use columnar::{Column, ColumnType}; +use columnar::{Column, ColumnType, StrColumn}; +use common::BitSet; +use rustc_hash::FxHashSet; use serde::Serialize; +use tantivy_fst::Regex; use crate::aggregation::accessor_helpers::{ get_all_ff_reader_or_empty, get_dynamic_columns, get_ff_reader, get_missing_val_as_u64_lenient, @@ -7,9 +10,9 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - HistogramAggReqData, HistogramBounds, MissingTermAggReqData, RangeAggReqData, - SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, - TermsAggReqData, TermsAggregation, TermsAggregationInternal, + HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, + RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, + TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, }; use crate::aggregation::metric::{ AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, @@ -207,6 +210,45 @@ pub struct PerRequestAggSegCtx { } impl PerRequestAggSegCtx { + /// Estimate the memory consumption of this struct in bytes. + fn get_memory_consumption(&self) -> usize { + self.term_req_data + .iter() + .map(|b| b.as_ref().unwrap().get_memory_consumption()) + .sum::() + + self + .histogram_req_data + .iter() + .map(|b| b.as_ref().unwrap().get_memory_consumption()) + .sum::() + + self + .range_req_data + .iter() + .map(|b| b.as_ref().unwrap().get_memory_consumption()) + .sum::() + + self + .stats_metric_req_data + .iter() + .map(|t| t.get_memory_consumption()) + .sum::() + + self + .cardinality_req_data + .iter() + .map(|t| t.get_memory_consumption()) + .sum::() + + self + .top_hits_req_data + .iter() + .map(|t| t.get_memory_consumption()) + .sum::() + + self + .missing_term_req_data + .iter() + .map(|t| t.get_memory_consumption()) + .sum::() + + self.agg_tree.len() * std::mem::size_of::() + } + pub fn get_name(&self, node: &AggRefNode) -> &str { let idx = node.idx_in_req_data; let kind = node.kind; @@ -277,6 +319,8 @@ pub(crate) fn build_segment_agg_collectors( collectors.push(build_segment_agg_collector(req, node)?); } + req.limits + .add_memory_consumed(req.per_request.get_memory_consumption() as u64)?; // Single collector special case if collectors.len() == 1 { return Ok(collectors.pop().unwrap()); @@ -781,6 +825,19 @@ fn build_terms_or_cardinality_nodes( let children = build_children(sub_aggs, reader, segment_ordinal, data)?; let (idx, kind) = match req { TermsOrCardinalityRequest::Terms(ref req) => { + let mut allowed_term_ids = None; + if req.include.is_some() || req.exclude.is_some() { + if column_type != ColumnType::Str { + // Skip non-string columns entirely when filtering is requested. + // When excluding, the behavior could be to include non-string values + continue; + } + let str_col = str_dict_column + .as_ref() + .expect("str_dict_column must exist for string column"); + allowed_term_ids = + build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?; + }; let idx_in_req_data = data.push_term_req_data(TermsAggReqData { accessor, column_type, @@ -788,11 +845,11 @@ fn build_terms_or_cardinality_nodes( missing_value_for_accessor, column_block_accessor: Default::default(), name: agg_name.to_string(), - field_type: column_type, req: TermsAggregationInternal::from_req(req), // Will be filled later when building collectors sub_aggregation_blueprint: None, sug_aggregations: sub_aggs.clone(), + allowed_term_ids, }); (idx_in_req_data, AggKind::Terms) } @@ -819,6 +876,66 @@ fn build_terms_or_cardinality_nodes( Ok(nodes) } +/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to +/// include/exclude parameters. +fn build_allowed_term_ids_for_str( + str_col: &StrColumn, + include: &Option, + exclude: &Option, +) -> crate::Result> { + let mut allowed: Option = None; + let num_terms = str_col.dictionary().num_terms() as u32; + if let Some(include) = include { + // add matches + allowed = Some(BitSet::with_max_value(num_terms)); + let allowed = allowed.as_mut().unwrap(); + for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?; + }; + + if let Some(exclude) = exclude { + if allowed.is_none() { + // Start with all terms allowed + allowed = Some(BitSet::with_max_value_and_full(num_terms)); + } + let allowed = allowed.as_mut().unwrap(); + for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?; + } + + Ok(allowed) +} + +/// Apply a callback to each matching term ordinal for the given include/exclude parameter. +fn for_each_matching_term_ord( + str_col: &StrColumn, + param: &IncludeExcludeParam, + mut cb: impl FnMut(u32), +) -> crate::Result<()> { + match param { + IncludeExcludeParam::Regex(pattern) => { + let re = Regex::new(pattern).map_err(|e| { + crate::TantivyError::InvalidArgument(format!("Invalid regex `{}`: {}", pattern, e)) + })?; + // TODO: we can handle patterns like `^prefix.*` more efficiently + let mut stream = str_col.dictionary().search(re).into_stream()?; + while stream.advance() { + cb(stream.term_ord() as u32); + } + } + IncludeExcludeParam::Values(values) => { + let set: FxHashSet<&str> = values.iter().map(|s| s.as_str()).collect(); + let mut stream = str_col.dictionary().stream()?; + while stream.advance() { + if let Ok(key_str) = std::str::from_utf8(stream.key()) { + if set.contains(key_str) { + cb(stream.term_ord() as u32); + } + } + } + } + } + Ok(()) +} + /// Convert the aggregation tree to something serializable and easy to read. #[derive(Serialize, Debug, Clone, PartialEq, Eq)] pub struct AggTreeViewNode { diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 428f2a0a6..5a52f7cf0 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -42,6 +42,12 @@ pub struct HistogramAggReqData { /// The offset used to calculate the bucket position. pub offset: f64, } +impl HistogramAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. /// Each document value is rounded down to its bucket. diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index a2b092257..121528450 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -31,6 +31,13 @@ pub struct RangeAggReqData { pub name: String, } +impl RangeAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} + /// Provide user-defined buckets to aggregate on. /// /// Two special buckets will automatically be created to cover the whole range of values. diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 2582f6a8f..ae2e0d87c 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -7,6 +7,7 @@ use columnar::{ Column, ColumnBlockAccessor, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalValue, StrColumn, }; +use common::BitSet; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -38,8 +39,6 @@ pub struct TermsAggReqData { pub missing_value_for_accessor: Option, /// The column block accessor to access the fast field values. pub column_block_accessor: ColumnBlockAccessor, - /// The type of the fast field. - pub field_type: ColumnType, /// Note: sub_aggregation_blueprint is filled later when building collectors pub sub_aggregation_blueprint: Option>, /// Used to build the correct nested result when we have an empty result. @@ -48,6 +47,21 @@ pub struct TermsAggReqData { pub name: String, /// The normalized term aggregation request. pub req: TermsAggregationInternal, + /// Preloaded allowed term ords (string columns only). If set, only ords present are collected. + pub allowed_term_ids: Option, +} + +impl TermsAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + + std::mem::size_of::() + + self + .allowed_term_ids + .as_ref() + .map(|bs| bs.len() / 8) + .unwrap_or(0) + } } /// Creates a bucket for every unique term and counts the number of occurrences. @@ -120,6 +134,68 @@ pub struct TermsAggReqData { /// } /// ``` +#[derive(Clone, Debug, PartialEq)] +pub enum IncludeExcludeParam { + /// A single string pattern is treated as regex. + Regex(String), + /// An array of strings is treated as exact values. + Values(Vec), +} + +impl Serialize for IncludeExcludeParam { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + match self { + IncludeExcludeParam::Regex(s) => serializer.serialize_str(s), + IncludeExcludeParam::Values(v) => v.serialize(serializer), + } + } +} + +// Custom deserializer to accept either a single string (regex) or an array of strings (values). +impl<'de> Deserialize<'de> for IncludeExcludeParam { + fn deserialize(deserializer: D) -> Result + where D: serde::Deserializer<'de> { + use serde::de::{self, SeqAccess, Visitor}; + struct IncludeExcludeVisitor; + + impl<'de> Visitor<'de> for IncludeExcludeVisitor { + type Value = IncludeExcludeParam; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string (regex) or an array of strings") + } + + fn visit_str(self, v: &str) -> Result + where E: de::Error { + Ok(IncludeExcludeParam::Regex(v.to_string())) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where E: de::Error { + Ok(IncludeExcludeParam::Regex(v.to_string())) + } + + fn visit_string(self, v: String) -> Result + where E: de::Error { + Ok(IncludeExcludeParam::Regex(v)) + } + + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let mut values: Vec = Vec::new(); + while let Some(elem) = seq.next_element::()? { + values.push(elem); + } + Ok(IncludeExcludeParam::Values(values)) + } + } + + deserializer.deserialize_any(IncludeExcludeVisitor) + } +} + +/// The terms aggregation allows you to group documents by unique values of a field. #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct TermsAggregation { /// The field to aggregate on. @@ -189,6 +265,13 @@ pub struct TermsAggregation { /// add text. #[serde(skip_serializing_if = "Option::is_none", default)] pub missing: Option, + + /// Include terms by either regex (single string) or exact values (array). + #[serde(skip_serializing_if = "Option::is_none", default)] + pub include: Option, + /// Exclude terms by either regex (single string) or exact values (array). + #[serde(skip_serializing_if = "Option::is_none", default)] + pub exclude: Option, } /// Same as TermsAggregation, but with populated defaults. @@ -330,6 +413,11 @@ impl SegmentAggregationCollector for SegmentTermCollector { } for term_id in req_data.column_block_accessor.iter_vals() { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } let entry = self.term_buckets.entries.entry(term_id).or_default(); *entry += 1; } @@ -339,6 +427,11 @@ impl SegmentAggregationCollector for SegmentTermCollector { .column_block_accessor .iter_docid_vals(docs, &req_data.accessor) { + if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() { + if !allowed_bs.contains(term_id as u32) { + continue; + } + } let sub_aggregations = self .term_buckets .sub_aggs @@ -375,11 +468,11 @@ impl SegmentTermCollector { node: &AggRefNode, ) -> crate::Result { let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data); - let field_type = terms_req_data.field_type; + let column_type = terms_req_data.column_type; let accessor_idx = node.idx_in_req_data; - if field_type == ColumnType::Bytes { + if column_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( - "terms aggregation is not supported for column type {field_type:?}" + "terms aggregation is not supported for column type {column_type:?}" ))); } let term_buckets = TermBuckets::default(); @@ -552,13 +645,20 @@ impl SegmentTermCollector { let mut stream = term_dict.stream()?; let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(&term_req.sug_aggregations); - while let Some((key, _ord)) = stream.next() { + while stream.advance() { if dict.len() >= term_req.req.segment_size as usize { break; } + // Respect allowed filters if present + if let Some(allowed_bs) = term_req.allowed_term_ids.as_ref() { + if !allowed_bs.contains(stream.term_ord() as u32) { + continue; + } + } + let key = IntermediateKey::Str( - std::str::from_utf8(key) + std::str::from_utf8(stream.key()) .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? .to_string(), ); @@ -751,6 +851,77 @@ mod tests { ); assert_eq!(res["my_texts"]["sum_other_doc_count"], 1); + // include filter: only terma and termc + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { + "field": "string_id", + "include": ["terma", "termc"], + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + + // exclude filter: remove termc + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { + "field": "string_id", + "exclude": ["termc"], + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + + // include regex (single string): only termb + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { + "field": "string_id", + "include": "termb", + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + + // include regex (term.*) with exclude regex (termc): expect terma and termb + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { + "field": "string_id", + "include": "term.*", + "exclude": "termc", + }, + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + // test min_doc_count let agg_req: Aggregations = serde_json::from_value(json!({ "my_texts": { diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 6f99af968..66f39927a 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -28,6 +28,13 @@ pub struct MissingTermAggReqData { pub req: TermsAggregation, } +impl MissingTermAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} + /// The specialized missing term aggregation. #[derive(Default, Debug, Clone)] pub struct TermMissingAgg { diff --git a/src/aggregation/metric/cardinality.rs b/src/aggregation/metric/cardinality.rs index 8331b3ab3..8f3bdd3e5 100644 --- a/src/aggregation/metric/cardinality.rs +++ b/src/aggregation/metric/cardinality.rs @@ -114,6 +114,13 @@ pub struct CardinalityAggReqData { pub req: CardinalityAggregationReq, } +impl CardinalityAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} + impl CardinalityAggregationReq { /// Creates a new [`CardinalityAggregationReq`] instance from a field name. pub fn from_field_name(field_name: String) -> Self { diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 6342b2045..3537af8a6 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -67,6 +67,13 @@ pub struct MetricAggReqData { pub name: String, } +impl MetricAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} + /// Single-metric aggregations use this common result structure. /// /// Main reason to wrap it in value is to match elasticsearch output structure. diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index e59f9b210..8156a1b66 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -37,6 +37,13 @@ pub struct TopHitsAggReqData { pub req: TopHitsAggregationReq, } +impl TopHitsAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + } +} + /// # Top Hits /// /// The top hits aggregation is a useful tool to answer questions like: diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 1b72c44d2..5dbd5ea44 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -1,4 +1,3 @@ -use core::num; use std::collections::HashMap; use crate::docset::COLLECT_BLOCK_BUFFER_LEN;