mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-23 02:29:57 +00:00
feat: added filter aggregation (#2711)
* Initial impl * Added `Filter` impl in `build_single_agg_segment_collector_with_reader` + Added tests * Added `Filter(FilterBucketResult)` + Made tests work. * Fixed type issues. * Fixed a test. * 8a7a73a: Pass `segment_reader` * Added more tests. * Improved parsing + tests * refactoring * Added more tests. * refactoring: moved parsing code under QueryParser * Use Tantivy syntax instead of ES * Added a sanity check test. * Simplified impl + tests * Added back tests in a more maintable way * nitz. * nitz * implemented very simple fast-path * improved a comment * implemented fast field support * Used `BoundsRange` * Improved fast field impl + tests * Simplified execution. * Fixed exports + nitz * Improved the tests to check to the expected result. * Improved test by checking the whole result JSON * Removed brittle perf checks. * Added efficiency verification tests. * Added one more efficiency check test. * Improved the efficiency tests. * Removed unnecessary parsing code + added direct Query obj * Fixed tests. * Improved tests * Fixed code structure * Fixed lint issues * nitz. * nitz * nitz. * nitz. * nitz. * Added an example * Fixed PR comments. * Applied PR comments + nitz * nitz. * Improved the code. * Fixed a perf issue. * Added batch processing. * Made the example more interesting * Fixed bucket count * Renamed Direct to CustomQuery * Fixed lint issues. * No need for scorer to be an `Option` * nitz * Used BitSet * Added an optimization for AllQuery * Fixed merge issues. * Fixed lint issues. * Added benchmark for FILTER * Removed the Option wrapper. * nitz. * Applied PR comments. * Fixed the AllQuery optimization * Applied PR comments. * feat: used `erased_serde` to allow filter query to be serialized * further improved a comment * Added back tests. * removed an unused method * removed an unused method * Added documentation * nitz. * Added query builder. * Fixed a comment. * Applied PR comments. * Fixed doctest issues. * Added ser/de * Removed bench in test * Fixed a lint issue.
This commit is contained in:
@@ -69,6 +69,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
|||||||
futures-util = { version = "0.3.28", optional = true }
|
futures-util = { version = "0.3.28", optional = true }
|
||||||
futures-channel = { version = "0.3.28", optional = true }
|
futures-channel = { version = "0.3.28", optional = true }
|
||||||
fnv = "1.0.7"
|
fnv = "1.0.7"
|
||||||
|
typetag = "0.2.21"
|
||||||
|
|
||||||
[target.'cfg(windows)'.dependencies]
|
[target.'cfg(windows)'.dependencies]
|
||||||
winapi = "0.3.9"
|
winapi = "0.3.9"
|
||||||
@@ -87,7 +88,7 @@ more-asserts = "0.3.1"
|
|||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
||||||
postcard = { version = "1.0.4", features = [
|
postcard = { version = "1.0.4", features = [
|
||||||
"use-std",
|
"use-std",
|
||||||
], default-features = false }
|
], default-features = false }
|
||||||
|
|
||||||
[target.'cfg(not(windows))'.dev-dependencies]
|
[target.'cfg(not(windows))'.dev-dependencies]
|
||||||
@@ -175,4 +176,3 @@ harness = false
|
|||||||
[[bench]]
|
[[bench]]
|
||||||
name = "and_or_queries"
|
name = "and_or_queries"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
|||||||
register!(group, histogram_with_term_agg_few);
|
register!(group, histogram_with_term_agg_few);
|
||||||
register!(group, avg_and_range_with_avg_sub_agg);
|
register!(group, avg_and_range_with_avg_sub_agg);
|
||||||
|
|
||||||
|
// Filter aggregation benchmarks
|
||||||
|
register!(group, filter_agg_all_query_count_agg);
|
||||||
|
register!(group, filter_agg_term_query_count_agg);
|
||||||
|
register!(group, filter_agg_all_query_with_sub_aggs);
|
||||||
|
register!(group, filter_agg_term_query_with_sub_aggs);
|
||||||
|
|
||||||
group.run();
|
group.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,3 +478,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
|||||||
|
|
||||||
Ok(index)
|
Ok(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Filter aggregation benchmarks
|
||||||
|
|
||||||
|
fn filter_agg_all_query_count_agg(index: &Index) {
|
||||||
|
let agg_req = json!({
|
||||||
|
"filtered": {
|
||||||
|
"filter": "*",
|
||||||
|
"aggs": {
|
||||||
|
"count": { "value_count": { "field": "score" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
execute_agg(index, agg_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_agg_term_query_count_agg(index: &Index) {
|
||||||
|
let agg_req = json!({
|
||||||
|
"filtered": {
|
||||||
|
"filter": "text:cool",
|
||||||
|
"aggs": {
|
||||||
|
"count": { "value_count": { "field": "score" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
execute_agg(index, agg_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||||
|
let agg_req = json!({
|
||||||
|
"filtered": {
|
||||||
|
"filter": "*",
|
||||||
|
"aggs": {
|
||||||
|
"avg_score": { "avg": { "field": "score" } },
|
||||||
|
"stats_score": { "stats": { "field": "score_f64" } },
|
||||||
|
"terms_text": {
|
||||||
|
"terms": { "field": "text_few_terms" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
execute_agg(index, agg_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||||
|
let agg_req = json!({
|
||||||
|
"filtered": {
|
||||||
|
"filter": "text:cool",
|
||||||
|
"aggs": {
|
||||||
|
"avg_score": { "avg": { "field": "score" } },
|
||||||
|
"stats_score": { "stats": { "field": "score_f64" } },
|
||||||
|
"terms_text": {
|
||||||
|
"terms": { "field": "text_few_terms" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
execute_agg(index, agg_req);
|
||||||
|
}
|
||||||
|
|||||||
212
examples/filter_aggregation.rs
Normal file
212
examples/filter_aggregation.rs
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
// # Filter Aggregation Example
|
||||||
|
//
|
||||||
|
// This example demonstrates filter aggregations - creating buckets of documents
|
||||||
|
// matching specific queries, with nested aggregations computed on each bucket.
|
||||||
|
//
|
||||||
|
// Filter aggregations are useful for computing metrics on different subsets of
|
||||||
|
// your data in a single query, like "average price overall + average price for
|
||||||
|
// electronics + count of in-stock items".
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
use tantivy::aggregation::agg_req::Aggregations;
|
||||||
|
use tantivy::aggregation::AggregationCollector;
|
||||||
|
use tantivy::query::AllQuery;
|
||||||
|
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
|
||||||
|
use tantivy::{doc, Index};
|
||||||
|
|
||||||
|
fn main() -> tantivy::Result<()> {
|
||||||
|
// Create a simple product schema
|
||||||
|
let mut schema_builder = Schema::builder();
|
||||||
|
schema_builder.add_text_field("category", TEXT | FAST);
|
||||||
|
schema_builder.add_text_field("brand", TEXT | FAST);
|
||||||
|
schema_builder.add_u64_field("price", FAST);
|
||||||
|
schema_builder.add_f64_field("rating", FAST);
|
||||||
|
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
|
||||||
|
let schema = schema_builder.build();
|
||||||
|
|
||||||
|
// Create index and add sample products
|
||||||
|
let index = Index::create_in_ram(schema.clone());
|
||||||
|
let mut writer = index.writer(50_000_000)?;
|
||||||
|
|
||||||
|
writer.add_document(doc!(
|
||||||
|
schema.get_field("category")? => "electronics",
|
||||||
|
schema.get_field("brand")? => "apple",
|
||||||
|
schema.get_field("price")? => 999u64,
|
||||||
|
schema.get_field("rating")? => 4.5f64,
|
||||||
|
schema.get_field("in_stock")? => true
|
||||||
|
))?;
|
||||||
|
writer.add_document(doc!(
|
||||||
|
schema.get_field("category")? => "electronics",
|
||||||
|
schema.get_field("brand")? => "samsung",
|
||||||
|
schema.get_field("price")? => 799u64,
|
||||||
|
schema.get_field("rating")? => 4.2f64,
|
||||||
|
schema.get_field("in_stock")? => true
|
||||||
|
))?;
|
||||||
|
writer.add_document(doc!(
|
||||||
|
schema.get_field("category")? => "clothing",
|
||||||
|
schema.get_field("brand")? => "nike",
|
||||||
|
schema.get_field("price")? => 120u64,
|
||||||
|
schema.get_field("rating")? => 4.1f64,
|
||||||
|
schema.get_field("in_stock")? => false
|
||||||
|
))?;
|
||||||
|
writer.add_document(doc!(
|
||||||
|
schema.get_field("category")? => "books",
|
||||||
|
schema.get_field("brand")? => "penguin",
|
||||||
|
schema.get_field("price")? => 25u64,
|
||||||
|
schema.get_field("rating")? => 4.8f64,
|
||||||
|
schema.get_field("in_stock")? => true
|
||||||
|
))?;
|
||||||
|
|
||||||
|
writer.commit()?;
|
||||||
|
|
||||||
|
let reader = index.reader()?;
|
||||||
|
let searcher = reader.searcher();
|
||||||
|
|
||||||
|
// Example 1: Basic filter with metric aggregation
|
||||||
|
println!("=== Example 1: Electronics average price ===");
|
||||||
|
let agg_req = json!({
|
||||||
|
"electronics": {
|
||||||
|
"filter": "category:electronics",
|
||||||
|
"aggs": {
|
||||||
|
"avg_price": { "avg": { "field": "price" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||||
|
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||||
|
let result = searcher.search(&AllQuery, &collector)?;
|
||||||
|
|
||||||
|
let expected = json!({
|
||||||
|
"electronics": {
|
||||||
|
"doc_count": 2,
|
||||||
|
"avg_price": { "value": 899.0 }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||||
|
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||||
|
|
||||||
|
// Example 2: Multiple independent filters
|
||||||
|
println!("=== Example 2: Multiple filters in one query ===");
|
||||||
|
let agg_req = json!({
|
||||||
|
"electronics": {
|
||||||
|
"filter": "category:electronics",
|
||||||
|
"aggs": { "avg_price": { "avg": { "field": "price" } } }
|
||||||
|
},
|
||||||
|
"in_stock": {
|
||||||
|
"filter": "in_stock:true",
|
||||||
|
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||||
|
},
|
||||||
|
"high_rated": {
|
||||||
|
"filter": "rating:[4.5 TO *]",
|
||||||
|
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||||
|
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||||
|
let result = searcher.search(&AllQuery, &collector)?;
|
||||||
|
|
||||||
|
let expected = json!({
|
||||||
|
"electronics": {
|
||||||
|
"doc_count": 2,
|
||||||
|
"avg_price": { "value": 899.0 }
|
||||||
|
},
|
||||||
|
"in_stock": {
|
||||||
|
"doc_count": 3,
|
||||||
|
"count": { "value": 3.0 }
|
||||||
|
},
|
||||||
|
"high_rated": {
|
||||||
|
"doc_count": 2,
|
||||||
|
"count": { "value": 2.0 }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||||
|
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||||
|
|
||||||
|
// Example 3: Nested filters - progressive refinement
|
||||||
|
println!("=== Example 3: Nested filters ===");
|
||||||
|
let agg_req = json!({
|
||||||
|
"in_stock": {
|
||||||
|
"filter": "in_stock:true",
|
||||||
|
"aggs": {
|
||||||
|
"electronics": {
|
||||||
|
"filter": "category:electronics",
|
||||||
|
"aggs": {
|
||||||
|
"expensive": {
|
||||||
|
"filter": "price:[800 TO *]",
|
||||||
|
"aggs": {
|
||||||
|
"avg_rating": { "avg": { "field": "rating" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||||
|
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||||
|
let result = searcher.search(&AllQuery, &collector)?;
|
||||||
|
|
||||||
|
let expected = json!({
|
||||||
|
"in_stock": {
|
||||||
|
"doc_count": 3, // apple, samsung, penguin
|
||||||
|
"electronics": {
|
||||||
|
"doc_count": 2, // apple, samsung
|
||||||
|
"expensive": {
|
||||||
|
"doc_count": 1, // only apple (999)
|
||||||
|
"avg_rating": { "value": 4.5 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||||
|
println!("{}\n", serde_json::to_string_pretty(&result)?);
|
||||||
|
|
||||||
|
// Example 4: Filter with sub-aggregation (terms)
|
||||||
|
println!("=== Example 4: Filter with terms sub-aggregation ===");
|
||||||
|
let agg_req = json!({
|
||||||
|
"electronics": {
|
||||||
|
"filter": "category:electronics",
|
||||||
|
"aggs": {
|
||||||
|
"by_brand": {
|
||||||
|
"terms": { "field": "brand" },
|
||||||
|
"aggs": {
|
||||||
|
"avg_price": { "avg": { "field": "price" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let agg: Aggregations = serde_json::from_value(agg_req)?;
|
||||||
|
let collector = AggregationCollector::from_aggs(agg, Default::default());
|
||||||
|
let result = searcher.search(&AllQuery, &collector)?;
|
||||||
|
|
||||||
|
let expected = json!({
|
||||||
|
"electronics": {
|
||||||
|
"doc_count": 2,
|
||||||
|
"by_brand": {
|
||||||
|
"buckets": [
|
||||||
|
{
|
||||||
|
"key": "samsung",
|
||||||
|
"doc_count": 1,
|
||||||
|
"avg_price": { "value": 799.0 }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "apple",
|
||||||
|
"doc_count": 1,
|
||||||
|
"avg_price": { "value": 999.0 }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sum_other_doc_count": 0,
|
||||||
|
"doc_count_error_upper_bound": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assert_eq!(serde_json::to_value(&result)?, expected);
|
||||||
|
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -10,9 +10,10 @@ use crate::aggregation::accessor_helpers::{
|
|||||||
};
|
};
|
||||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||||
use crate::aggregation::bucket::{
|
use crate::aggregation::bucket::{
|
||||||
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
|
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||||
RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector,
|
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||||
TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal,
|
SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||||
|
TermsAggregationInternal,
|
||||||
};
|
};
|
||||||
use crate::aggregation::metric::{
|
use crate::aggregation::metric::{
|
||||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||||
@@ -24,7 +25,7 @@ use crate::aggregation::metric::{
|
|||||||
use crate::aggregation::segment_agg_result::{
|
use crate::aggregation::segment_agg_result::{
|
||||||
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
|
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
|
||||||
};
|
};
|
||||||
use crate::aggregation::{f64_to_fastfield_u64, AggregationLimitsGuard, Key};
|
use crate::aggregation::{f64_to_fastfield_u64, AggContextParams, Key};
|
||||||
use crate::{SegmentOrdinal, SegmentReader};
|
use crate::{SegmentOrdinal, SegmentReader};
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
@@ -33,7 +34,7 @@ use crate::{SegmentOrdinal, SegmentReader};
|
|||||||
pub struct AggregationsSegmentCtx {
|
pub struct AggregationsSegmentCtx {
|
||||||
/// Request data for each aggregation type.
|
/// Request data for each aggregation type.
|
||||||
pub per_request: PerRequestAggSegCtx,
|
pub per_request: PerRequestAggSegCtx,
|
||||||
pub limits: AggregationLimitsGuard,
|
pub context: AggContextParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AggregationsSegmentCtx {
|
impl AggregationsSegmentCtx {
|
||||||
@@ -67,6 +68,10 @@ impl AggregationsSegmentCtx {
|
|||||||
self.per_request.range_req_data.push(Some(Box::new(data)));
|
self.per_request.range_req_data.push(Some(Box::new(data)));
|
||||||
self.per_request.range_req_data.len() - 1
|
self.per_request.range_req_data.len() - 1
|
||||||
}
|
}
|
||||||
|
pub(crate) fn push_filter_req_data(&mut self, data: FilterAggReqData) -> usize {
|
||||||
|
self.per_request.filter_req_data.push(Some(Box::new(data)));
|
||||||
|
self.per_request.filter_req_data.len() - 1
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
|
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
|
||||||
@@ -102,6 +107,12 @@ impl AggregationsSegmentCtx {
|
|||||||
.as_deref()
|
.as_deref()
|
||||||
.expect("range_req_data slot is empty (taken)")
|
.expect("range_req_data slot is empty (taken)")
|
||||||
}
|
}
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
|
||||||
|
self.per_request.filter_req_data[idx]
|
||||||
|
.as_deref()
|
||||||
|
.expect("filter_req_data slot is empty (taken)")
|
||||||
|
}
|
||||||
|
|
||||||
// ---------- mutable getters ----------
|
// ---------- mutable getters ----------
|
||||||
|
|
||||||
@@ -179,6 +190,21 @@ impl AggregationsSegmentCtx {
|
|||||||
debug_assert!(self.per_request.range_req_data[idx].is_none());
|
debug_assert!(self.per_request.range_req_data[idx].is_none());
|
||||||
self.per_request.range_req_data[idx] = Some(value);
|
self.per_request.range_req_data[idx] = Some(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Move out the boxed Filter request at `idx`, leaving `None`.
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn take_filter_req_data(&mut self, idx: usize) -> Box<FilterAggReqData> {
|
||||||
|
self.per_request.filter_req_data[idx]
|
||||||
|
.take()
|
||||||
|
.expect("filter_req_data slot is empty (taken)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Put back a Filter request into an empty slot at `idx`.
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn put_back_filter_req_data(&mut self, idx: usize, value: Box<FilterAggReqData>) {
|
||||||
|
debug_assert!(self.per_request.filter_req_data[idx].is_none());
|
||||||
|
self.per_request.filter_req_data[idx] = Some(value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Each type of aggregation has its own request data struct. This struct holds
|
/// Each type of aggregation has its own request data struct. This struct holds
|
||||||
@@ -196,6 +222,8 @@ pub struct PerRequestAggSegCtx {
|
|||||||
pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>,
|
pub histogram_req_data: Vec<Option<Box<HistogramAggReqData>>>,
|
||||||
/// RangeAggReqData contains the request data for a range aggregation.
|
/// RangeAggReqData contains the request data for a range aggregation.
|
||||||
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
|
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
|
||||||
|
/// FilterAggReqData contains the request data for a filter aggregation.
|
||||||
|
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
|
||||||
/// Shared by avg, min, max, sum, stats, extended_stats, count
|
/// Shared by avg, min, max, sum, stats, extended_stats, count
|
||||||
pub stats_metric_req_data: Vec<MetricAggReqData>,
|
pub stats_metric_req_data: Vec<MetricAggReqData>,
|
||||||
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
|
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
|
||||||
@@ -226,6 +254,11 @@ impl PerRequestAggSegCtx {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|b| b.as_ref().unwrap().get_memory_consumption())
|
.map(|b| b.as_ref().unwrap().get_memory_consumption())
|
||||||
.sum::<usize>()
|
.sum::<usize>()
|
||||||
|
+ self
|
||||||
|
.filter_req_data
|
||||||
|
.iter()
|
||||||
|
.map(|b| b.as_ref().unwrap().get_memory_consumption())
|
||||||
|
.sum::<usize>()
|
||||||
+ self
|
+ self
|
||||||
.stats_metric_req_data
|
.stats_metric_req_data
|
||||||
.iter()
|
.iter()
|
||||||
@@ -277,6 +310,11 @@ impl PerRequestAggSegCtx {
|
|||||||
.expect("range_req_data slot is empty (taken)")
|
.expect("range_req_data slot is empty (taken)")
|
||||||
.name
|
.name
|
||||||
.as_str(),
|
.as_str(),
|
||||||
|
AggKind::Filter => self.filter_req_data[idx]
|
||||||
|
.as_deref()
|
||||||
|
.expect("filter_req_data slot is empty (taken)")
|
||||||
|
.name
|
||||||
|
.as_str(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +357,8 @@ pub(crate) fn build_segment_agg_collectors(
|
|||||||
collectors.push(build_segment_agg_collector(req, node)?);
|
collectors.push(build_segment_agg_collector(req, node)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
req.limits
|
req.context
|
||||||
|
.limits
|
||||||
.add_memory_consumed(req.per_request.get_memory_consumption() as u64)?;
|
.add_memory_consumed(req.per_request.get_memory_consumption() as u64)?;
|
||||||
// Single collector special case
|
// Single collector special case
|
||||||
if collectors.len() == 1 {
|
if collectors.len() == 1 {
|
||||||
@@ -394,6 +433,9 @@ pub(crate) fn build_segment_agg_collector(
|
|||||||
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||||
req, node,
|
req, node,
|
||||||
)?)),
|
)?)),
|
||||||
|
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
|
||||||
|
req, node,
|
||||||
|
)?)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,6 +465,7 @@ pub enum AggKind {
|
|||||||
Histogram,
|
Histogram,
|
||||||
DateHistogram,
|
DateHistogram,
|
||||||
Range,
|
Range,
|
||||||
|
Filter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AggKind {
|
impl AggKind {
|
||||||
@@ -437,6 +480,7 @@ impl AggKind {
|
|||||||
AggKind::Histogram => "Histogram",
|
AggKind::Histogram => "Histogram",
|
||||||
AggKind::DateHistogram => "DateHistogram",
|
AggKind::DateHistogram => "DateHistogram",
|
||||||
AggKind::Range => "Range",
|
AggKind::Range => "Range",
|
||||||
|
AggKind::Filter => "Filter",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,11 +490,11 @@ pub(crate) fn build_aggregations_data_from_req(
|
|||||||
aggs: &Aggregations,
|
aggs: &Aggregations,
|
||||||
reader: &SegmentReader,
|
reader: &SegmentReader,
|
||||||
segment_ordinal: SegmentOrdinal,
|
segment_ordinal: SegmentOrdinal,
|
||||||
limits: AggregationLimitsGuard,
|
context: AggContextParams,
|
||||||
) -> crate::Result<AggregationsSegmentCtx> {
|
) -> crate::Result<AggregationsSegmentCtx> {
|
||||||
let mut data = AggregationsSegmentCtx {
|
let mut data = AggregationsSegmentCtx {
|
||||||
per_request: Default::default(),
|
per_request: Default::default(),
|
||||||
limits,
|
context,
|
||||||
};
|
};
|
||||||
|
|
||||||
for (name, agg) in aggs.iter() {
|
for (name, agg) in aggs.iter() {
|
||||||
@@ -686,6 +730,36 @@ fn build_nodes(
|
|||||||
children,
|
children,
|
||||||
}])
|
}])
|
||||||
}
|
}
|
||||||
|
AggregationVariants::Filter(filter_req) => {
|
||||||
|
// Build the query and evaluator upfront
|
||||||
|
let schema = reader.schema();
|
||||||
|
let tokenizers = &data.context.tokenizers;
|
||||||
|
let query = filter_req.parse_query(&schema, tokenizers)?;
|
||||||
|
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
|
||||||
|
query,
|
||||||
|
schema.clone(),
|
||||||
|
reader,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Pre-allocate buffer for batch filtering
|
||||||
|
let max_doc = reader.max_doc();
|
||||||
|
let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize);
|
||||||
|
let matching_docs_buffer = Vec::with_capacity(buffer_capacity);
|
||||||
|
|
||||||
|
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
|
||||||
|
name: agg_name.to_string(),
|
||||||
|
req: filter_req.clone(),
|
||||||
|
segment_reader: reader.clone(),
|
||||||
|
evaluator,
|
||||||
|
matching_docs_buffer,
|
||||||
|
});
|
||||||
|
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||||
|
Ok(vec![AggRefNode {
|
||||||
|
kind: AggKind::Filter,
|
||||||
|
idx_in_req_data,
|
||||||
|
children,
|
||||||
|
}])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ use rustc_hash::FxHashMap;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::bucket::{
|
use super::bucket::{
|
||||||
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
|
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
|
||||||
|
TermsAggregation,
|
||||||
};
|
};
|
||||||
use super::metric::{
|
use super::metric::{
|
||||||
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
||||||
@@ -130,6 +131,9 @@ pub enum AggregationVariants {
|
|||||||
/// Put data into buckets of terms.
|
/// Put data into buckets of terms.
|
||||||
#[serde(rename = "terms")]
|
#[serde(rename = "terms")]
|
||||||
Terms(TermsAggregation),
|
Terms(TermsAggregation),
|
||||||
|
/// Filter documents into a single bucket.
|
||||||
|
#[serde(rename = "filter")]
|
||||||
|
Filter(FilterAggregation),
|
||||||
|
|
||||||
// Metric aggregation types
|
// Metric aggregation types
|
||||||
/// Computes the average of the extracted values.
|
/// Computes the average of the extracted values.
|
||||||
@@ -175,6 +179,7 @@ impl AggregationVariants {
|
|||||||
AggregationVariants::Range(range) => vec![range.field.as_str()],
|
AggregationVariants::Range(range) => vec![range.field.as_str()],
|
||||||
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
|
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
|
||||||
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
|
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
|
||||||
|
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
|
||||||
AggregationVariants::Average(avg) => vec![avg.field_name()],
|
AggregationVariants::Average(avg) => vec![avg.field_name()],
|
||||||
AggregationVariants::Count(count) => vec![count.field_name()],
|
AggregationVariants::Count(count) => vec![count.field_name()],
|
||||||
AggregationVariants::Max(max) => vec![max.field_name()],
|
AggregationVariants::Max(max) => vec![max.field_name()],
|
||||||
|
|||||||
@@ -156,6 +156,8 @@ pub enum BucketResult {
|
|||||||
/// The upper bound error for the doc count of each term.
|
/// The upper bound error for the doc count of each term.
|
||||||
doc_count_error_upper_bound: Option<u64>,
|
doc_count_error_upper_bound: Option<u64>,
|
||||||
},
|
},
|
||||||
|
/// This is the filter result - a single bucket with sub-aggregations
|
||||||
|
Filter(FilterBucketResult),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BucketResult {
|
impl BucketResult {
|
||||||
@@ -172,6 +174,11 @@ impl BucketResult {
|
|||||||
sum_other_doc_count: _,
|
sum_other_doc_count: _,
|
||||||
doc_count_error_upper_bound: _,
|
doc_count_error_upper_bound: _,
|
||||||
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
|
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
|
||||||
|
BucketResult::Filter(filter_result) => {
|
||||||
|
// Filter doesn't add to bucket count - it's not a user-facing bucket
|
||||||
|
// Only count sub-aggregation buckets
|
||||||
|
filter_result.sub_aggregations.get_bucket_count()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -308,3 +315,25 @@ impl RangeBucketEntry {
|
|||||||
1 + self.sub_aggregation.get_bucket_count()
|
1 + self.sub_aggregation.get_bucket_count()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This is the filter bucket result, which contains the document count and sub-aggregations.
|
||||||
|
///
|
||||||
|
/// # JSON Format
|
||||||
|
/// ```json
|
||||||
|
/// {
|
||||||
|
/// "electronics_only": {
|
||||||
|
/// "doc_count": 2,
|
||||||
|
/// "avg_price": {
|
||||||
|
/// "value": 150.0
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub struct FilterBucketResult {
|
||||||
|
/// Number of documents in the filter bucket
|
||||||
|
pub doc_count: u64,
|
||||||
|
/// Sub-aggregation results
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub sub_aggregations: AggregationResults,
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ use crate::aggregation::agg_result::AggregationResults;
|
|||||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||||
use crate::aggregation::collector::AggregationCollector;
|
use crate::aggregation::collector::AggregationCollector;
|
||||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||||
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
|
|
||||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||||
use crate::aggregation::DistributedAggregationCollector;
|
use crate::aggregation::DistributedAggregationCollector;
|
||||||
use crate::query::{AllQuery, TermQuery};
|
use crate::query::{AllQuery, TermQuery};
|
||||||
@@ -128,10 +127,8 @@ fn test_aggregation_flushing(
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let agg_res: AggregationResults = if use_distributed_collector {
|
let agg_res: AggregationResults = if use_distributed_collector {
|
||||||
let collector = DistributedAggregationCollector::from_aggs(
|
let collector =
|
||||||
agg_req.clone(),
|
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
|
||||||
AggregationLimitsGuard::default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let searcher = reader.searcher();
|
let searcher = reader.searcher();
|
||||||
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
|
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
|
||||||
|
|||||||
1755
src/aggregation/bucket/filter.rs
Normal file
1755
src/aggregation/bucket/filter.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -352,7 +352,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
|||||||
|
|
||||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||||
if mem_delta > 0 {
|
if mem_delta > 0 {
|
||||||
agg_data.limits.add_memory_consumed(mem_delta as u64)?;
|
agg_data
|
||||||
|
.context
|
||||||
|
.limits
|
||||||
|
.add_memory_consumed(mem_delta as u64)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
//! - [Range](RangeAggregation)
|
//! - [Range](RangeAggregation)
|
||||||
//! - [Terms](TermsAggregation)
|
//! - [Terms](TermsAggregation)
|
||||||
|
|
||||||
|
mod filter;
|
||||||
mod histogram;
|
mod histogram;
|
||||||
mod range;
|
mod range;
|
||||||
mod term_agg;
|
mod term_agg;
|
||||||
@@ -30,6 +31,7 @@ mod term_missing_agg;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
pub use filter::*;
|
||||||
pub use histogram::*;
|
pub use histogram::*;
|
||||||
pub use range::*;
|
pub use range::*;
|
||||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ impl SegmentRangeCollector {
|
|||||||
})
|
})
|
||||||
.collect::<crate::Result<_>>()?;
|
.collect::<crate::Result<_>>()?;
|
||||||
|
|
||||||
req_data.limits.add_memory_consumed(
|
req_data.context.limits.add_memory_consumed(
|
||||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
|||||||
@@ -443,7 +443,10 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
|||||||
|
|
||||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||||
if mem_delta > 0 {
|
if mem_delta > 0 {
|
||||||
agg_data.limits.add_memory_consumed(mem_delta as u64)?;
|
agg_data
|
||||||
|
.context
|
||||||
|
.limits
|
||||||
|
.add_memory_consumed(mem_delta as u64)?;
|
||||||
}
|
}
|
||||||
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
|
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
|
||||||
|
|
||||||
@@ -2341,10 +2344,8 @@ mod tests {
|
|||||||
let search = |idx: &Index,
|
let search = |idx: &Index,
|
||||||
agg_req: &Aggregations|
|
agg_req: &Aggregations|
|
||||||
-> crate::Result<IntermediateAggregationResults> {
|
-> crate::Result<IntermediateAggregationResults> {
|
||||||
let collector = DistributedAggregationCollector::from_aggs(
|
let collector =
|
||||||
agg_req.clone(),
|
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
|
||||||
AggregationLimitsGuard::default(),
|
|
||||||
);
|
|
||||||
let reader = idx.reader()?;
|
let reader = idx.reader()?;
|
||||||
let searcher = reader.searcher();
|
let searcher = reader.searcher();
|
||||||
let agg_res = searcher.search(&AllQuery, &collector)?;
|
let agg_res = searcher.search(&AllQuery, &collector)?;
|
||||||
@@ -2388,9 +2389,8 @@ mod tests {
|
|||||||
let agg_res2 = search(&index2, &agg_req2)?;
|
let agg_res2 = search(&index2, &agg_req2)?;
|
||||||
|
|
||||||
agg_res.merge_fruits(agg_res2).unwrap();
|
agg_res.merge_fruits(agg_res2).unwrap();
|
||||||
let agg_json = serde_json::to_value(
|
let agg_json =
|
||||||
&agg_res.into_final_result(agg_req2, AggregationLimitsGuard::default())?,
|
serde_json::to_value(&agg_res.into_final_result(agg_req2, Default::default())?)?;
|
||||||
)?;
|
|
||||||
|
|
||||||
// hosts:
|
// hosts:
|
||||||
let hosts = &agg_json["hosts"]["buckets"];
|
let hosts = &agg_json["hosts"]["buckets"];
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ use super::agg_req::Aggregations;
|
|||||||
use super::agg_result::AggregationResults;
|
use super::agg_result::AggregationResults;
|
||||||
use super::buf_collector::BufAggregationCollector;
|
use super::buf_collector::BufAggregationCollector;
|
||||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||||
use super::segment_agg_result::{AggregationLimitsGuard, SegmentAggregationCollector};
|
use super::segment_agg_result::SegmentAggregationCollector;
|
||||||
|
use super::AggContextParams;
|
||||||
use crate::aggregation::agg_data::{
|
use crate::aggregation::agg_data::{
|
||||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||||
};
|
};
|
||||||
@@ -21,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
|
|||||||
/// The collector collects all aggregations by the underlying aggregation request.
|
/// The collector collects all aggregations by the underlying aggregation request.
|
||||||
pub struct AggregationCollector {
|
pub struct AggregationCollector {
|
||||||
agg: Aggregations,
|
agg: Aggregations,
|
||||||
limits: AggregationLimitsGuard,
|
context: AggContextParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AggregationCollector {
|
impl AggregationCollector {
|
||||||
@@ -29,8 +30,8 @@ impl AggregationCollector {
|
|||||||
///
|
///
|
||||||
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
||||||
/// bucket limit)
|
/// bucket limit)
|
||||||
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
|
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
|
||||||
Self { agg, limits }
|
Self { agg, context }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ impl AggregationCollector {
|
|||||||
/// into the final `AggregationResults` via the `into_final_result()` method.
|
/// into the final `AggregationResults` via the `into_final_result()` method.
|
||||||
pub struct DistributedAggregationCollector {
|
pub struct DistributedAggregationCollector {
|
||||||
agg: Aggregations,
|
agg: Aggregations,
|
||||||
limits: AggregationLimitsGuard,
|
context: AggContextParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DistributedAggregationCollector {
|
impl DistributedAggregationCollector {
|
||||||
@@ -52,8 +53,8 @@ impl DistributedAggregationCollector {
|
|||||||
///
|
///
|
||||||
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
|
||||||
/// bucket limit)
|
/// bucket limit)
|
||||||
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
|
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
|
||||||
Self { agg, limits }
|
Self { agg, context }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
|
|||||||
&self.agg,
|
&self.agg,
|
||||||
reader,
|
reader,
|
||||||
segment_local_id,
|
segment_local_id,
|
||||||
&self.limits,
|
&self.context,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ impl Collector for AggregationCollector {
|
|||||||
&self.agg,
|
&self.agg,
|
||||||
reader,
|
reader,
|
||||||
segment_local_id,
|
segment_local_id,
|
||||||
&self.limits,
|
&self.context,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ impl Collector for AggregationCollector {
|
|||||||
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
|
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
|
||||||
) -> crate::Result<Self::Fruit> {
|
) -> crate::Result<Self::Fruit> {
|
||||||
let res = merge_fruits(segment_fruits)?;
|
let res = merge_fruits(segment_fruits)?;
|
||||||
res.into_final_result(self.agg.clone(), self.limits.clone())
|
res.into_final_result(self.agg.clone(), self.context.limits.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,10 +147,10 @@ impl AggregationSegmentCollector {
|
|||||||
agg: &Aggregations,
|
agg: &Aggregations,
|
||||||
reader: &SegmentReader,
|
reader: &SegmentReader,
|
||||||
segment_ordinal: SegmentOrdinal,
|
segment_ordinal: SegmentOrdinal,
|
||||||
limits: &AggregationLimitsGuard,
|
context: &AggContextParams,
|
||||||
) -> crate::Result<Self> {
|
) -> crate::Result<Self> {
|
||||||
let mut agg_data =
|
let mut agg_data =
|
||||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, limits.clone())?;
|
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||||
let result =
|
let result =
|
||||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,9 @@ use super::metric::{
|
|||||||
};
|
};
|
||||||
use super::segment_agg_result::AggregationLimitsGuard;
|
use super::segment_agg_result::AggregationLimitsGuard;
|
||||||
use super::{format_date, AggregationError, Key, SerializedKey};
|
use super::{format_date, AggregationError, Key, SerializedKey};
|
||||||
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
|
use crate::aggregation::agg_result::{
|
||||||
|
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
|
||||||
|
};
|
||||||
use crate::aggregation::bucket::TermsAggregationInternal;
|
use crate::aggregation::bucket::TermsAggregationInternal;
|
||||||
use crate::aggregation::metric::CardinalityCollector;
|
use crate::aggregation::metric::CardinalityCollector;
|
||||||
use crate::TantivyError;
|
use crate::TantivyError;
|
||||||
@@ -246,6 +248,10 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
|
|||||||
Cardinality(_) => IntermediateAggregationResult::Metric(
|
Cardinality(_) => IntermediateAggregationResult::Metric(
|
||||||
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
|
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
|
||||||
),
|
),
|
||||||
|
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
|
||||||
|
doc_count: 0,
|
||||||
|
sub_aggregations: IntermediateAggregationResults::default(),
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,6 +438,13 @@ pub enum IntermediateBucketResult {
|
|||||||
/// The term buckets
|
/// The term buckets
|
||||||
buckets: IntermediateTermBucketResult,
|
buckets: IntermediateTermBucketResult,
|
||||||
},
|
},
|
||||||
|
/// Filter aggregation - a single bucket with sub-aggregations
|
||||||
|
Filter {
|
||||||
|
/// Document count in the filter bucket
|
||||||
|
doc_count: u64,
|
||||||
|
/// Sub-aggregation results
|
||||||
|
sub_aggregations: IntermediateAggregationResults,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntermediateBucketResult {
|
impl IntermediateBucketResult {
|
||||||
@@ -515,6 +528,18 @@ impl IntermediateBucketResult {
|
|||||||
req.sub_aggregation(),
|
req.sub_aggregation(),
|
||||||
limits,
|
limits,
|
||||||
),
|
),
|
||||||
|
IntermediateBucketResult::Filter {
|
||||||
|
doc_count,
|
||||||
|
sub_aggregations,
|
||||||
|
} => {
|
||||||
|
// Convert sub-aggregation results to final format
|
||||||
|
let final_sub_aggregations = sub_aggregations
|
||||||
|
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
|
||||||
|
Ok(BucketResult::Filter(FilterBucketResult {
|
||||||
|
doc_count,
|
||||||
|
sub_aggregations: final_sub_aggregations,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,6 +593,19 @@ impl IntermediateBucketResult {
|
|||||||
|
|
||||||
*buckets_left = buckets?;
|
*buckets_left = buckets?;
|
||||||
}
|
}
|
||||||
|
(
|
||||||
|
IntermediateBucketResult::Filter {
|
||||||
|
doc_count: doc_count_left,
|
||||||
|
sub_aggregations: sub_aggs_left,
|
||||||
|
},
|
||||||
|
IntermediateBucketResult::Filter {
|
||||||
|
doc_count: doc_count_right,
|
||||||
|
sub_aggregations: sub_aggs_right,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
*doc_count_left += doc_count_right;
|
||||||
|
sub_aggs_left.merge_fruits(sub_aggs_right)?;
|
||||||
|
}
|
||||||
(IntermediateBucketResult::Range(_), _) => {
|
(IntermediateBucketResult::Range(_), _) => {
|
||||||
panic!("try merge on different types")
|
panic!("try merge on different types")
|
||||||
}
|
}
|
||||||
@@ -577,6 +615,9 @@ impl IntermediateBucketResult {
|
|||||||
(IntermediateBucketResult::Terms { .. }, _) => {
|
(IntermediateBucketResult::Terms { .. }, _) => {
|
||||||
panic!("try merge on different types")
|
panic!("try merge on different types")
|
||||||
}
|
}
|
||||||
|
(IntermediateBucketResult::Filter { .. }, _) => {
|
||||||
|
panic!("try merge on different types")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,6 +160,28 @@ use itertools::Itertools;
|
|||||||
use serde::de::{self, Visitor};
|
use serde::de::{self, Visitor};
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
|
use crate::tokenizer::TokenizerManager;
|
||||||
|
|
||||||
|
/// Context parameters for aggregation execution
|
||||||
|
///
|
||||||
|
/// This struct holds shared resources needed during aggregation execution:
|
||||||
|
/// - `limits`: Memory and bucket limits for the aggregation
|
||||||
|
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct AggContextParams {
|
||||||
|
/// Aggregation limits (memory and bucket count)
|
||||||
|
pub limits: AggregationLimitsGuard,
|
||||||
|
/// Tokenizer manager for query string parsing
|
||||||
|
pub tokenizers: TokenizerManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AggContextParams {
|
||||||
|
/// Create new aggregation context parameters
|
||||||
|
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
|
||||||
|
Self { limits, tokenizers }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
|
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
|
||||||
let parsed = value
|
let parsed = value
|
||||||
.parse::<f64>()
|
.parse::<f64>()
|
||||||
@@ -390,7 +412,10 @@ mod tests {
|
|||||||
query: Option<(&str, &str)>,
|
query: Option<(&str, &str)>,
|
||||||
limits: AggregationLimitsGuard,
|
limits: AggregationLimitsGuard,
|
||||||
) -> crate::Result<Value> {
|
) -> crate::Result<Value> {
|
||||||
let collector = AggregationCollector::from_aggs(agg_req, limits);
|
let collector = AggregationCollector::from_aggs(
|
||||||
|
agg_req,
|
||||||
|
AggContextParams::new(limits, index.tokenizers().clone()),
|
||||||
|
);
|
||||||
|
|
||||||
let reader = index.reader()?;
|
let reader = index.reader()?;
|
||||||
let searcher = reader.searcher();
|
let searcher = reader.searcher();
|
||||||
|
|||||||
Reference in New Issue
Block a user