feat: added filter aggregation (#2711)

* Initial impl

* Added `Filter` impl in `build_single_agg_segment_collector_with_reader` + Added tests

* Added `Filter(FilterBucketResult)` + Made tests work.

* Fixed type issues.

* Fixed a test.

* 8a7a73a: Pass `segment_reader`

* Added more tests.

* Improved parsing + tests

* refactoring

* Added more tests.

* refactoring: moved parsing code under QueryParser

* Use Tantivy syntax instead of ES

* Added a sanity check test.

* Simplified impl + tests

* Added back tests in a more maintable way

* nitz.

* nitz

* implemented very simple fast-path

* improved a comment

* implemented fast field support

* Used `BoundsRange`

* Improved fast field impl + tests

* Simplified execution.

* Fixed exports + nitz

* Improved the tests to check to the expected result.

* Improved test by checking the whole result JSON

* Removed brittle perf checks.

* Added efficiency verification tests.

* Added one more efficiency check test.

* Improved the efficiency tests.

* Removed unnecessary parsing code + added direct Query obj

* Fixed tests.

* Improved tests

* Fixed code structure

* Fixed lint issues

* nitz.

* nitz

* nitz.

* nitz.

* nitz.

* Added an example

* Fixed PR comments.

* Applied PR comments + nitz

* nitz.

* Improved the code.

* Fixed a perf issue.

* Added batch processing.

* Made the example more interesting

* Fixed bucket count

* Renamed Direct to CustomQuery

* Fixed lint issues.

* No need for scorer to be an `Option`

* nitz

* Used BitSet

* Added an optimization for AllQuery

* Fixed merge issues.

* Fixed lint issues.

* Added benchmark for FILTER

* Removed the Option wrapper.

* nitz.

* Applied PR comments.

* Fixed the AllQuery optimization

* Applied PR comments.

* feat: used `erased_serde` to allow filter query to be serialized

* further improved a comment

* Added back tests.

* removed an unused method

* removed an unused method

* Added documentation

* nitz.

* Added query builder.

* Fixed a comment.

* Applied PR comments.

* Fixed doctest issues.

* Added ser/de

* Removed bench in test

* Fixed a lint issue.
This commit is contained in:
Moe
2025-11-18 11:54:31 -08:00
committed by GitHub
parent 5277367cb0
commit 70e591e230
15 changed files with 2248 additions and 40 deletions

View File

@@ -69,6 +69,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true } futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true } futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7" fnv = "1.0.7"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
winapi = "0.3.9" winapi = "0.3.9"
@@ -87,7 +88,7 @@ more-asserts = "0.3.1"
rand_distr = "0.4.3" rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] } time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [ postcard = { version = "1.0.4", features = [
"use-std", "use-std",
], default-features = false } ], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies] [target.'cfg(not(windows))'.dev-dependencies]
@@ -175,4 +176,3 @@ harness = false
[[bench]] [[bench]]
name = "and_or_queries" name = "and_or_queries"
harness = false harness = false

View File

@@ -74,6 +74,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, histogram_with_term_agg_few); register!(group, histogram_with_term_agg_few);
register!(group, avg_and_range_with_avg_sub_agg); register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
register!(group, filter_agg_all_query_count_agg);
register!(group, filter_agg_term_query_count_agg);
register!(group, filter_agg_all_query_with_sub_aggs);
register!(group, filter_agg_term_query_with_sub_aggs);
group.run(); group.run();
} }
@@ -472,3 +478,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
Ok(index) Ok(index)
} }
// Filter aggregation benchmarks
fn filter_agg_all_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
}
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
}
}
}
});
execute_agg(index, agg_req);
}

View File

@@ -0,0 +1,212 @@
// # Filter Aggregation Example
//
// This example demonstrates filter aggregations - creating buckets of documents
// matching specific queries, with nested aggregations computed on each bucket.
//
// Filter aggregations are useful for computing metrics on different subsets of
// your data in a single query, like "average price overall + average price for
// electronics + count of in-stock items".
use serde_json::json;
use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::AllQuery;
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index};
fn main() -> tantivy::Result<()> {
// Create a simple product schema
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("category", TEXT | FAST);
schema_builder.add_text_field("brand", TEXT | FAST);
schema_builder.add_u64_field("price", FAST);
schema_builder.add_f64_field("rating", FAST);
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
let schema = schema_builder.build();
// Create index and add sample products
let index = Index::create_in_ram(schema.clone());
let mut writer = index.writer(50_000_000)?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "apple",
schema.get_field("price")? => 999u64,
schema.get_field("rating")? => 4.5f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "samsung",
schema.get_field("price")? => 799u64,
schema.get_field("rating")? => 4.2f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "clothing",
schema.get_field("brand")? => "nike",
schema.get_field("price")? => 120u64,
schema.get_field("rating")? => 4.1f64,
schema.get_field("in_stock")? => false
))?;
writer.add_document(doc!(
schema.get_field("category")? => "books",
schema.get_field("brand")? => "penguin",
schema.get_field("price")? => 25u64,
schema.get_field("rating")? => 4.8f64,
schema.get_field("in_stock")? => true
))?;
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Example 1: Basic filter with metric aggregation
println!("=== Example 1: Electronics average price ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 2: Multiple independent filters
println!("=== Example 2: Multiple filters in one query ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": { "avg_price": { "avg": { "field": "price" } } }
},
"in_stock": {
"filter": "in_stock:true",
"aggs": { "count": { "value_count": { "field": "brand" } } }
},
"high_rated": {
"filter": "rating:[4.5 TO *]",
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
},
"in_stock": {
"doc_count": 3,
"count": { "value": 3.0 }
},
"high_rated": {
"doc_count": 2,
"count": { "value": 2.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 3: Nested filters - progressive refinement
println!("=== Example 3: Nested filters ===");
let agg_req = json!({
"in_stock": {
"filter": "in_stock:true",
"aggs": {
"electronics": {
"filter": "category:electronics",
"aggs": {
"expensive": {
"filter": "price:[800 TO *]",
"aggs": {
"avg_rating": { "avg": { "field": "rating" } }
}
}
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"in_stock": {
"doc_count": 3, // apple, samsung, penguin
"electronics": {
"doc_count": 2, // apple, samsung
"expensive": {
"doc_count": 1, // only apple (999)
"avg_rating": { "value": 4.5 }
}
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 4: Filter with sub-aggregation (terms)
println!("=== Example 4: Filter with terms sub-aggregation ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"by_brand": {
"terms": { "field": "brand" },
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"by_brand": {
"buckets": [
{
"key": "samsung",
"doc_count": 1,
"avg_price": { "value": 799.0 }
},
{
"key": "apple",
"doc_count": 1,
"avg_price": { "value": 999.0 }
}
],
"sum_other_doc_count": 0,
"doc_count_error_upper_bound": 0
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}

View File

@@ -10,9 +10,10 @@ use crate::aggregation::accessor_helpers::{
}; };
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{ use crate::aggregation::bucket::{
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
}; };
use crate::aggregation::metric::{ use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
@@ -24,7 +25,7 @@ use crate::aggregation::metric::{
use crate::aggregation::segment_agg_result::{ use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector, GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
}; };
use crate::aggregation::{f64_to_fastfield_u64, AggregationLimitsGuard, Key}; use crate::aggregation::{f64_to_fastfield_u64, AggContextParams, Key};
use crate::{SegmentOrdinal, SegmentReader}; use crate::{SegmentOrdinal, SegmentReader};
#[derive(Default)] #[derive(Default)]
@@ -33,7 +34,7 @@ use crate::{SegmentOrdinal, SegmentReader};
pub struct AggregationsSegmentCtx { pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type. /// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx, pub per_request: PerRequestAggSegCtx,
pub limits: AggregationLimitsGuard, pub context: AggContextParams,
} }
impl AggregationsSegmentCtx { impl AggregationsSegmentCtx {
@@ -67,6 +68,10 @@ impl AggregationsSegmentCtx {
self.per_request.range_req_data.push(Some(Box::new(data))); self.per_request.range_req_data.push(Some(Box::new(data)));
self.per_request.range_req_data.len() - 1 self.per_request.range_req_data.len() - 1
} }
pub(crate) fn push_filter_req_data(&mut self, data: FilterAggReqData) -> usize {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
#[inline] #[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData { pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -102,6 +107,12 @@ impl AggregationsSegmentCtx {
.as_deref() .as_deref()
.expect("range_req_data slot is empty (taken)") .expect("range_req_data slot is empty (taken)")
} }
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ---------- // ---------- mutable getters ----------
@@ -179,6 +190,21 @@ impl AggregationsSegmentCtx {
debug_assert!(self.per_request.range_req_data[idx].is_none()); debug_assert!(self.per_request.range_req_data[idx].is_none());
self.per_request.range_req_data[idx] = Some(value); self.per_request.range_req_data[idx] = Some(value);
} }
/// Move out the boxed Filter request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_filter_req_data(&mut self, idx: usize) -> Box<FilterAggReqData> {
self.per_request.filter_req_data[idx]
.take()
.expect("filter_req_data slot is empty (taken)")
}
/// Put back a Filter request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_filter_req_data(&mut self, idx: usize, value: Box<FilterAggReqData>) {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
} }
/// Each type of aggregation has its own request data struct. This struct holds /// Each type of aggregation has its own request data struct. This struct holds
@@ -196,6 +222,8 @@ pub struct PerRequestAggSegCtx {
pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>, pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>,
/// RangeAggReqData contains the request data for a range aggregation. /// RangeAggReqData contains the request data for a range aggregation.
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>, pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
/// FilterAggReqData contains the request data for a filter aggregation.
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
/// Shared by avg, min, max, sum, stats, extended_stats, count /// Shared by avg, min, max, sum, stats, extended_stats, count
pub stats_metric_req_data: Vec<MetricAggReqData>, pub stats_metric_req_data: Vec<MetricAggReqData>,
/// CardinalityAggReqData contains the request data for a cardinality aggregation. /// CardinalityAggReqData contains the request data for a cardinality aggregation.
@@ -226,6 +254,11 @@ impl PerRequestAggSegCtx {
.iter() .iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption()) .map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>() .sum::<usize>()
+ self
.filter_req_data
.iter()
.map(|b| b.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self + self
.stats_metric_req_data .stats_metric_req_data
.iter() .iter()
@@ -277,6 +310,11 @@ impl PerRequestAggSegCtx {
.expect("range_req_data slot is empty (taken)") .expect("range_req_data slot is empty (taken)")
.name .name
.as_str(), .as_str(),
AggKind::Filter => self.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
} }
} }
@@ -319,7 +357,8 @@ pub(crate) fn build_segment_agg_collectors(
collectors.push(build_segment_agg_collector(req, node)?); collectors.push(build_segment_agg_collector(req, node)?);
} }
req.limits req.context
.limits
.add_memory_consumed(req.per_request.get_memory_consumption() as u64)?; .add_memory_consumed(req.per_request.get_memory_consumption() as u64)?;
// Single collector special case // Single collector special case
if collectors.len() == 1 { if collectors.len() == 1 {
@@ -394,6 +433,9 @@ pub(crate) fn build_segment_agg_collector(
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node, req, node,
)?)), )?)),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
} }
} }
@@ -423,6 +465,7 @@ pub enum AggKind {
Histogram, Histogram,
DateHistogram, DateHistogram,
Range, Range,
Filter,
} }
impl AggKind { impl AggKind {
@@ -437,6 +480,7 @@ impl AggKind {
AggKind::Histogram => "Histogram", AggKind::Histogram => "Histogram",
AggKind::DateHistogram => "DateHistogram", AggKind::DateHistogram => "DateHistogram",
AggKind::Range => "Range", AggKind::Range => "Range",
AggKind::Filter => "Filter",
} }
} }
} }
@@ -446,11 +490,11 @@ pub(crate) fn build_aggregations_data_from_req(
aggs: &Aggregations, aggs: &Aggregations,
reader: &SegmentReader, reader: &SegmentReader,
segment_ordinal: SegmentOrdinal, segment_ordinal: SegmentOrdinal,
limits: AggregationLimitsGuard, context: AggContextParams,
) -> crate::Result<AggregationsSegmentCtx> { ) -> crate::Result<AggregationsSegmentCtx> {
let mut data = AggregationsSegmentCtx { let mut data = AggregationsSegmentCtx {
per_request: Default::default(), per_request: Default::default(),
limits, context,
}; };
for (name, agg) in aggs.iter() { for (name, agg) in aggs.iter() {
@@ -686,6 +730,36 @@ fn build_nodes(
children, children,
}]) }])
} }
AggregationVariants::Filter(filter_req) => {
// Build the query and evaluator upfront
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(&schema, tokenizers)?;
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
reader,
)?;
// Pre-allocate buffer for batch filtering
let max_doc = reader.max_doc();
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
let matching_docs_buffer = Vec::with_capacity(buffer_capacity);
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
name: agg_name.to_string(),
req: filter_req.clone(),
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
kind: AggKind::Filter,
idx_in_req_data,
children,
}])
}
} }
} }

View File

@@ -32,7 +32,8 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::bucket::{ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
}; };
use super::metric::{ use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -130,6 +131,9 @@ pub enum AggregationVariants {
/// Put data into buckets of terms. /// Put data into buckets of terms.
#[serde(rename = "terms")] #[serde(rename = "terms")]
Terms(TermsAggregation), Terms(TermsAggregation),
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
// Metric aggregation types // Metric aggregation types
/// Computes the average of the extracted values. /// Computes the average of the extracted values.
@@ -175,6 +179,7 @@ impl AggregationVariants {
AggregationVariants::Range(range) => vec![range.field.as_str()], AggregationVariants::Range(range) => vec![range.field.as_str()],
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()], AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()], AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Average(avg) => vec![avg.field_name()], AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()], AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()], AggregationVariants::Max(max) => vec![max.field_name()],

View File

@@ -156,6 +156,8 @@ pub enum BucketResult {
/// The upper bound error for the doc count of each term. /// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>, doc_count_error_upper_bound: Option<u64>,
}, },
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
} }
impl BucketResult { impl BucketResult {
@@ -172,6 +174,11 @@ impl BucketResult {
sum_other_doc_count: _, sum_other_doc_count: _,
doc_count_error_upper_bound: _, doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
BucketResult::Filter(filter_result) => {
// Filter doesn't add to bucket count - it's not a user-facing bucket
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
} }
} }
} }
@@ -308,3 +315,25 @@ impl RangeBucketEntry {
1 + self.sub_aggregation.get_bucket_count() 1 + self.sub_aggregation.get_bucket_count()
} }
} }
/// This is the filter bucket result, which contains the document count and sub-aggregations.
///
/// # JSON Format
/// ```json
/// {
/// "electronics_only": {
/// "doc_count": 2,
/// "avg_price": {
/// "value": 150.0
/// }
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FilterBucketResult {
/// Number of documents in the filter bucket
pub doc_count: u64,
/// Sub-aggregation results
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}

View File

@@ -5,7 +5,6 @@ use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector; use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector; use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery}; use crate::query::{AllQuery, TermQuery};
@@ -128,10 +127,8 @@ fn test_aggregation_flushing(
.unwrap(); .unwrap();
let agg_res: AggregationResults = if use_distributed_collector { let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs( let collector =
agg_req.clone(), DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
AggregationLimitsGuard::default(),
);
let searcher = reader.searcher(); let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();

File diff suppressed because it is too large Load Diff

View File

@@ -352,7 +352,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
let mem_delta = self.get_memory_consumption() - mem_pre; let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 { if mem_delta > 0 {
agg_data.limits.add_memory_consumed(mem_delta as u64)?; agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
} }
Ok(()) Ok(())

View File

@@ -22,6 +22,7 @@
//! - [Range](RangeAggregation) //! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation) //! - [Terms](TermsAggregation)
mod filter;
mod histogram; mod histogram;
mod range; mod range;
mod term_agg; mod term_agg;
@@ -30,6 +31,7 @@ mod term_missing_agg;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
pub use filter::*;
pub use histogram::*; pub use histogram::*;
pub use range::*; pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

View File

@@ -339,7 +339,7 @@ impl SegmentRangeCollector {
}) })
.collect::<crate::Result<_>>()?; .collect::<crate::Result<_>>()?;
req_data.limits.add_memory_consumed( req_data.context.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64, buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?; )?;

View File

@@ -443,7 +443,10 @@ impl SegmentAggregationCollector for SegmentTermCollector {
let mem_delta = self.get_memory_consumption() - mem_pre; let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 { if mem_delta > 0 {
agg_data.limits.add_memory_consumed(mem_delta as u64)?; agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
} }
agg_data.put_back_term_req_data(self.accessor_idx, req_data); agg_data.put_back_term_req_data(self.accessor_idx, req_data);
@@ -2341,10 +2344,8 @@ mod tests {
let search = |idx: &Index, let search = |idx: &Index,
agg_req: &Aggregations| agg_req: &Aggregations|
-> crate::Result<IntermediateAggregationResults> { -> crate::Result<IntermediateAggregationResults> {
let collector = DistributedAggregationCollector::from_aggs( let collector =
agg_req.clone(), DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
AggregationLimitsGuard::default(),
);
let reader = idx.reader()?; let reader = idx.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector)?; let agg_res = searcher.search(&AllQuery, &collector)?;
@@ -2388,9 +2389,8 @@ mod tests {
let agg_res2 = search(&index2, &agg_req2)?; let agg_res2 = search(&index2, &agg_req2)?;
agg_res.merge_fruits(agg_res2).unwrap(); agg_res.merge_fruits(agg_res2).unwrap();
let agg_json = serde_json::to_value( let agg_json =
&agg_res.into_final_result(agg_req2, AggregationLimitsGuard::default())?, serde_json::to_value(&agg_res.into_final_result(agg_req2, Default::default())?)?;
)?;
// hosts: // hosts:
let hosts = &agg_json["hosts"]["buckets"]; let hosts = &agg_json["hosts"]["buckets"];

View File

@@ -2,7 +2,8 @@ use super::agg_req::Aggregations;
use super::agg_result::AggregationResults; use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector; use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults; use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{AggregationLimitsGuard, SegmentAggregationCollector}; use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
use crate::aggregation::agg_data::{ use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
}; };
@@ -21,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// The collector collects all aggregations by the underlying aggregation request. /// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector { pub struct AggregationCollector {
agg: Aggregations, agg: Aggregations,
limits: AggregationLimitsGuard, context: AggContextParams,
} }
impl AggregationCollector { impl AggregationCollector {
@@ -29,8 +30,8 @@ impl AggregationCollector {
/// ///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit) /// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, limits } Self { agg, context }
} }
} }
@@ -44,7 +45,7 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method. /// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector { pub struct DistributedAggregationCollector {
agg: Aggregations, agg: Aggregations,
limits: AggregationLimitsGuard, context: AggContextParams,
} }
impl DistributedAggregationCollector { impl DistributedAggregationCollector {
@@ -52,8 +53,8 @@ impl DistributedAggregationCollector {
/// ///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit) /// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, limits } Self { agg, context }
} }
} }
@@ -71,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
&self.agg, &self.agg,
reader, reader,
segment_local_id, segment_local_id,
&self.limits, &self.context,
) )
} }
@@ -101,7 +102,7 @@ impl Collector for AggregationCollector {
&self.agg, &self.agg,
reader, reader,
segment_local_id, segment_local_id,
&self.limits, &self.context,
) )
} }
@@ -114,7 +115,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>, segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> { ) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?; let res = merge_fruits(segment_fruits)?;
res.into_final_result(self.agg.clone(), self.limits.clone()) res.into_final_result(self.agg.clone(), self.context.limits.clone())
} }
} }
@@ -146,10 +147,10 @@ impl AggregationSegmentCollector {
agg: &Aggregations, agg: &Aggregations,
reader: &SegmentReader, reader: &SegmentReader,
segment_ordinal: SegmentOrdinal, segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard, context: &AggContextParams,
) -> crate::Result<Self> { ) -> crate::Result<Self> {
let mut agg_data = let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, limits.clone())?; build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result = let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);

View File

@@ -24,7 +24,9 @@ use super::metric::{
}; };
use super::segment_agg_result::AggregationLimitsGuard; use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey}; use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal; use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector; use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError; use crate::TantivyError;
@@ -246,6 +248,10 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
Cardinality(_) => IntermediateAggregationResult::Metric( Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()), IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
), ),
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
} }
} }
@@ -432,6 +438,13 @@ pub enum IntermediateBucketResult {
/// The term buckets /// The term buckets
buckets: IntermediateTermBucketResult, buckets: IntermediateTermBucketResult,
}, },
/// Filter aggregation - a single bucket with sub-aggregations
Filter {
/// Document count in the filter bucket
doc_count: u64,
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
} }
impl IntermediateBucketResult { impl IntermediateBucketResult {
@@ -515,6 +528,18 @@ impl IntermediateBucketResult {
req.sub_aggregation(), req.sub_aggregation(),
limits, limits,
), ),
IntermediateBucketResult::Filter {
doc_count,
sub_aggregations,
} => {
// Convert sub-aggregation results to final format
let final_sub_aggregations = sub_aggregations
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
Ok(BucketResult::Filter(FilterBucketResult {
doc_count,
sub_aggregations: final_sub_aggregations,
}))
}
} }
} }
@@ -568,6 +593,19 @@ impl IntermediateBucketResult {
*buckets_left = buckets?; *buckets_left = buckets?;
} }
(
IntermediateBucketResult::Filter {
doc_count: doc_count_left,
sub_aggregations: sub_aggs_left,
},
IntermediateBucketResult::Filter {
doc_count: doc_count_right,
sub_aggregations: sub_aggs_right,
},
) => {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(IntermediateBucketResult::Range(_), _) => { (IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types") panic!("try merge on different types")
} }
@@ -577,6 +615,9 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Terms { .. }, _) => { (IntermediateBucketResult::Terms { .. }, _) => {
panic!("try merge on different types") panic!("try merge on different types")
} }
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
} }
Ok(()) Ok(())
} }

View File

@@ -160,6 +160,28 @@ use itertools::Itertools;
use serde::de::{self, Visitor}; use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
/// - `limits`: Memory and bucket limits for the aggregation
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
#[derive(Clone, Default)]
pub struct AggContextParams {
/// Aggregation limits (memory and bucket count)
pub limits: AggregationLimitsGuard,
/// Tokenizer manager for query string parsing
pub tokenizers: TokenizerManager,
}
impl AggContextParams {
/// Create new aggregation context parameters
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
Self { limits, tokenizers }
}
}
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> { fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
let parsed = value let parsed = value
.parse::<f64>() .parse::<f64>()
@@ -390,7 +412,10 @@ mod tests {
query: Option<(&str, &str)>, query: Option<(&str, &str)>,
limits: AggregationLimitsGuard, limits: AggregationLimitsGuard,
) -> crate::Result<Value> { ) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, limits); let collector = AggregationCollector::from_aggs(
agg_req,
AggContextParams::new(limits, index.tokenizers().clone()),
);
let reader = index.reader()?; let reader = index.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();