From a0fbf436b9aa51111924ab11a021cb9264bdceab Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Thu, 15 Jan 2026 15:43:19 -1000 Subject: [PATCH] Applied PR comment: I would move it outside of the aggregation. You can fetch the fields from the aggregation request and do a validation in a helper function --- src/aggregation/agg_req.rs | 48 ++++++++++++++++++++++++++++++++++++ src/aggregation/agg_tests.rs | 41 ++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 5fa187537..00a9e79c7 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -115,6 +115,54 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet { fast_field_names } +/// Validates that all fields referenced in the aggregation request exist in the schema +/// and are configured as fast fields. +/// +/// This is a convenience function for upfront validation before executing aggregations. +/// Returns an error if any field doesn't exist or is not a fast field. +/// +/// # Example +/// ``` +/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields}; +/// use tantivy::Index; +/// +/// # fn example(index: &Index, agg_req: Aggregations) -> tantivy::Result<()> { +/// let reader = index.reader()?; +/// let searcher = reader.searcher(); +/// +/// // Validate fields before executing +/// for segment_reader in searcher.segment_readers() { +/// validate_aggregation_fields(&agg_req, segment_reader)?; +/// } +/// # Ok(()) +/// # } +/// ``` +pub fn validate_aggregation_fields( + aggs: &Aggregations, + reader: &crate::SegmentReader, +) -> crate::Result<()> { + let field_names = get_fast_field_names(aggs); + let schema = reader.schema(); + + for field_name in field_names { + // Check if the field is either directly in the schema or could be part of a json field + // present in the schema, and verify it's a fast field. + if let Some((field, _path)) = schema.find_field(&field_name) { + let field_type = schema.get_field_entry(field).field_type(); + if !field_type.is_fast() { + return Err(crate::TantivyError::SchemaError(format!( + "Field '{}' is not a fast field. Aggregations require fast fields.", + field_name + ))); + } + } else { + return Err(crate::TantivyError::FieldNotFound(field_name)); + } + } + + Ok(()) +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// All aggregation types. pub enum AggregationVariants { diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 49a8afb37..3042753cb 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -1436,3 +1436,44 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() { ) ); } + +#[test] +fn test_aggregation_field_validation_helper() { + // Test the standalone validation helper function for field validation + let index = get_test_index_2_segments(false).unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + // Test with invalid field + let agg_req: Aggregations = serde_json::from_str( + r#"{ + "avg_test": { + "avg": { "field": "nonexistent_field" } + } + }"#, + ) + .unwrap(); + + let result = crate::aggregation::agg_req::validate_aggregation_fields(&agg_req, segment_reader); + assert!(result.is_err()); + match result { + Err(crate::TantivyError::FieldNotFound(field_name)) => { + assert_eq!(field_name, "nonexistent_field"); + } + _ => panic!("Expected FieldNotFound error, got: {:?}", result), + } + + // Test with valid field + let agg_req: Aggregations = serde_json::from_str( + r#"{ + "avg_test": { + "avg": { "field": "score" } + } + }"#, + ) + .unwrap(); + + let result = crate::aggregation::agg_req::validate_aggregation_fields(&agg_req, segment_reader); + assert!(result.is_ok()); +}