diff --git a/Cargo.toml b/Cargo.toml index 7217ac0b7..6e7235f9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } futures-util = { version = "0.3.28", optional = true } futures-channel = { version = "0.3.28", optional = true } fnv = "1.0.7" +typetag = "0.2.21" [target.'cfg(windows)'.dependencies] winapi = "0.3.9" @@ -87,7 +88,7 @@ more-asserts = "0.3.1" rand_distr = "0.4.3" time = { version = "0.3.10", features = ["serde-well-known", "macros"] } postcard = { version = "1.0.4", features = [ - "use-std", + "use-std", ], default-features = false } [target.'cfg(not(windows))'.dev-dependencies] @@ -175,4 +176,3 @@ harness = false [[bench]] name = "and_or_queries" harness = false - diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 2f47bc0e4..8df1ba539 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -74,6 +74,12 @@ fn bench_agg(mut group: InputGroup) { register!(group, histogram_with_term_agg_few); register!(group, avg_and_range_with_avg_sub_agg); + // Filter aggregation benchmarks + register!(group, filter_agg_all_query_count_agg); + register!(group, filter_agg_term_query_count_agg); + register!(group, filter_agg_all_query_with_sub_aggs); + register!(group, filter_agg_term_query_with_sub_aggs); + group.run(); } @@ -472,3 +478,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { Ok(index) } + +// Filter aggregation benchmarks + +fn filter_agg_all_query_count_agg(index: &Index) { + let agg_req = json!({ + "filtered": { + "filter": "*", + "aggs": { + "count": { "value_count": { "field": "score" } } + } + } + }); + execute_agg(index, agg_req); +} + +fn filter_agg_term_query_count_agg(index: &Index) { + let agg_req = json!({ + "filtered": { + "filter": "text:cool", + "aggs": { + "count": { "value_count": { "field": "score" } } + } + } + }); + execute_agg(index, agg_req); +} + +fn filter_agg_all_query_with_sub_aggs(index: &Index) { + let agg_req = json!({ + "filtered": { + "filter": "*", + "aggs": { + "avg_score": { "avg": { "field": "score" } }, + "stats_score": { "stats": { "field": "score_f64" } }, + "terms_text": { + "terms": { "field": "text_few_terms" } + } + } + } + }); + execute_agg(index, agg_req); +} + +fn filter_agg_term_query_with_sub_aggs(index: &Index) { + let agg_req = json!({ + "filtered": { + "filter": "text:cool", + "aggs": { + "avg_score": { "avg": { "field": "score" } }, + "stats_score": { "stats": { "field": "score_f64" } }, + "terms_text": { + "terms": { "field": "text_few_terms" } + } + } + } + }); + execute_agg(index, agg_req); +} diff --git a/examples/filter_aggregation.rs b/examples/filter_aggregation.rs new file mode 100644 index 000000000..6aafec6de --- /dev/null +++ b/examples/filter_aggregation.rs @@ -0,0 +1,212 @@ +// # Filter Aggregation Example +// +// This example demonstrates filter aggregations - creating buckets of documents +// matching specific queries, with nested aggregations computed on each bucket. +// +// Filter aggregations are useful for computing metrics on different subsets of +// your data in a single query, like "average price overall + average price for +// electronics + count of in-stock items". + +use serde_json::json; +use tantivy::aggregation::agg_req::Aggregations; +use tantivy::aggregation::AggregationCollector; +use tantivy::query::AllQuery; +use tantivy::schema::{Schema, FAST, INDEXED, TEXT}; +use tantivy::{doc, Index}; + +fn main() -> tantivy::Result<()> { + // Create a simple product schema + let mut schema_builder = Schema::builder(); + schema_builder.add_text_field("category", TEXT | FAST); + schema_builder.add_text_field("brand", TEXT | FAST); + schema_builder.add_u64_field("price", FAST); + schema_builder.add_f64_field("rating", FAST); + schema_builder.add_bool_field("in_stock", FAST | INDEXED); + let schema = schema_builder.build(); + + // Create index and add sample products + let index = Index::create_in_ram(schema.clone()); + let mut writer = index.writer(50_000_000)?; + + writer.add_document(doc!( + schema.get_field("category")? => "electronics", + schema.get_field("brand")? => "apple", + schema.get_field("price")? => 999u64, + schema.get_field("rating")? => 4.5f64, + schema.get_field("in_stock")? => true + ))?; + writer.add_document(doc!( + schema.get_field("category")? => "electronics", + schema.get_field("brand")? => "samsung", + schema.get_field("price")? => 799u64, + schema.get_field("rating")? => 4.2f64, + schema.get_field("in_stock")? => true + ))?; + writer.add_document(doc!( + schema.get_field("category")? => "clothing", + schema.get_field("brand")? => "nike", + schema.get_field("price")? => 120u64, + schema.get_field("rating")? => 4.1f64, + schema.get_field("in_stock")? => false + ))?; + writer.add_document(doc!( + schema.get_field("category")? => "books", + schema.get_field("brand")? => "penguin", + schema.get_field("price")? => 25u64, + schema.get_field("rating")? => 4.8f64, + schema.get_field("in_stock")? => true + ))?; + + writer.commit()?; + + let reader = index.reader()?; + let searcher = reader.searcher(); + + // Example 1: Basic filter with metric aggregation + println!("=== Example 1: Electronics average price ==="); + let agg_req = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + } + }); + + let agg: Aggregations = serde_json::from_value(agg_req)?; + let collector = AggregationCollector::from_aggs(agg, Default::default()); + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "avg_price": { "value": 899.0 } + } + }); + assert_eq!(serde_json::to_value(&result)?, expected); + println!("{}\n", serde_json::to_string_pretty(&result)?); + + // Example 2: Multiple independent filters + println!("=== Example 2: Multiple filters in one query ==="); + let agg_req = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { "avg_price": { "avg": { "field": "price" } } } + }, + "in_stock": { + "filter": "in_stock:true", + "aggs": { "count": { "value_count": { "field": "brand" } } } + }, + "high_rated": { + "filter": "rating:[4.5 TO *]", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let agg: Aggregations = serde_json::from_value(agg_req)?; + let collector = AggregationCollector::from_aggs(agg, Default::default()); + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "avg_price": { "value": 899.0 } + }, + "in_stock": { + "doc_count": 3, + "count": { "value": 3.0 } + }, + "high_rated": { + "doc_count": 2, + "count": { "value": 2.0 } + } + }); + assert_eq!(serde_json::to_value(&result)?, expected); + println!("{}\n", serde_json::to_string_pretty(&result)?); + + // Example 3: Nested filters - progressive refinement + println!("=== Example 3: Nested filters ==="); + let agg_req = json!({ + "in_stock": { + "filter": "in_stock:true", + "aggs": { + "electronics": { + "filter": "category:electronics", + "aggs": { + "expensive": { + "filter": "price:[800 TO *]", + "aggs": { + "avg_rating": { "avg": { "field": "rating" } } + } + } + } + } + } + } + }); + + let agg: Aggregations = serde_json::from_value(agg_req)?; + let collector = AggregationCollector::from_aggs(agg, Default::default()); + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "in_stock": { + "doc_count": 3, // apple, samsung, penguin + "electronics": { + "doc_count": 2, // apple, samsung + "expensive": { + "doc_count": 1, // only apple (999) + "avg_rating": { "value": 4.5 } + } + } + } + }); + assert_eq!(serde_json::to_value(&result)?, expected); + println!("{}\n", serde_json::to_string_pretty(&result)?); + + // Example 4: Filter with sub-aggregation (terms) + println!("=== Example 4: Filter with terms sub-aggregation ==="); + let agg_req = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { + "by_brand": { + "terms": { "field": "brand" }, + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + } + } + } + }); + + let agg: Aggregations = serde_json::from_value(agg_req)?; + let collector = AggregationCollector::from_aggs(agg, Default::default()); + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "by_brand": { + "buckets": [ + { + "key": "samsung", + "doc_count": 1, + "avg_price": { "value": 799.0 } + }, + { + "key": "apple", + "doc_count": 1, + "avg_price": { "value": 999.0 } + } + ], + "sum_other_doc_count": 0, + "doc_count_error_upper_bound": 0 + } + } + }); + assert_eq!(serde_json::to_value(&result)?, expected); + println!("{}", serde_json::to_string_pretty(&result)?); + + Ok(()) +} diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index c7dc2e4e6..3b29830a7 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -10,9 +10,10 @@ use crate::aggregation::accessor_helpers::{ }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::bucket::{ - HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, - RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, - TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, + FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, + MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, + SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + TermsAggregationInternal, }; use crate::aggregation::metric::{ AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, @@ -24,7 +25,7 @@ use crate::aggregation::metric::{ use crate::aggregation::segment_agg_result::{ GenericSegmentAggregationResultsCollector, SegmentAggregationCollector, }; -use crate::aggregation::{f64_to_fastfield_u64, AggregationLimitsGuard, Key}; +use crate::aggregation::{f64_to_fastfield_u64, AggContextParams, Key}; use crate::{SegmentOrdinal, SegmentReader}; #[derive(Default)] @@ -33,7 +34,7 @@ use crate::{SegmentOrdinal, SegmentReader}; pub struct AggregationsSegmentCtx { /// Request data for each aggregation type. pub per_request: PerRequestAggSegCtx, - pub limits: AggregationLimitsGuard, + pub context: AggContextParams, } impl AggregationsSegmentCtx { @@ -67,6 +68,10 @@ impl AggregationsSegmentCtx { self.per_request.range_req_data.push(Some(Box::new(data))); self.per_request.range_req_data.len() - 1 } + pub(crate) fn push_filter_req_data(&mut self, data: FilterAggReqData) -> usize { + self.per_request.filter_req_data.push(Some(Box::new(data))); + self.per_request.filter_req_data.len() - 1 + } #[inline] pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData { @@ -102,6 +107,12 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } + #[inline] + pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData { + self.per_request.filter_req_data[idx] + .as_deref() + .expect("filter_req_data slot is empty (taken)") + } // ---------- mutable getters ---------- @@ -179,6 +190,21 @@ impl AggregationsSegmentCtx { debug_assert!(self.per_request.range_req_data[idx].is_none()); self.per_request.range_req_data[idx] = Some(value); } + + /// Move out the boxed Filter request at `idx`, leaving `None`. + #[inline] + pub(crate) fn take_filter_req_data(&mut self, idx: usize) -> Box { + self.per_request.filter_req_data[idx] + .take() + .expect("filter_req_data slot is empty (taken)") + } + + /// Put back a Filter request into an empty slot at `idx`. + #[inline] + pub(crate) fn put_back_filter_req_data(&mut self, idx: usize, value: Box) { + debug_assert!(self.per_request.filter_req_data[idx].is_none()); + self.per_request.filter_req_data[idx] = Some(value); + } } /// Each type of aggregation has its own request data struct. This struct holds @@ -196,6 +222,8 @@ pub struct PerRequestAggSegCtx { pub histogram_req_data: Vec>>, /// RangeAggReqData contains the request data for a range aggregation. pub range_req_data: Vec>>, + /// FilterAggReqData contains the request data for a filter aggregation. + pub filter_req_data: Vec>>, /// Shared by avg, min, max, sum, stats, extended_stats, count pub stats_metric_req_data: Vec, /// CardinalityAggReqData contains the request data for a cardinality aggregation. @@ -226,6 +254,11 @@ impl PerRequestAggSegCtx { .iter() .map(|b| b.as_ref().unwrap().get_memory_consumption()) .sum::() + + self + .filter_req_data + .iter() + .map(|b| b.as_ref().unwrap().get_memory_consumption()) + .sum::() + self .stats_metric_req_data .iter() @@ -277,6 +310,11 @@ impl PerRequestAggSegCtx { .expect("range_req_data slot is empty (taken)") .name .as_str(), + AggKind::Filter => self.filter_req_data[idx] + .as_deref() + .expect("filter_req_data slot is empty (taken)") + .name + .as_str(), } } @@ -319,7 +357,8 @@ pub(crate) fn build_segment_agg_collectors( collectors.push(build_segment_agg_collector(req, node)?); } - req.limits + req.context + .limits .add_memory_consumed(req.per_request.get_memory_consumption() as u64)?; // Single collector special case if collectors.len() == 1 { @@ -394,6 +433,9 @@ pub(crate) fn build_segment_agg_collector( AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( req, node, )?)), + AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate( + req, node, + )?)), } } @@ -423,6 +465,7 @@ pub enum AggKind { Histogram, DateHistogram, Range, + Filter, } impl AggKind { @@ -437,6 +480,7 @@ impl AggKind { AggKind::Histogram => "Histogram", AggKind::DateHistogram => "DateHistogram", AggKind::Range => "Range", + AggKind::Filter => "Filter", } } } @@ -446,11 +490,11 @@ pub(crate) fn build_aggregations_data_from_req( aggs: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: AggregationLimitsGuard, + context: AggContextParams, ) -> crate::Result { let mut data = AggregationsSegmentCtx { per_request: Default::default(), - limits, + context, }; for (name, agg) in aggs.iter() { @@ -686,6 +730,36 @@ fn build_nodes( children, }]) } + AggregationVariants::Filter(filter_req) => { + // Build the query and evaluator upfront + let schema = reader.schema(); + let tokenizers = &data.context.tokenizers; + let query = filter_req.parse_query(&schema, tokenizers)?; + let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new( + query, + schema.clone(), + reader, + )?; + + // Pre-allocate buffer for batch filtering + let max_doc = reader.max_doc(); + let buffer_capacity = crate::docset::COLLECT_BLOCK_BUFFER_LEN.min(max_doc as usize); + let matching_docs_buffer = Vec::with_capacity(buffer_capacity); + + let idx_in_req_data = data.push_filter_req_data(FilterAggReqData { + name: agg_name.to_string(), + req: filter_req.clone(), + segment_reader: reader.clone(), + evaluator, + matching_docs_buffer, + }); + let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; + Ok(vec![AggRefNode { + kind: AggKind::Filter, + idx_in_req_data, + children, + }]) + } } } diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index e5dfed85a..5fa187537 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -32,7 +32,8 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::{ - DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, + DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation, + TermsAggregation, }; use super::metric::{ AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, @@ -130,6 +131,9 @@ pub enum AggregationVariants { /// Put data into buckets of terms. #[serde(rename = "terms")] Terms(TermsAggregation), + /// Filter documents into a single bucket. + #[serde(rename = "filter")] + Filter(FilterAggregation), // Metric aggregation types /// Computes the average of the extracted values. @@ -175,6 +179,7 @@ impl AggregationVariants { AggregationVariants::Range(range) => vec![range.field.as_str()], AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()], AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::Filter(filter) => filter.get_fast_field_names(), AggregationVariants::Average(avg) => vec![avg.field_name()], AggregationVariants::Count(count) => vec![count.field_name()], AggregationVariants::Max(max) => vec![max.field_name()], diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 005568182..34b5e2043 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -156,6 +156,8 @@ pub enum BucketResult { /// The upper bound error for the doc count of each term. doc_count_error_upper_bound: Option, }, + /// This is the filter result - a single bucket with sub-aggregations + Filter(FilterBucketResult), } impl BucketResult { @@ -172,6 +174,11 @@ impl BucketResult { sum_other_doc_count: _, doc_count_error_upper_bound: _, } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), + BucketResult::Filter(filter_result) => { + // Filter doesn't add to bucket count - it's not a user-facing bucket + // Only count sub-aggregation buckets + filter_result.sub_aggregations.get_bucket_count() + } } } } @@ -308,3 +315,25 @@ impl RangeBucketEntry { 1 + self.sub_aggregation.get_bucket_count() } } + +/// This is the filter bucket result, which contains the document count and sub-aggregations. +/// +/// # JSON Format +/// ```json +/// { +/// "electronics_only": { +/// "doc_count": 2, +/// "avg_price": { +/// "value": 150.0 +/// } +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FilterBucketResult { + /// Number of documents in the filter bucket + pub doc_count: u64, + /// Sub-aggregation results + #[serde(flatten)] + pub sub_aggregations: AggregationResults, +} diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 7a8febfc6..fede0c7c7 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -5,7 +5,6 @@ use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; -use crate::aggregation::segment_agg_result::AggregationLimitsGuard; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; use crate::query::{AllQuery, TermQuery}; @@ -128,10 +127,8 @@ fn test_aggregation_flushing( .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs( - agg_req.clone(), - AggregationLimitsGuard::default(), - ); + let collector = + DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default()); let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs new file mode 100644 index 000000000..18fc40b25 --- /dev/null +++ b/src/aggregation/bucket/filter.rs @@ -0,0 +1,1755 @@ +use std::fmt::Debug; + +use common::BitSet; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, +}; +use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; +use crate::docset::DocSet; +use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; +use crate::schema::Schema; +use crate::tokenizer::TokenizerManager; +use crate::{DocId, SegmentReader, TantivyError}; + +/// A trait for query builders that can build queries programmatically. +/// +/// This trait enables programmatic query construction for filter aggregations with +/// full serialization/deserialization support for distributed aggregation scenarios. +/// +/// # Why This Exists +/// +/// Filter aggregations need to support both: +/// - Query strings (simple, always serializable) +/// - Programmatic query construction (flexible, with serialization support) +/// +/// This trait provides the programmatic query construction capability with full +/// serialization support via the `typetag` crate. +/// +/// # Implementation Requirements +/// +/// Implementors must: +/// 1. Derive `Debug`, `Clone`, `Serialize`, and `Deserialize` +/// 2. Use `#[typetag::serde]` attribute on the impl block +/// 3. Implement `build_query()` to construct the query from schema/tokenizers +/// 4. Implement `box_clone()` to enable cloning (typically just `Box::new(self.clone())`) +/// +/// # Example +/// +/// ```rust +/// use tantivy::aggregation::bucket::QueryBuilder; +/// use tantivy::query::{Query, TermQuery}; +/// use tantivy::schema::{Schema, IndexRecordOption}; +/// use tantivy::tokenizer::TokenizerManager; +/// use tantivy::Term; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// struct TermQueryBuilder { +/// field_name: String, +/// term_text: String, +/// } +/// +/// #[typetag::serde] +/// impl QueryBuilder for TermQueryBuilder { +/// fn build_query( +/// &self, +/// schema: &Schema, +/// _tokenizers: &TokenizerManager, +/// ) -> tantivy::Result> { +/// let field = schema.get_field(&self.field_name)?; +/// let term = Term::from_field_text(field, &self.term_text); +/// Ok(Box::new(TermQuery::new(term, IndexRecordOption::Basic))) +/// } +/// +/// fn box_clone(&self) -> Box { +/// Box::new(self.clone()) +/// } +/// } +/// +/// // Create an instance +/// let builder = TermQueryBuilder { +/// field_name: "category".to_string(), +/// term_text: "electronics".to_string(), +/// }; +/// ``` +#[typetag::serde(tag = "type")] +pub trait QueryBuilder: Debug + Send + Sync { + /// Build a query from the given schema and tokenizer manager. + /// + /// This method is called once when creating the FilterAggReqData for a segment. + /// + /// # Parameters + /// - `schema`: The index schema for field lookups + /// - `tokenizers`: The tokenizer manager for text analysis + /// + /// # Returns + /// A boxed Query object, or an error if construction fails + fn build_query( + &self, + schema: &Schema, + tokenizers: &TokenizerManager, + ) -> crate::Result>; + + /// Clone this builder into a boxed trait object. + /// + /// Since builders are just data (no state), this simply clones the data. + /// The typical implementation is: + /// ```rust,ignore + /// fn box_clone(&self) -> Box { + /// Box::new(self.clone()) + /// } + /// ``` + fn box_clone(&self) -> Box; +} + +/// Filter aggregation creates a single bucket containing documents that match a query. +/// +/// # Usage +/// +/// ## Query String (Recommended) +/// ```rust +/// use tantivy::aggregation::bucket::FilterAggregation; +/// +/// // Query strings are parsed using Tantivy's standard QueryParser +/// let filter_agg = FilterAggregation::new("category:electronics AND price:[100 TO 500]".to_string()); +/// ``` +/// +/// ## Custom Query Builder +/// ```rust +/// use tantivy::aggregation::bucket::{FilterAggregation, QueryBuilder}; +/// use tantivy::query::{Query, TermQuery}; +/// use tantivy::schema::{Schema, IndexRecordOption}; +/// use tantivy::tokenizer::TokenizerManager; +/// use tantivy::Term; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// struct MyBuilder { +/// field_name: String, +/// term_text: String, +/// } +/// +/// #[typetag::serde] +/// impl QueryBuilder for MyBuilder { +/// fn build_query( +/// &self, +/// schema: &Schema, +/// _tokenizers: &TokenizerManager, +/// ) -> tantivy::Result> { +/// let field = schema.get_field(&self.field_name)?; +/// let term = Term::from_field_text(field, &self.term_text); +/// Ok(Box::new(TermQuery::new(term, IndexRecordOption::Basic))) +/// } +/// +/// fn box_clone(&self) -> Box { +/// Box::new(self.clone()) +/// } +/// } +/// +/// let builder = MyBuilder { +/// field_name: "category".to_string(), +/// term_text: "electronics".to_string(), +/// }; +/// let filter_agg = FilterAggregation::new_with_builder(Box::new(builder)); +/// ``` +/// +/// # Result +/// The filter aggregation returns a single bucket with: +/// - `doc_count`: Number of documents matching the filter +/// - Sub-aggregation results computed on the filtered document set +#[derive(Debug, Clone)] +pub struct FilterAggregation { + /// The query for filtering - can be either a query string or a query builder + query: FilterQuery, +} + +/// Represents different ways to specify a filter query +pub enum FilterQuery { + /// Query string that will be parsed using Tantivy's standard parsing facilities + /// + /// This is the recommended approach as it's serializable and doesn't carry runtime state. + QueryString(String), + + /// Custom query builder for programmatic query building + /// + /// This variant stores a builder that builds the query once when creating FilterAggReqData. + /// + /// This is useful for: + /// - Custom query types not expressible as query strings + /// - Programmatic query construction based on schema + /// - Extension query types + /// + /// **Note**: The builder is serializable and can be deserialized. + CustomBuilder(Box), +} + +impl Debug for FilterQuery { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FilterQuery::QueryString(s) => f.debug_tuple("QueryString").field(s).finish(), + FilterQuery::CustomBuilder(_) => { + f.debug_struct("CustomBuilder").finish_non_exhaustive() + } + } + } +} + +impl Clone for FilterQuery { + fn clone(&self) -> Self { + match self { + FilterQuery::QueryString(query_string) => { + FilterQuery::QueryString(query_string.clone()) + } + FilterQuery::CustomBuilder(builder) => FilterQuery::CustomBuilder(builder.box_clone()), + } + } +} + +impl FilterAggregation { + /// Create a new filter aggregation with a query string + /// The query string will be parsed using the QueryParser::parse_query() method. + pub fn new(query_string: String) -> Self { + Self { + query: FilterQuery::QueryString(query_string), + } + } + + /// Create a new filter aggregation with a query builder + /// + /// The builder will be called once when creating the FilterAggReqData for each segment. + /// + /// # Example + /// ```rust + /// use tantivy::aggregation::bucket::{FilterAggregation, QueryBuilder}; + /// use tantivy::query::{Query, TermQuery}; + /// use tantivy::schema::{Schema, IndexRecordOption}; + /// use tantivy::tokenizer::TokenizerManager; + /// use tantivy::Term; + /// use serde::{Serialize, Deserialize}; + /// + /// #[derive(Debug, Clone, Serialize, Deserialize)] + /// struct MyBuilder { + /// field_name: String, + /// term_text: String, + /// } + /// + /// #[typetag::serde] + /// impl QueryBuilder for MyBuilder { + /// fn build_query( + /// &self, + /// schema: &Schema, + /// _tokenizers: &TokenizerManager, + /// ) -> tantivy::Result> { + /// let field = schema.get_field(&self.field_name)?; + /// let term = Term::from_field_text(field, &self.term_text); + /// Ok(Box::new(TermQuery::new(term, IndexRecordOption::Basic))) + /// } + /// + /// fn box_clone(&self) -> Box { + /// Box::new(self.clone()) + /// } + /// } + /// + /// let builder = MyBuilder { + /// field_name: "category".to_string(), + /// term_text: "electronics".to_string(), + /// }; + /// let filter_agg = FilterAggregation::new_with_builder(Box::new(builder)); + /// ``` + pub fn new_with_builder(builder: Box) -> Self { + Self { + query: FilterQuery::CustomBuilder(builder), + } + } + + /// Parse the query into a Tantivy Query object + /// + /// For query strings, this uses the QueryParser::parse_query() method. + /// For custom builders, builds the query using the builder. + pub(crate) fn parse_query( + &self, + schema: &Schema, + tokenizer_manager: &TokenizerManager, + ) -> crate::Result> { + match &self.query { + FilterQuery::QueryString(query_str) => { + let query_parser = + QueryParser::new(schema.clone(), vec![], tokenizer_manager.clone()); + + query_parser + .parse_query(query_str) + .map_err(|e| TantivyError::InvalidArgument(e.to_string())) + } + FilterQuery::CustomBuilder(builder) => { + // Build the query using the builder + builder.build_query(schema, tokenizer_manager) + } + } + } + + /// Parse the query with a custom QueryParser + /// + /// This method allows using a pre-configured QueryParser with custom settings + /// like field boosts, fuzzy matching, default fields, etc. + /// + /// For custom builders, this method is not supported and will return an error. + /// Custom builders need schema and tokenizers which are not accessible from QueryParser. + pub fn parse_query_with_parser( + &self, + query_parser: &QueryParser, + ) -> crate::Result> { + match &self.query { + FilterQuery::QueryString(query_str) => query_parser + .parse_query(query_str) + .map_err(|e| TantivyError::InvalidArgument(e.to_string())), + FilterQuery::CustomBuilder(_) => Err(TantivyError::InvalidArgument( + "parse_query_with_parser is not supported for custom query builders. Use \ + parse_query with explicit schema and tokenizers instead." + .to_string(), + )), + } + } + + /// Get the fast field names used by this aggregation (none for filter aggregation) + pub fn get_fast_field_names(&self) -> Vec<&str> { + // Filter aggregation cannot introspect query fast field dependencies. + // + // As of PR #2693, queries can fall back to fast fields when fields are not indexed + // (e.g., TermQuery falls back to RangeQuery on fast fields). However, the Query + // trait has no mechanism to report these dependencies. + // + // For prefetching optimization, callers must analyze the query themselves to + // determine fast field usage. This requires: + // 1. Parsing the query string to extract field references + // 2. Checking the schema to see if those fields are indexed or fast-only + // 3. Collecting fast field names for non-indexed fields + // + // This limitation exists because: + // - Query::weight() is called during execution, not during planning + // - The fallback decision is based on schema configuration + // - There's no Query trait method to declare potential fast field dependencies + vec![] + } +} + +// Custom serialization implementation +impl Serialize for FilterAggregation { + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + match &self.query { + FilterQuery::QueryString(query_string) => { + // Serialize query strings as plain strings + query_string.serialize(serializer) + } + FilterQuery::CustomBuilder(builder) => { + // Serialize custom builders using typetag (includes type information) + builder.serialize(serializer) + } + } + } +} + +impl<'de> Deserialize<'de> for FilterAggregation { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + // We need to peek at the value to determine if it's a string or an object + use serde::de::Error; + use serde_json::Value; + + let value = Value::deserialize(deserializer)?; + + let query = if let Some(query_string) = value.as_str() { + // It's a plain string - query string + FilterQuery::QueryString(query_string.to_string()) + } else { + // It's an object - custom builder with typetag + let builder: Box = serde_json::from_value(value).map_err(|e| { + D::Error::custom(format!("Failed to deserialize QueryBuilder: {}", e)) + })?; + FilterQuery::CustomBuilder(builder) + }; + + Ok(FilterAggregation { query }) + } +} + +// PartialEq is required because AggregationVariants derives it +// We implement it manually to handle custom builders which cannot be compared +impl PartialEq for FilterAggregation { + fn eq(&self, other: &Self) -> bool { + match (&self.query, &other.query) { + (FilterQuery::QueryString(a), FilterQuery::QueryString(b)) => a == b, + // Custom builders cannot be compared for equality + _ => false, + } + } +} + +/// Request data for filter aggregation +/// This struct holds the per-segment data needed to execute a filter aggregation +pub struct FilterAggReqData { + /// The name of the filter aggregation + pub name: String, + /// The filter aggregation + pub req: FilterAggregation, + /// The segment reader + pub segment_reader: SegmentReader, + /// Document evaluator for the filter query (precomputed BitSet) + /// This is built once when the request data is created + pub evaluator: DocumentQueryEvaluator, + /// Reusable buffer for matching documents to minimize allocations during collection + pub matching_docs_buffer: Vec, +} + +impl FilterAggReqData { + pub(crate) fn get_memory_consumption(&self) -> usize { + // Estimate: name + segment reader reference + bitset + buffer capacity + self.name.len() + + std::mem::size_of::() + + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + + self.matching_docs_buffer.capacity() * std::mem::size_of::() + } +} + +/// Document evaluator for filter queries using BitSet +pub struct DocumentQueryEvaluator { + /// BitSet containing all matching documents for this segment. + /// For AllQuery, this is a full BitSet (all bits set). + /// For other queries, only matching document bits are set. + pub(crate) bitset: BitSet, +} + +impl DocumentQueryEvaluator { + /// Create and initialize a document query evaluator for a segment + /// This executes the query upfront and collects results into a BitSet, + /// unless the query is AllQuery in which case we skip BitSet creation. + pub(crate) fn new( + query: Box, + schema: Schema, + segment_reader: &SegmentReader, + ) -> crate::Result { + let max_doc = segment_reader.max_doc(); + + // Optimization: Detect AllQuery and create a full BitSet + if query.as_any().downcast_ref::().is_some() { + return Ok(Self { + bitset: BitSet::with_max_value_and_full(max_doc), + }); + } + + // Get the weight for the query + let weight = query.weight(EnableScoring::disabled_from_schema(&schema))?; + + // Get a scorer that iterates over matching documents + let mut scorer = weight.scorer(segment_reader, 1.0)?; + + // Create a BitSet to hold all matching documents + let mut bitset = BitSet::with_max_value(max_doc); + + // Collect all matching documents into the BitSet + // This is the upfront cost, but then lookups are O(1) + let mut doc = scorer.doc(); + while doc != crate::TERMINATED { + bitset.insert(doc); + doc = scorer.advance(); + } + + Ok(Self { bitset }) + } + + /// Evaluate if a document matches the filter query + /// O(1) lookup in the precomputed BitSet + #[inline] + pub fn matches_document(&self, doc: DocId) -> bool { + self.bitset.contains(doc) + } + + /// Filter a batch of documents + /// Returns matching documents from the input batch + #[inline] + pub fn filter_batch(&self, docs: &[DocId], output: &mut Vec) { + for &doc in docs { + if self.bitset.contains(doc) { + output.push(doc); + } + } + } +} + +impl Debug for DocumentQueryEvaluator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DocumentQueryEvaluator") + .field("num_matches", &self.bitset.len()) + .finish() + } +} + +/// Segment collector for filter aggregation +pub struct SegmentFilterCollector { + /// Document count in this bucket + doc_count: u64, + /// Sub-aggregation collectors + sub_aggregations: Option>, + /// Accessor index for this filter aggregation (to access FilterAggReqData) + accessor_idx: usize, +} + +impl SegmentFilterCollector { + /// Create a new filter segment collector following the new agg_data pattern + pub(crate) fn from_req_and_validate( + req: &mut AggregationsSegmentCtx, + node: &AggRefNode, + ) -> crate::Result { + // Build sub-aggregation collectors if any + let sub_agg_collector = if !node.children.is_empty() { + Some(build_segment_agg_collectors(req, &node.children)?) + } else { + None + }; + + Ok(SegmentFilterCollector { + doc_count: 0, + sub_aggregations: sub_agg_collector, + accessor_idx: node.idx_in_req_data, + }) + } +} + +impl Debug for SegmentFilterCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentFilterCollector") + .field("doc_count", &self.doc_count) + .field("has_sub_aggs", &self.sub_aggregations.is_some()) + .field("accessor_idx", &self.accessor_idx) + .finish() + } +} + +impl CollectorClone for SegmentFilterCollector { + fn clone_box(&self) -> Box { + // For now, panic - this needs proper implementation with weight recreation + panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation") + } +} + +impl SegmentAggregationCollector for SegmentFilterCollector { + fn add_intermediate_aggregation_result( + self: Box, + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let mut sub_results = IntermediateAggregationResults::default(); + + if let Some(sub_aggs) = self.sub_aggregations { + sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; + } + + // Create the filter bucket result + let filter_bucket_result = IntermediateBucketResult::Filter { + doc_count: self.doc_count, + sub_aggregations: sub_results, + }; + + // Get the name of this filter aggregation + let name = agg_data.per_request.filter_req_data[self.accessor_idx] + .as_ref() + .expect("filter_req_data slot is empty") + .name + .clone(); + + results.push( + name, + IntermediateAggregationResult::Bucket(filter_bucket_result), + )?; + + Ok(()) + } + + fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + // Access the evaluator from FilterAggReqData + let req_data = agg_data.get_filter_req_data(self.accessor_idx); + + // O(1) BitSet lookup to check if document matches filter + if req_data.evaluator.matches_document(doc) { + self.doc_count += 1; + + // If we have sub-aggregations, collect on them for this filtered document + if let Some(sub_aggs) = &mut self.sub_aggregations { + sub_aggs.collect(doc, agg_data)?; + } + } + Ok(()) + } + + #[inline] + fn collect_block( + &mut self, + docs: &[DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + if docs.is_empty() { + return Ok(()); + } + + // Take the request data to avoid borrow checker issues with sub-aggregations + let mut req = agg_data.take_filter_req_data(self.accessor_idx); + + // Use batch filtering with O(1) BitSet lookups + req.matching_docs_buffer.clear(); + req.evaluator + .filter_batch(docs, &mut req.matching_docs_buffer); + + self.doc_count += req.matching_docs_buffer.len() as u64; + + // Batch process sub-aggregations if we have matches + if !req.matching_docs_buffer.is_empty() { + if let Some(sub_aggs) = &mut self.sub_aggregations { + // Use collect_block for better sub-aggregation performance + sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; + } + } + + // Put the request data back + agg_data.put_back_filter_req_data(self.accessor_idx, req); + + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + if let Some(ref mut sub_aggs) = self.sub_aggregations { + sub_aggs.flush(agg_data)?; + } + Ok(()) + } +} + +/// Intermediate result for filter aggregation +#[derive(Debug, Clone, PartialEq)] +pub struct IntermediateFilterBucketResult { + /// Document count in this bucket + pub doc_count: u64, + /// Sub-aggregation results + pub sub_aggregations: IntermediateAggregationResults, +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use serde_json::{json, Value}; + + use super::*; + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::{AggContextParams, AggregationCollector}; + use crate::query::{AllQuery, QueryParser, TermQuery}; + use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT}; + use crate::{doc, Index, IndexWriter}; + + // Test helper functions + fn aggregation_results_to_json(results: &AggregationResults) -> Value { + serde_json::to_value(results).expect("Failed to serialize aggregation results") + } + + fn json_values_match(actual: &Value, expected: &Value, tolerance: f64) -> bool { + match (actual, expected) { + (Value::Number(a), Value::Number(e)) => { + let a_f64 = a.as_f64().unwrap_or(0.0); + let e_f64 = e.as_f64().unwrap_or(0.0); + (a_f64 - e_f64).abs() < tolerance + } + (Value::Object(a_map), Value::Object(e_map)) => { + if a_map.len() != e_map.len() { + return false; + } + for (key, expected_val) in e_map { + match a_map.get(key) { + Some(actual_val) => { + if !json_values_match(actual_val, expected_val, tolerance) { + return false; + } + } + None => return false, + } + } + true + } + (Value::Array(a_arr), Value::Array(e_arr)) => { + if a_arr.len() != e_arr.len() { + return false; + } + for (actual_item, expected_item) in a_arr.iter().zip(e_arr.iter()) { + if !json_values_match(actual_item, expected_item, tolerance) { + return false; + } + } + true + } + _ => actual == expected, + } + } + + fn assert_aggregation_results_match( + actual_results: &AggregationResults, + expected_json: Value, + tolerance: f64, + ) { + let actual_json = aggregation_results_to_json(actual_results); + + if !json_values_match(&actual_json, &expected_json, tolerance) { + panic!( + "Aggregation results do not match expected JSON.\nActual:\n{}\nExpected:\n{}", + serde_json::to_string_pretty(&actual_json).unwrap(), + serde_json::to_string_pretty(&expected_json).unwrap() + ); + } + } + + macro_rules! assert_agg_results { + ($actual:expr, $expected:expr) => { + assert_aggregation_results_match($actual, $expected, 0.1) + }; + ($actual:expr, $expected:expr, $tolerance:expr) => { + assert_aggregation_results_match($actual, $expected, $tolerance) + }; + } + + fn create_standard_test_index() -> crate::Result { + let mut schema_builder = Schema::builder(); + let category = schema_builder.add_text_field("category", TEXT | FAST); + let brand = schema_builder.add_text_field("brand", TEXT | FAST); + let price = schema_builder.add_u64_field("price", FAST | INDEXED); + let rating = schema_builder.add_f64_field("rating", FAST); + let in_stock = schema_builder.add_bool_field("in_stock", FAST | INDEXED); + + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer: IndexWriter = index.writer(50_000_000)?; + + writer.add_document(doc!( + category => "electronics", brand => "apple", + price => 999u64, rating => 4.5f64, in_stock => true + ))?; + writer.add_document(doc!( + category => "electronics", brand => "samsung", + price => 799u64, rating => 4.2f64, in_stock => true + ))?; + writer.add_document(doc!( + category => "clothing", brand => "nike", + price => 120u64, rating => 4.1f64, in_stock => false + ))?; + writer.add_document(doc!( + category => "books", brand => "penguin", + price => 25u64, rating => 4.8f64, in_stock => true + ))?; + + writer.commit()?; + Ok(index) + } + + /// Helper to create aggregation collector with serialization roundtrip + /// This ensures all aggregations can be serialized and deserialized correctly + fn create_collector( + index: &Index, + aggregations: Aggregations, + ) -> crate::Result { + // Serialize and deserialize the aggregations + let serialized = serde_json::to_string(&aggregations)?; + let deserialized: Aggregations = serde_json::from_str(&serialized)?; + + // Create collector with deserialized aggregations + Ok(AggregationCollector::from_aggs( + deserialized, + AggContextParams::new(Default::default(), index.tokenizers().clone()), + )) + } + + #[test] + fn test_basic_filter_with_metric_agg() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "avg_price": { "value": 899.0 } // (999 + 799) / 2 + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_filter_with_no_matches() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "furniture": { + "filter": "category:furniture", + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "furniture": { + "doc_count": 0, + "avg_price": { "value": null } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_multiple_independent_filters() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { "avg_price": { "avg": { "field": "price" } } } + }, + "in_stock": { + "filter": "in_stock:true", + "aggs": { "count": { "value_count": { "field": "brand" } } } + }, + "high_rated": { + "filter": "rating:[4.5 TO *]", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "avg_price": { "value": 899.0 } + }, + "in_stock": { + "doc_count": 3, // apple, samsung, penguin + "count": { "value": 3.0 } + }, + "high_rated": { + "doc_count": 2, // apple (4.5), penguin (4.8) + "count": { "value": 2.0 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + // ============================================================================ + // Query Type Tests + // ============================================================================ + + #[test] + fn test_term_query_filter() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "apple_products": { + "filter": "brand:apple", + "aggs": { "max_price": { "max": { "field": "price" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "apple_products": { + "doc_count": 1, + "max_price": { "value": 999.0 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_range_query_filter() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "mid_price": { + "filter": "price:[100 TO 900]", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "mid_price": { + "doc_count": 2, // samsung (799), nike (120) + "count": { "value": 2.0 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_boolean_query_filter() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "premium_electronics": { + "filter": "category:electronics AND price:[800 TO *]", + "aggs": { "avg_rating": { "avg": { "field": "rating" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "premium_electronics": { + "doc_count": 1, // Only apple (999) is >= 800 in tantivy's range semantics + "avg_rating": { "value": 4.5 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_bool_field_filter() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "in_stock": { + "filter": "in_stock:true", + "aggs": { "avg_price": { "avg": { "field": "price" } } } + }, + "out_of_stock": { + "filter": "in_stock:false", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "in_stock": { + "doc_count": 3, // apple, samsung, penguin + "avg_price": { "value": 607.67 } // (999 + 799 + 25) / 3 ≈ 607.67 + }, + "out_of_stock": { + "doc_count": 1, // nike + "count": { "value": 1.0 } + } + }); + + assert_agg_results!(&result, expected, 1.0); + Ok(()) + } + + // ============================================================================ + // Nested Filter Tests + // ============================================================================ + + #[test] + fn test_two_level_nested_filters() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "all": { + "filter": "*", + "aggs": { + "electronics": { + "filter": "category:electronics", + "aggs": { + "expensive": { + "filter": "price:[900 TO *]", + "aggs": { + "count": { "value_count": { "field": "brand" } } + } + } + } + } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "all": { + "doc_count": 4, + "electronics": { + "doc_count": 2, + "expensive": { + "doc_count": 1, // Only apple (999) is >= 900 + "count": { "value": 1.0 } + } + } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_deeply_nested_filters() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "level1": { + "filter": "*", + "aggs": { + "level2": { + "filter": "in_stock:true", + "aggs": { + "level3": { + "filter": "rating:[4.0 TO *]", + "aggs": { + "level4": { + "filter": "price:[500 TO *]", + "aggs": { + "final_count": { "value_count": { "field": "brand" } } + } + } + } + } + } + } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "level1": { + "doc_count": 4, + "level2": { + "doc_count": 3, // in_stock: apple, samsung, penguin + "level3": { + "doc_count": 3, // all have rating >= 4.0 + "level4": { + "doc_count": 2, // apple (999), samsung (799) + "final_count": { "value": 2.0 } + } + } + } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_multiple_nested_branches() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "root": { + "filter": "*", + "aggs": { + "electronics_branch": { + "filter": "category:electronics", + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + }, + "in_stock_branch": { + "filter": "in_stock:true", + "aggs": { + "count": { "value_count": { "field": "brand" } } + } + } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "root": { + "doc_count": 4, + "electronics_branch": { + "doc_count": 2, + "avg_price": { "value": 899.0 } + }, + "in_stock_branch": { + "doc_count": 3, + "count": { "value": 3.0 } + } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_nested_filters_with_multiple_siblings_at_each_level() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + // Test complex nesting: multiple branches at each level + let agg = json!({ + "all": { + "filter": "*", + "aggs": { + // Level 2: Two independent filters + "expensive": { + "filter": "price:[500 TO *]", + "aggs": { + // Level 3: Multiple branches under "expensive" + "electronics": { + "filter": "category:electronics", + "aggs": { + "avg_rating": { "avg": { "field": "rating" } } + } + }, + "in_stock": { + "filter": "in_stock:true", + "aggs": { + "count": { "value_count": { "field": "brand" } } + } + } + } + }, + "affordable": { + "filter": "price:[0 TO 200]", + "aggs": { + // Level 3: Multiple branches under "affordable" + "books": { + "filter": "category:books", + "aggs": { + "max_rating": { "max": { "field": "rating" } } + } + }, + "clothing": { + "filter": "category:clothing", + "aggs": { + "min_price": { "min": { "field": "price" } } + } + } + } + } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "all": { + "doc_count": 4, + "expensive": { + "doc_count": 2, // apple (999), samsung (799) + "electronics": { + "doc_count": 2, // both are electronics + "avg_rating": { "value": 4.35 } // (4.5 + 4.2) / 2 + }, + "in_stock": { + "doc_count": 2, // both are in stock + "count": { "value": 2.0 } + } + }, + "affordable": { + "doc_count": 2, // nike (120), penguin (25) + "books": { + "doc_count": 1, // penguin (25) + "max_rating": { "value": 4.8 } + }, + "clothing": { + "doc_count": 1, // nike (120) + "min_price": { "value": 120.0 } + } + } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + // ============================================================================ + // Sub-Aggregation Combination Tests + // ============================================================================ + + #[test] + fn test_filter_with_terms_sub_agg() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { + "brands": { + "terms": { "field": "brand" }, + "aggs": { + "avg_price": { "avg": { "field": "price" } } + } + } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + // Verify the structure exists and has expected doc_count + let expected = json!({ + "electronics": { + "doc_count": 2, + "brands": { + "buckets": [ + { + "key": "samsung", + "doc_count": 1, + "avg_price": { "value": 799.0 } + }, + { + "key": "apple", + "doc_count": 1, + "avg_price": { "value": 999.0 } + } + ], + "sum_other_doc_count": 0, + "doc_count_error_upper_bound": 0 + } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_filter_with_multiple_metric_aggs() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { + "price_stats": { "stats": { "field": "price" } }, + "rating_avg": { "avg": { "field": "rating" } }, + "count": { "value_count": { "field": "brand" } } + } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, + "price_stats": { + "count": 2, + "min": 799.0, + "max": 999.0, + "sum": 1798.0, + "avg": 899.0 + }, + "rating_avg": { "value": 4.35 }, + "count": { "value": 2.0 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + // ============================================================================ + // Edge Cases and Error Handling + // ============================================================================ + + #[test] + fn test_filter_on_empty_index() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let _category = schema_builder.add_text_field("category", TEXT | FAST); + let _price = schema_builder.add_u64_field("price", FAST); + + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer: IndexWriter = index.writer(50_000_000)?; + writer.commit()?; // Commit empty index + + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { "avg_price": { "avg": { "field": "price" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 0, + "avg_price": { "value": null } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + #[test] + fn test_malformed_query_string() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + // Empty query string + let agg = json!({ + "test": { + "filter": "", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let result = serde_json::from_value::(agg) + .map_err(|e| crate::TantivyError::InvalidArgument(e.to_string())) + .and_then(|aggregations| { + let collector = create_collector(&index, aggregations)?; + searcher.search(&AllQuery, &collector) + }); + + // Empty string should either work (matching nothing) or error gracefully + assert!(result.is_ok() || result.is_err()); + Ok(()) + } + + #[test] + fn test_filter_with_base_query() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let schema = index.schema(); + + // Use a base query to pre-filter to in_stock items only + let in_stock_field = schema.get_field("in_stock").unwrap(); + let base_query = TermQuery::new( + Term::from_field_bool(in_stock_field, true), + IndexRecordOption::Basic, + ); + + let agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&base_query, &collector)?; + + let expected = json!({ + "electronics": { + "doc_count": 2, // Both in-stock electronics + "count": { "value": 2.0 } + } + }); + + assert_agg_results!(&result, expected); + Ok(()) + } + + // ============================================================================ + // Custom Query Integration Tests + // ============================================================================ + + #[test] + fn test_custom_query_builder() -> crate::Result<()> { + // Define a query builder with full serde support + #[derive(Debug, Clone, Serialize, Deserialize)] + struct TestTermQueryBuilder { + field_name: String, + term_text: String, + } + + #[typetag::serde(name = "TestTermQueryBuilder")] + impl QueryBuilder for TestTermQueryBuilder { + fn build_query( + &self, + schema: &Schema, + _tokenizers: &TokenizerManager, + ) -> crate::Result> { + let field = schema.get_field(&self.field_name)?; + let term = Term::from_field_text(field, &self.term_text); + Ok(Box::new(TermQuery::new(term, IndexRecordOption::Basic))) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } + } + + let index = create_standard_test_index()?; + + // Create a filter aggregation with a custom query builder + let builder = TestTermQueryBuilder { + field_name: "category".to_string(), + term_text: "electronics".to_string(), + }; + let filter_agg = FilterAggregation::new_with_builder(Box::new(builder)); + + // Test that the query can be parsed + let schema = index.schema(); + let tokenizers = index.tokenizers(); + let query = filter_agg.parse_query(&schema, tokenizers)?; + + // Verify the query was built correctly (it should be a TermQuery) + assert!(format!("{:?}", query).contains("TermQuery")); + + // Test that it can be cloned + let cloned = filter_agg.clone(); + let query2 = cloned.parse_query(&schema, tokenizers)?; + assert!(format!("{:?}", query2).contains("TermQuery")); + + // Verify that custom builders CAN be serialized with typetag + let serialized = serde_json::to_string(&filter_agg)?; + assert!( + serialized.contains("TestTermQueryBuilder"), + "Serialized JSON should contain the type tag" + ); + assert!( + serialized.contains("electronics"), + "Serialized JSON should contain the field data" + ); + + // Verify that it can be deserialized + let deserialized: FilterAggregation = serde_json::from_str(&serialized)?; + let query3 = deserialized.parse_query(&schema, tokenizers)?; + assert!(format!("{:?}", query3).contains("TermQuery")); + + Ok(()) + } + + #[test] + fn test_query_string_serialization() -> crate::Result<()> { + // Query strings should serialize/deserialize correctly + let filter_agg = FilterAggregation::new("category:electronics".to_string()); + + let serialized = serde_json::to_string(&filter_agg)?; + assert!(serialized.contains("electronics")); + + let deserialized: FilterAggregation = serde_json::from_str(&serialized)?; + // Verify it deserializes correctly by using it in an aggregation + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg = json!({ + "test": { + "filter": deserialized, + "aggs": { "count": { "value_count": { "field": "brand" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector = create_collector(&index, aggregations)?; + let result = searcher.search(&AllQuery, &collector)?; + + // Should match 2 electronics + let result_json = serde_json::to_value(&result)?; + assert_eq!(result_json["test"]["doc_count"], 2); + + Ok(()) + } + + #[test] + fn test_query_builder_serialization_roundtrip() -> crate::Result<()> { + // Define a serializable query builder + #[derive(Debug, Clone, Serialize, Deserialize)] + struct RoundtripTermQueryBuilder { + field_name: String, + term_text: String, + } + + #[typetag::serde(name = "RoundtripTermQueryBuilder")] + impl QueryBuilder for RoundtripTermQueryBuilder { + fn build_query( + &self, + schema: &Schema, + _tokenizers: &TokenizerManager, + ) -> crate::Result> { + let field = schema.get_field(&self.field_name)?; + let term = Term::from_field_text(field, &self.term_text); + Ok(Box::new(TermQuery::new(term, IndexRecordOption::Basic))) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } + } + + let index = create_standard_test_index()?; + + // Create a filter aggregation with a custom query builder + let builder = RoundtripTermQueryBuilder { + field_name: "category".to_string(), + term_text: "electronics".to_string(), + }; + let filter_agg = FilterAggregation::new_with_builder(Box::new(builder)); + + // Serialize the filter aggregation + let serialized = serde_json::to_string(&filter_agg)?; + + // Verify the serialized JSON contains the builder data and type tag + assert!( + serialized.contains("RoundtripTermQueryBuilder"), + "Serialized JSON should contain type tag" + ); + assert!( + serialized.contains("category"), + "Serialized JSON should contain field_name" + ); + assert!( + serialized.contains("electronics"), + "Serialized JSON should contain term_text" + ); + + // Deserialize back + let deserialized: FilterAggregation = serde_json::from_str(&serialized)?; + + // Verify the aggregation produces correct results + let agg = json!({ + "filtered": { + "filter": deserialized + } + }); + + let agg_req: Aggregations = serde_json::from_value(agg)?; + let searcher = index.reader()?.searcher(); + let collector = create_collector(&index, agg_req)?; + let agg_res = searcher.search(&AllQuery, &collector)?; + + let result_json = serde_json::to_value(&agg_res)?; + assert_eq!(result_json["filtered"]["doc_count"], 2); + + Ok(()) + } + + // ============================================================================ + // Correctness Validation Tests + // ============================================================================ + + #[test] + fn test_filter_result_correctness_vs_separate_query() -> crate::Result<()> { + let index = create_standard_test_index()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let schema = index.schema(); + + // Method 1: Filter aggregation + let filter_agg = json!({ + "electronics": { + "filter": "category:electronics", + "aggs": { "avg_price": { "avg": { "field": "price" } } } + } + }); + + let aggregations: Aggregations = serde_json::from_value(filter_agg)?; + let collector = create_collector(&index, aggregations)?; + let filter_result = searcher.search(&AllQuery, &collector)?; + + // Method 2: Separate query + let category_field = schema.get_field("category").unwrap(); + let term = Term::from_field_text(category_field, "electronics"); + let term_query = TermQuery::new(term, IndexRecordOption::Basic); + + let separate_agg = json!({ + "result": { "avg": { "field": "price" } } + }); + + let separate_aggregations: Aggregations = serde_json::from_value(separate_agg)?; + let separate_collector = + AggregationCollector::from_aggs(separate_aggregations, Default::default()); + let separate_result = searcher.search(&term_query, &separate_collector)?; + + // Both methods should produce identical results + let filter_expected = json!({ + "electronics": { + "doc_count": 2, + "avg_price": { "value": 899.0 } + } + }); + + let separate_expected = json!({ + "result": { + "value": 899.0 + } + }); + + // Verify filter aggregation result + assert_agg_results!(&filter_result, filter_expected); + + // Verify separate query result matches + assert_agg_results!(&separate_result, separate_expected); + + // This test demonstrates that filter aggregation produces the same results + // as running a separate query with the same condition + Ok(()) + } + + #[test] + fn test_custom_tokenizer_required() -> crate::Result<()> { + use crate::schema::{TextFieldIndexing, TextOptions}; + use crate::tokenizer::{SimpleTokenizer, TextAnalyzer, TokenizerManager}; + + // Create a custom tokenizer that doesn't lowercase (just splits on whitespace) + let custom_tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build(); + + // Register tokenizer + let tokenizers = TokenizerManager::default(); + tokenizers.register("my_custom", custom_tokenizer); + + // Create a schema with a text field that uses our custom tokenizer + let mut schema_builder = Schema::builder(); + let text_field_indexing = TextFieldIndexing::default() + .set_tokenizer("my_custom") + .set_index_option(IndexRecordOption::Basic); + let text_options = TextOptions::default() + .set_indexing_options(text_field_indexing) + .set_stored(); + let text_field = schema_builder.add_text_field("text", text_options); + let schema = schema_builder.build(); + + // Build index with custom tokenizer + let index = crate::IndexBuilder::new() + .schema(schema.clone()) + .tokenizers(tokenizers) + .create_in_ram()?; + let mut writer = index.writer(50_000_000)?; + + // Add documents with UPPERCASE text + writer.add_document(doc!(text_field => "HELLO"))?; + writer.add_document(doc!(text_field => "WORLD"))?; + writer.add_document(doc!(text_field => "hello"))?; // lowercase version + writer.commit()?; + + let reader = index.reader()?; + let searcher = reader.searcher(); + + // Test: With correct tokenizer (from index) - should work + let agg = json!({ + "uppercase_hello": { + "filter": "text:HELLO" + } + }); + + let aggregations: Aggregations = serde_json::from_value(agg)?; + let collector_with_tokenizer = create_collector(&index, aggregations.clone())?; + let result_with_tokenizer = searcher.search(&AllQuery, &collector_with_tokenizer)?; + + // Should match only the UPPERCASE "HELLO" (1 document) + let result_json = serde_json::to_value(&result_with_tokenizer)?; + assert_eq!( + result_json["uppercase_hello"]["doc_count"], 1, + "With custom tokenizer from index, should match exactly 1 UPPERCASE document" + ); + + // Test 2: With default tokenizer (wrong!) - should fail to parse the query + // because "my_custom" tokenizer is not in the default TokenizerManager + let collector_with_default = AggregationCollector::from_aggs( + aggregations, + AggContextParams::new(Default::default(), TokenizerManager::default()), + ); + let result_with_default = searcher.search(&AllQuery, &collector_with_default); + + // This should error because the tokenizer "my_custom" is not registered + assert!( + result_with_default.is_err(), + "Without proper tokenizers, query parsing should fail" + ); + assert!( + result_with_default + .unwrap_err() + .to_string() + .contains("my_custom"), + "Error should mention the missing tokenizer" + ); + + Ok(()) + } +} diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 5a52f7cf0..36c0fe57e 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -352,7 +352,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { let mem_delta = self.get_memory_consumption() - mem_pre; if mem_delta > 0 { - agg_data.limits.add_memory_consumed(mem_delta as u64)?; + agg_data + .context + .limits + .add_memory_consumed(mem_delta as u64)?; } Ok(()) diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index 52a952cb8..cb64fcd75 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -22,6 +22,7 @@ //! - [Range](RangeAggregation) //! - [Terms](TermsAggregation) +mod filter; mod histogram; mod range; mod term_agg; @@ -30,6 +31,7 @@ mod term_missing_agg; use std::collections::HashMap; use std::fmt; +pub use filter::*; pub use histogram::*; pub use range::*; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 121528450..c26872e9b 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -339,7 +339,7 @@ impl SegmentRangeCollector { }) .collect::>()?; - req_data.limits.add_memory_consumed( + req_data.context.limits.add_memory_consumed( buckets.len() as u64 * std::mem::size_of::() as u64, )?; diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index e526da9cf..0b18eaa6b 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -443,7 +443,10 @@ impl SegmentAggregationCollector for SegmentTermCollector { let mem_delta = self.get_memory_consumption() - mem_pre; if mem_delta > 0 { - agg_data.limits.add_memory_consumed(mem_delta as u64)?; + agg_data + .context + .limits + .add_memory_consumed(mem_delta as u64)?; } agg_data.put_back_term_req_data(self.accessor_idx, req_data); @@ -2341,10 +2344,8 @@ mod tests { let search = |idx: &Index, agg_req: &Aggregations| -> crate::Result { - let collector = DistributedAggregationCollector::from_aggs( - agg_req.clone(), - AggregationLimitsGuard::default(), - ); + let collector = + DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default()); let reader = idx.reader()?; let searcher = reader.searcher(); let agg_res = searcher.search(&AllQuery, &collector)?; @@ -2388,9 +2389,8 @@ mod tests { let agg_res2 = search(&index2, &agg_req2)?; agg_res.merge_fruits(agg_res2).unwrap(); - let agg_json = serde_json::to_value( - &agg_res.into_final_result(agg_req2, AggregationLimitsGuard::default())?, - )?; + let agg_json = + serde_json::to_value(&agg_res.into_final_result(agg_req2, Default::default())?)?; // hosts: let hosts = &agg_json["hosts"]["buckets"]; diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 10e3ef526..4c4c2c7f1 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -2,7 +2,8 @@ use super::agg_req::Aggregations; use super::agg_result::AggregationResults; use super::buf_collector::BufAggregationCollector; use super::intermediate_agg_result::IntermediateAggregationResults; -use super::segment_agg_result::{AggregationLimitsGuard, SegmentAggregationCollector}; +use super::segment_agg_result::SegmentAggregationCollector; +use super::AggContextParams; use crate::aggregation::agg_data::{ build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, }; @@ -21,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000; /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, - limits: AggregationLimitsGuard, + context: AggContextParams, } impl AggregationCollector { @@ -29,8 +30,8 @@ impl AggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { - Self { agg, limits } + pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self { + Self { agg, context } } } @@ -44,7 +45,7 @@ impl AggregationCollector { /// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, - limits: AggregationLimitsGuard, + context: AggContextParams, } impl DistributedAggregationCollector { @@ -52,8 +53,8 @@ impl DistributedAggregationCollector { /// /// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and /// bucket limit) - pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self { - Self { agg, limits } + pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self { + Self { agg, context } } } @@ -71,7 +72,7 @@ impl Collector for DistributedAggregationCollector { &self.agg, reader, segment_local_id, - &self.limits, + &self.context, ) } @@ -101,7 +102,7 @@ impl Collector for AggregationCollector { &self.agg, reader, segment_local_id, - &self.limits, + &self.context, ) } @@ -114,7 +115,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - res.into_final_result(self.agg.clone(), self.limits.clone()) + res.into_final_result(self.agg.clone(), self.context.limits.clone()) } } @@ -146,10 +147,10 @@ impl AggregationSegmentCollector { agg: &Aggregations, reader: &SegmentReader, segment_ordinal: SegmentOrdinal, - limits: &AggregationLimitsGuard, + context: &AggContextParams, ) -> crate::Result { let mut agg_data = - build_aggregations_data_from_req(agg, reader, segment_ordinal, limits.clone())?; + build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; let result = BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index f5f373bb0..104131461 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -24,7 +24,9 @@ use super::metric::{ }; use super::segment_agg_result::AggregationLimitsGuard; use super::{format_date, AggregationError, Key, SerializedKey}; -use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; +use crate::aggregation::agg_result::{ + AggregationResults, BucketEntries, BucketEntry, FilterBucketResult, +}; use crate::aggregation::bucket::TermsAggregationInternal; use crate::aggregation::metric::CardinalityCollector; use crate::TantivyError; @@ -246,6 +248,10 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Cardinality(_) => IntermediateAggregationResult::Metric( IntermediateMetricResult::Cardinality(CardinalityCollector::default()), ), + Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter { + doc_count: 0, + sub_aggregations: IntermediateAggregationResults::default(), + }), } } @@ -432,6 +438,13 @@ pub enum IntermediateBucketResult { /// The term buckets buckets: IntermediateTermBucketResult, }, + /// Filter aggregation - a single bucket with sub-aggregations + Filter { + /// Document count in the filter bucket + doc_count: u64, + /// Sub-aggregation results + sub_aggregations: IntermediateAggregationResults, + }, } impl IntermediateBucketResult { @@ -515,6 +528,18 @@ impl IntermediateBucketResult { req.sub_aggregation(), limits, ), + IntermediateBucketResult::Filter { + doc_count, + sub_aggregations, + } => { + // Convert sub-aggregation results to final format + let final_sub_aggregations = sub_aggregations + .into_final_result(req.sub_aggregation().clone(), limits.clone())?; + Ok(BucketResult::Filter(FilterBucketResult { + doc_count, + sub_aggregations: final_sub_aggregations, + })) + } } } @@ -568,6 +593,19 @@ impl IntermediateBucketResult { *buckets_left = buckets?; } + ( + IntermediateBucketResult::Filter { + doc_count: doc_count_left, + sub_aggregations: sub_aggs_left, + }, + IntermediateBucketResult::Filter { + doc_count: doc_count_right, + sub_aggregations: sub_aggs_right, + }, + ) => { + *doc_count_left += doc_count_right; + sub_aggs_left.merge_fruits(sub_aggs_right)?; + } (IntermediateBucketResult::Range(_), _) => { panic!("try merge on different types") } @@ -577,6 +615,9 @@ impl IntermediateBucketResult { (IntermediateBucketResult::Terms { .. }, _) => { panic!("try merge on different types") } + (IntermediateBucketResult::Filter { .. }, _) => { + panic!("try merge on different types") + } } Ok(()) } diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index d94f653c6..ddf60ea4c 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -160,6 +160,28 @@ use itertools::Itertools; use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer, Serialize}; +use crate::tokenizer::TokenizerManager; + +/// Context parameters for aggregation execution +/// +/// This struct holds shared resources needed during aggregation execution: +/// - `limits`: Memory and bucket limits for the aggregation +/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations +#[derive(Clone, Default)] +pub struct AggContextParams { + /// Aggregation limits (memory and bucket count) + pub limits: AggregationLimitsGuard, + /// Tokenizer manager for query string parsing + pub tokenizers: TokenizerManager, +} + +impl AggContextParams { + /// Create new aggregation context parameters + pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self { + Self { limits, tokenizers } + } +} + fn parse_str_into_f64(value: &str) -> Result { let parsed = value .parse::() @@ -390,7 +412,10 @@ mod tests { query: Option<(&str, &str)>, limits: AggregationLimitsGuard, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req, limits); + let collector = AggregationCollector::from_aggs( + agg_req, + AggContextParams::new(limits, index.tokenizers().clone()), + ); let reader = index.reader()?; let searcher = reader.searcher();