From 6aaecbc9f694f068ff60c42d9b518e8be6edf6dd Mon Sep 17 00:00:00 2001 From: discord9 Date: Fri, 17 Apr 2026 14:57:17 +0800 Subject: [PATCH] feat: stats optimize Signed-off-by: discord9 --- src/query/src/optimizer/aggr_stats.rs | 337 +++++++- src/query/src/optimizer/aggr_stats/check.rs | 28 +- src/query/src/optimizer/aggr_stats/split.rs | 19 + src/query/src/optimizer/aggr_stats/tests.rs | 856 +++++++++++++++++++- 4 files changed, 1198 insertions(+), 42 deletions(-) diff --git a/src/query/src/optimizer/aggr_stats.rs b/src/query/src/optimizer/aggr_stats.rs index 0ab4432c47..ed0255939b 100644 --- a/src/query/src/optimizer/aggr_stats.rs +++ b/src/query/src/optimizer/aggr_stats.rs @@ -14,13 +14,21 @@ use std::sync::Arc; +use arrow::array::{Array, ArrayRef, StructArray}; +use arrow::record_batch::RecordBatch; +use common_error::ext::BoxedError; use common_telemetry::debug; use datafusion::config::ConfigOptions; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::source::DataSourceExec; use datafusion::physical_optimizer::PhysicalOptimizerRule; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::aggregates::AggregateExec; -use datafusion_common::Result as DfResult; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DataFusionError, Result as DfResult, ScalarValue}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datatypes::arrow::datatypes::DataType; use table::table::scan::RegionScanExec; @@ -30,6 +38,90 @@ mod split; mod tests; use check::RewriteCheck; +use split::{common_stats_file_ordinals, filter_stats_by_file_ordinals, partial_state_from_stats}; + +enum RewriteTarget<'a> { + SingleStage { + aggregate_exec: &'a AggregateExec, + region_scan: &'a RegionScanExec, + }, + FinalOverPartial { + final_exec: &'a AggregateExec, + partial_exec: &'a AggregateExec, + region_scan: &'a RegionScanExec, + keep_coalesce: bool, + }, +} + +impl<'a> RewriteTarget<'a> { + fn extract(plan: &'a Arc) -> Option { + let aggregate_exec = plan.as_any().downcast_ref::()?; + + if matches!( + aggregate_exec.mode(), + AggregateMode::Single | AggregateMode::SinglePartitioned + ) { + let region_scan = AggregateStats::extract_region_scan(aggregate_exec.input())?; + return Some(Self::SingleStage { + aggregate_exec, + region_scan, + }); + } + + if !matches!(aggregate_exec.mode(), AggregateMode::Final) { + return None; + } + + if let Some(coalesce) = aggregate_exec + .input() + .as_any() + .downcast_ref::() + { + let partial_exec = coalesce.input().as_any().downcast_ref::()?; + if !matches!(partial_exec.mode(), AggregateMode::Partial) { + return None; + } + + let region_scan = AggregateStats::extract_region_scan(partial_exec.input())?; + return Some(Self::FinalOverPartial { + final_exec: aggregate_exec, + partial_exec, + region_scan, + keep_coalesce: true, + }); + } + + let partial_exec = aggregate_exec + .input() + .as_any() + .downcast_ref::()?; + if !matches!(partial_exec.mode(), AggregateMode::Partial) { + return None; + } + + let region_scan = AggregateStats::extract_region_scan(partial_exec.input())?; + Some(Self::FinalOverPartial { + final_exec: aggregate_exec, + partial_exec, + region_scan, + keep_coalesce: false, + }) + } + + fn first_stage_aggregate(&self) -> &'a AggregateExec { + match self { + RewriteTarget::SingleStage { aggregate_exec, .. } => aggregate_exec, + RewriteTarget::FinalOverPartial { partial_exec, .. } => partial_exec, + } + } + + fn region_scan(&self) -> &'a RegionScanExec { + match self { + RewriteTarget::SingleStage { region_scan, .. } + | RewriteTarget::FinalOverPartial { region_scan, .. } => region_scan, + } + } +} #[derive(Debug)] pub struct AggregateStats; @@ -83,32 +175,245 @@ impl AggregateStats { fn do_optimize(plan: Arc) -> DfResult> { let result = plan .transform_down(|plan| { - let Some(aggregate_exec) = plan.as_any().downcast_ref::() else { + let Some(target) = RewriteTarget::extract(&plan) else { return Ok(Transformed::no(plan)); }; - let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else { - return Ok(Transformed::no(plan)); - }; - - let check = RewriteCheck::new(aggregate_exec, region_scan); + let check = RewriteCheck::new(target.first_stage_aggregate(), target.region_scan()); if let Some(reason) = check.skip_reason()? { debug!("Skip aggregate stats optimization: {reason}"); return Ok(Transformed::no(plan)); } - // Subtask 03 only adds the scan-side exclusion plumbing. The optimizer must not - // exclude stats-covered files until subtask 04 also materializes their - // stats-derived partial state and merges it back into the aggregate result. - Ok(Transformed::no(plan)) + let aggs = check.parse_aggs().map_err(|reason| { + DataFusionError::Internal(format!( + "aggregate stats rewrite became ineligible after eligibility check: {reason}" + )) + })?; + let Some(scan_input_stats) = target.region_scan().scan_input_stats()? else { + return Ok(Transformed::no(plan)); + }; + + let excluded_file_ordinals = common_stats_file_ordinals(&aggs, &scan_input_stats); + if excluded_file_ordinals.is_empty() { + debug!( + "Skip aggregate stats optimization: no shared stats-covered files across aggregates" + ); + return Ok(Transformed::no(plan)); + } + + let rewritten = Self::rewrite_aggregate( + &target, + &aggs, + &scan_input_stats, + &excluded_file_ordinals, + )?; + + Ok(Transformed::yes(rewritten)) })? .data; Ok(result) } - fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> { - let child = aggregate_exec.children().into_iter().next()?; - child.as_any().downcast_ref::() + fn extract_region_scan(plan: &Arc) -> Option<&RegionScanExec> { + plan.as_any().downcast_ref::() + } + + fn rewrite_aggregate( + target: &RewriteTarget<'_>, + aggs: &[StatsAgg], + scan_input_stats: &store_api::scan_stats::RegionScanStats, + excluded_file_ordinals: &[usize], + ) -> DfResult> { + match target { + RewriteTarget::SingleStage { + aggregate_exec, + region_scan, + } => Self::rewrite_single_stage( + aggregate_exec, + region_scan, + aggs, + scan_input_stats, + excluded_file_ordinals, + ), + RewriteTarget::FinalOverPartial { + final_exec, + partial_exec, + region_scan, + keep_coalesce, + } => Self::rewrite_final_over_partial( + final_exec, + partial_exec, + region_scan, + *keep_coalesce, + aggs, + scan_input_stats, + excluded_file_ordinals, + ), + } + } + + fn rewrite_single_stage( + aggregate_exec: &AggregateExec, + region_scan: &RegionScanExec, + aggs: &[StatsAgg], + scan_input_stats: &store_api::scan_stats::RegionScanStats, + excluded_file_ordinals: &[usize], + ) -> DfResult> { + let union = Self::build_partial_union_source( + aggregate_exec, + region_scan, + aggs, + scan_input_stats, + excluded_file_ordinals, + )?; + let union = Self::coalesce_if_needed(union); + + Ok(Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + aggregate_exec.group_expr().clone(), + aggregate_exec.aggr_expr().to_vec(), + vec![None; aggregate_exec.aggr_expr().len()], + union, + aggregate_exec.input_schema(), + )? + .with_limit_options(aggregate_exec.limit_options()), + )) + } + + fn rewrite_final_over_partial( + final_exec: &AggregateExec, + partial_exec: &AggregateExec, + region_scan: &RegionScanExec, + keep_coalesce: bool, + aggs: &[StatsAgg], + scan_input_stats: &store_api::scan_stats::RegionScanStats, + excluded_file_ordinals: &[usize], + ) -> DfResult> { + let partial_source = Self::build_partial_union_source( + partial_exec, + region_scan, + aggs, + scan_input_stats, + excluded_file_ordinals, + )?; + let final_input = if keep_coalesce { + Arc::new(CoalescePartitionsExec::new(partial_source)) as Arc + } else if partial_source.output_partitioning().partition_count() > 1 { + Arc::new(CoalescePartitionsExec::new(partial_source)) as Arc + } else { + partial_source + }; + + Ok(Arc::new( + AggregateExec::try_new( + *final_exec.mode(), + final_exec.group_expr().clone(), + final_exec.aggr_expr().to_vec(), + final_exec.filter_expr().to_vec(), + final_input, + final_exec.input_schema(), + )? + .with_limit_options(final_exec.limit_options()), + )) + } + + fn build_partial_union_source( + aggregate_exec: &AggregateExec, + region_scan: &RegionScanExec, + aggs: &[StatsAgg], + scan_input_stats: &store_api::scan_stats::RegionScanStats, + excluded_file_ordinals: &[usize], + ) -> DfResult> { + let stats_scan_input = + filter_stats_by_file_ordinals(scan_input_stats, excluded_file_ordinals); + let stats_states = aggs + .iter() + .map(|agg| { + partial_state_from_stats(agg, &stats_scan_input)?.ok_or_else(|| { + DataFusionError::Internal( + "missing stats-derived partial state for excluded files".to_string(), + ) + }) + }) + .collect::>>()?; + + let filtered_scan = Arc::new( + region_scan + .with_excluded_file_ordinals(excluded_file_ordinals.to_vec()) + .map_err(boxed_external)?, + ); + + let partial_scan = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + aggregate_exec.group_expr().clone(), + aggregate_exec.aggr_expr().to_vec(), + aggregate_exec.filter_expr().to_vec(), + filtered_scan, + aggregate_exec.input_schema(), + )? + .with_limit_options(aggregate_exec.limit_options()), + ); + + let stats_input = Self::build_stats_input(aggregate_exec.aggr_expr(), stats_states)?; + UnionExec::try_new(vec![partial_scan, stats_input]) + } + + fn coalesce_if_needed(plan: Arc) -> Arc { + if plan.output_partitioning().partition_count() > 1 { + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + plan + } + } + + fn build_stats_input( + aggr_exprs: &[Arc], + stats_states: Vec, + ) -> DfResult> { + let fields = aggr_exprs.iter().try_fold(Vec::new(), |mut fields, expr| { + fields.extend(expr.state_fields()?); + Ok::<_, DataFusionError>(fields) + })?; + let schema = Arc::new(arrow::datatypes::Schema::new(fields)); + + let columns = stats_states + .into_iter() + .try_fold(Vec::new(), |mut columns, state| { + columns.extend(Self::state_columns(state)?); + Ok::<_, DataFusionError>(columns) + })?; + let batch = RecordBatch::try_new(schema.clone(), columns) + .map_err(|err| DataFusionError::ArrowError(Box::new(err), None))?; + + Ok(DataSourceExec::from_data_source( + MemorySourceConfig::try_new(&[vec![batch]], schema, None)?, + )) + } + + fn state_columns(state: ScalarValue) -> DfResult> { + let ScalarValue::Struct(array) = state else { + return Err(DataFusionError::Internal( + "aggregate stats rewrite expected a struct partial state".to_string(), + )); + }; + + let struct_array = array + .as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "aggregate stats rewrite expected a struct array partial state".to_string(), + ) + })?; + Ok(struct_array.columns().to_vec()) } } + +fn boxed_external(err: BoxedError) -> DataFusionError { + DataFusionError::External(Box::new(err)) +} diff --git a/src/query/src/optimizer/aggr_stats/check.rs b/src/query/src/optimizer/aggr_stats/check.rs index 3dbeef2315..e5b61a83e6 100644 --- a/src/query/src/optimizer/aggr_stats/check.rs +++ b/src/query/src/optimizer/aggr_stats/check.rs @@ -13,7 +13,7 @@ // limitations under the License. use common_telemetry::debug; -use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; use datafusion_common::Result as DfResult; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; use datafusion_physical_expr::PhysicalExpr; @@ -39,9 +39,20 @@ impl<'a> RewriteCheck<'a> { } pub(super) fn skip_reason(&self) -> DfResult> { - // MVP only handles global aggregates over append-mode region scans. Anything - // else falls back to the normal execution path. - if !self.region_scan.append_mode() || !self.aggregate_exec.group_expr().is_empty() { + // MVP only handles global first-stage aggregates over append-mode region + // scans. Anything else falls back to the normal execution path. + if !self.region_scan.append_mode() + || !self.aggregate_exec.group_expr().is_empty() + || !matches!( + self.aggregate_exec.mode(), + AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned + ) + || self + .aggregate_exec + .filter_expr() + .iter() + .any(|expr| expr.is_some()) + { return Ok(Some(RejectReason::UnsupportedPlan)); } @@ -77,7 +88,7 @@ impl<'a> RewriteCheck<'a> { } } - fn parse_aggs(&self) -> Result, RejectReason> { + pub(super) fn parse_aggs(&self) -> Result, RejectReason> { let aggr_exprs = self.aggregate_exec.aggr_expr(); if aggr_exprs.is_empty() { return Err(RejectReason::UnsupportedAggregate); @@ -95,6 +106,7 @@ impl<'a> RewriteCheck<'a> { let inputs = expr.expressions(); let name = expr.fun().name().to_ascii_lowercase(); + // COUNT(*) is usually rewrite to COUNT(time-index) // before this physical optimizer runs, so CountStar is mostly a defensive fallback if name == "count" && is_count_star_expr(&inputs) { @@ -139,11 +151,7 @@ impl<'a> RewriteCheck<'a> { } pub(super) fn check_agg_shape(expr: &AggregateFunctionExpr) -> Result<(), RejectReason> { - if expr.is_distinct() - || expr.ignore_nulls() - || expr.is_reversed() - || !expr.order_bys().is_empty() - { + if expr.is_distinct() || expr.is_reversed() || !expr.order_bys().is_empty() { return Err(RejectReason::UnsupportedAggregate); } diff --git a/src/query/src/optimizer/aggr_stats/split.rs b/src/query/src/optimizer/aggr_stats/split.rs index cd2c44ed84..d5d9bb98bc 100644 --- a/src/query/src/optimizer/aggr_stats/split.rs +++ b/src/query/src/optimizer/aggr_stats/split.rs @@ -123,6 +123,25 @@ pub(super) fn common_stats_file_ordinals( .collect() } +pub(super) fn filter_stats_by_file_ordinals( + scan_input_stats: &RegionScanStats, + file_ordinals: &[usize], +) -> RegionScanStats { + let selected = file_ordinals + .iter() + .copied() + .collect::>(); + + RegionScanStats { + files: scan_input_stats + .files + .iter() + .filter(|file| selected.contains(&file.file_ordinal)) + .cloned() + .collect(), + } +} + pub(super) trait StatsAggExt { fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool; } diff --git a/src/query/src/optimizer/aggr_stats/tests.rs b/src/query/src/optimizer/aggr_stats/tests.rs index 7fbbac1197..5ae194d1a3 100644 --- a/src/query/src/optimizer/aggr_stats/tests.rs +++ b/src/query/src/optimizer/aggr_stats/tests.rs @@ -13,13 +13,27 @@ // limitations under the License. use std::collections::HashMap; +use std::fmt; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use arrow::array::{Int64Array, TimestampMillisecondArray}; use arrow::datatypes::{Field, Schema}; +use common_error::ext::BoxedError; +use common_recordbatch::{ + EmptyRecordBatchStream, RecordBatch as CommonRecordBatch, RecordBatches, + SendableRecordBatchStream, +}; use common_time::Timestamp; use common_time::timestamp::TimeUnit as TimestampUnit; +use datafusion::execution::context::SessionContext; use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::{Expr, LogicalPlan}; @@ -32,20 +46,26 @@ use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema as GreptimeSchema}; use datatypes::value::Value; use session::context::QueryContext; +use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef}; +use store_api::region_engine::{ + PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties, +}; use store_api::scan_stats::{ RegionScanColumnStats as RegionScanColumnInputStats, RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats, }; +use store_api::storage::{RegionId, ScanRequest}; use table::metadata::{TableInfoBuilder, TableMetaBuilder}; +use table::table::scan::RegionScanExec; use table::test_util::EmptyTable; -use super::StatsAgg; use super::check::{RejectReason, RewriteCheck, is_supported_aggregate_name}; use super::split::{ - FileStatsRequirement, StatsAggExt, common_stats_file_ordinals, has_partition_expr_mismatch, - partial_state_from_stats, split_count_field_files, split_count_star_files, - split_min_max_field_files, split_time_files, + FileStatsRequirement, StatsAggExt, common_stats_file_ordinals, filter_stats_by_file_ordinals, + has_partition_expr_mismatch, partial_state_from_stats, split_count_field_files, + split_count_star_files, split_min_max_field_files, split_time_files, }; +use super::{AggregateStats, StatsAgg}; use crate::parser::QueryLanguageParser; use crate::tests::new_query_engine_with_table; @@ -68,17 +88,180 @@ fn field_stats( )]) } +#[derive(Debug)] +struct StatsRecordingScanner { + schema: datatypes::schema::SchemaRef, + metadata: RegionMetadataRef, + properties: ScannerProperties, + base_stats: RegionScanInputStats, + excluded_count: Arc, + file_batches: Vec<(usize, CommonRecordBatch)>, + excluded_file_ordinals: Vec, +} + +impl StatsRecordingScanner { + fn new( + schema: datatypes::schema::SchemaRef, + metadata: RegionMetadataRef, + base_stats: RegionScanInputStats, + excluded_count: Arc, + ) -> Self { + Self { + schema, + metadata, + properties: ScannerProperties::default().with_append_mode(true), + base_stats, + excluded_count, + file_batches: Vec::new(), + excluded_file_ordinals: Vec::new(), + } + } + + fn with_file_batches(mut self, file_batches: Vec<(usize, CommonRecordBatch)>) -> Self { + self.file_batches = file_batches; + self + } +} + +impl DisplayAs for StatsRecordingScanner { + fn fmt_as(&self, _: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "StatsRecordingScanner") + } +} + +impl RegionScanner for StatsRecordingScanner { + fn name(&self) -> &str { + "StatsRecordingScanner" + } + + fn properties(&self) -> &ScannerProperties { + &self.properties + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn metadata(&self) -> RegionMetadataRef { + self.metadata.clone() + } + + fn prepare(&mut self, request: PrepareRequest) -> Result<(), common_error::ext::BoxedError> { + request.validate()?; + if let Some(excluded_file_ordinals) = request.excluded_file_ordinals.as_ref() { + self.excluded_count + .store(excluded_file_ordinals.len(), Ordering::Relaxed); + self.excluded_file_ordinals = excluded_file_ordinals.clone(); + } + self.properties.prepare(request); + Ok(()) + } + + fn scan_partition( + &self, + _: &QueryScanContext, + _: &ExecutionPlanMetricsSet, + _: usize, + ) -> Result { + let batches = self + .file_batches + .iter() + .filter(|(file_ordinal, _)| !self.excluded_file_ordinals.contains(file_ordinal)) + .map(|(_, batch)| batch.clone()) + .collect::>(); + + if batches.is_empty() { + Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) + } else { + Ok(RecordBatches::try_new(self.schema.clone(), batches) + .map_err(BoxedError::new)? + .as_stream()) + } + } + + fn has_predicate_without_region(&self) -> bool { + false + } + + fn scan_input_stats( + &self, + ) -> Result, common_error::ext::BoxedError> { + Ok(Some(filter_stats_by_file_ordinals( + &self.base_stats, + &self + .base_stats + .files + .iter() + .filter_map(|file| { + (!self.excluded_file_ordinals.contains(&file.file_ordinal)) + .then_some(file.file_ordinal) + }) + .collect::>(), + ))) + } + + fn add_dyn_filter_to_predicate(&mut self, _: Vec>) -> Vec { + Vec::new() + } + + fn set_logical_region(&mut self, logical_region: bool) { + self.properties.set_logical_region(logical_region); + } +} + +fn test_region_metadata() -> RegionMetadataRef { + let mut builder = RegionMetadataBuilder::new(RegionId::new(1024, 1)); + builder + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + semantic_type: api::v1::SemanticType::Timestamp, + column_id: 1, + }) + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true), + semantic_type: api::v1::SemanticType::Field, + column_id: 2, + }) + .primary_key(vec![]); + Arc::new(builder.build().unwrap()) +} + +fn scan_batch( + schema: datatypes::schema::SchemaRef, + timestamps: Vec, + values: Vec>, +) -> CommonRecordBatch { + let df_record_batch = datatypes::arrow::record_batch::RecordBatch::try_new( + schema.arrow_schema().clone(), + vec![ + Arc::new(TimestampMillisecondArray::from_iter_values(timestamps)), + Arc::new(Int64Array::from(values)), + ], + ) + .unwrap(); + + CommonRecordBatch::from_df_record_batch(schema, df_record_batch) +} + fn build_test_aggr_expr( distinct: bool, ignore_nulls: bool, order_by: bool, ) -> datafusion_common::Result { - let schema = Arc::new(Schema::new(vec![Field::new( - "value", - DataType::Int64, - true, - )])); - let args = vec![Arc::new(Column::new("value", 0)) as Arc]; + let schema = Arc::new(Schema::new(vec![ + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Int64, true), + ])); + let args = vec![Arc::new(Column::new("value", 1)) as Arc]; let mut builder = AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args) .schema(schema) .alias("count(value)"); @@ -99,6 +282,50 @@ fn build_test_aggr_expr( builder.build() } +fn build_min_field_aggr_expr_with_ignore_nulls( + ignore_nulls: bool, +) -> datafusion_common::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Int64, true), + ])); + let args = vec![Arc::new(Column::new("value", 1)) as Arc]; + let mut builder = AggregateExprBuilder::new(Arc::new((*min_udaf()).clone()), args) + .schema(schema) + .alias("min(value)"); + if ignore_nulls { + builder = builder.ignore_nulls(); + } + + builder.build() +} + +fn build_max_field_aggr_expr_with_ignore_nulls( + ignore_nulls: bool, +) -> datafusion_common::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Int64, true), + ])); + let args = vec![Arc::new(Column::new("value", 1)) as Arc]; + let mut builder = AggregateExprBuilder::new(Arc::new((*max_udaf()).clone()), args) + .schema(schema) + .alias("max(value)"); + if ignore_nulls { + builder = builder.ignore_nulls(); + } + + builder.build() +} + fn build_count_star_aggr_expr() -> datafusion_common::Result { let schema = Arc::new(Schema::empty()); let args = vec![Arc::new(Literal::new(COUNT_STAR_EXPANSION)) as Arc]; @@ -270,19 +497,50 @@ fn test_split_count_star_files_keeps_zero_row_files_stats_eligible() { } #[test] -fn test_supported_aggregate_rejects_distinct_ignore_nulls_and_order_by_shapes() { +fn test_supported_aggregate_shape_allows_ignore_nulls_for_count_min_max() { let distinct = build_test_aggr_expr(true, false, false).unwrap(); - let ignore_nulls = build_test_aggr_expr(false, true, false).unwrap(); + let count_ignore_nulls = build_test_aggr_expr(false, true, false).unwrap(); + let min_ignore_nulls = build_min_field_aggr_expr_with_ignore_nulls(true).unwrap(); + let max_ignore_nulls = build_max_field_aggr_expr_with_ignore_nulls(true).unwrap(); let order_by = build_test_aggr_expr(false, false, true).unwrap(); + let reject_reason = |expr: AggregateFunctionExpr| { + let schema = execution_test_schema(); + let region_scan = Arc::new( + RegionScanExec::new( + Box::new(StatsRecordingScanner::new( + schema.clone(), + test_region_metadata(), + execution_test_stats(), + Arc::new(AtomicUsize::new(0)), + )), + ScanRequest::default(), + None, + ) + .unwrap(), + ); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![Arc::new(expr)], + vec![None], + region_scan.clone(), + schema.arrow_schema().clone(), + ) + .unwrap(), + ); + let check = RewriteCheck::new(aggregate.as_ref(), region_scan.as_ref()); + check.skip_reason().unwrap() + }; + assert!(matches!( RewriteCheck::check_agg_shape(&distinct), Err(RejectReason::UnsupportedAggregate) )); - assert!(matches!( - RewriteCheck::check_agg_shape(&ignore_nulls), - Err(RejectReason::UnsupportedAggregate) - )); + assert!(reject_reason(count_ignore_nulls).is_none()); + assert!(reject_reason(min_ignore_nulls).is_none()); + assert!(reject_reason(max_ignore_nulls).is_none()); assert!(matches!( RewriteCheck::check_agg_shape(&order_by), Err(RejectReason::UnsupportedAggregate) @@ -632,6 +890,32 @@ fn test_common_stats_file_ordinals_returns_only_shared_stats_eligible_files() { assert_eq!(common_stats_file_ordinals(&aggregates, &stats), vec![0]); } +#[test] +fn test_filter_stats_by_file_ordinals_keeps_only_selected_files() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + let filtered = filter_stats_by_file_ordinals(&stats, &[1]); + assert_eq!(filtered.files.len(), 1); + assert_eq!(filtered.files[0].file_ordinal, 1); +} + #[test] fn test_partial_state_from_stats_count_star() { let aggregate = StatsAgg::CountStar; @@ -779,3 +1063,543 @@ fn test_partial_state_from_stats_max_field() { .unwrap(); assert_eq!(max_values.value(0), 9); } + +#[test] +fn test_optimizer_rewrites_into_final_union_partial_scan_and_stats() { + let schema = Arc::new(GreptimeSchema::new(vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true), + ])); + let metadata = test_region_metadata(); + let excluded_count = Arc::new(AtomicUsize::new(0)); + let base_stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + ], + }; + let scanner = Box::new(StatsRecordingScanner::new( + schema.clone(), + metadata, + base_stats, + excluded_count.clone(), + )); + let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap()); + let aggr_expr = Arc::new(build_test_aggr_expr(false, false, false).unwrap()); + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr], + vec![None], + scan, + schema.arrow_schema().clone(), + ) + .unwrap(), + ); + + let optimized = AggregateStats::do_optimize(plan).unwrap(); + let final_agg = optimized.as_any().downcast_ref::().unwrap(); + assert_eq!(final_agg.mode(), &AggregateMode::Final); + + let coalesce = final_agg + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let union = coalesce + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(union.inputs().len(), 2); + + let partial_agg = union.inputs()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(partial_agg.mode(), &AggregateMode::Partial); + let partial_scan = partial_agg + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let remaining = partial_scan.scan_input_stats().unwrap().unwrap(); + assert_eq!( + remaining + .files + .iter() + .map(|file| file.file_ordinal) + .collect::>(), + vec![1] + ); + assert_eq!(excluded_count.load(Ordering::Relaxed), 2); +} + +#[test] +fn test_optimizer_rewrites_final_partial_plan_by_replacing_partial_input() { + let schema = Arc::new(GreptimeSchema::new(vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true), + ])); + let metadata = test_region_metadata(); + let excluded_count = Arc::new(AtomicUsize::new(0)); + let base_stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + ], + }; + let scanner = Box::new(StatsRecordingScanner::new( + schema.clone(), + metadata, + base_stats, + excluded_count.clone(), + )); + let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap()); + let aggr_expr = Arc::new(build_test_aggr_expr(false, false, false).unwrap()); + let partial: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr.clone()], + vec![None], + scan, + schema.arrow_schema().clone(), + ) + .unwrap(), + ); + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr], + vec![None], + Arc::new(CoalescePartitionsExec::new(partial.clone())), + partial.schema(), + ) + .unwrap(), + ); + + let optimized = AggregateStats::do_optimize(plan).unwrap(); + let final_agg = optimized.as_any().downcast_ref::().unwrap(); + assert_eq!(final_agg.mode(), &AggregateMode::Final); + + let coalesce = final_agg + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let union = coalesce + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(union.inputs().len(), 2); + + let partial_agg = union.inputs()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(partial_agg.mode(), &AggregateMode::Partial); + let partial_scan = partial_agg + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let remaining = partial_scan.scan_input_stats().unwrap().unwrap(); + assert_eq!( + remaining + .files + .iter() + .map(|file| file.file_ordinal) + .collect::>(), + vec![1] + ); + assert_eq!(excluded_count.load(Ordering::Relaxed), 2); +} + +#[derive(Clone, Copy)] +enum ExecutionAggExprCase { + CountValue { ignore_nulls: bool }, + CountStar, + MinValue { ignore_nulls: bool }, + MaxValue { ignore_nulls: bool }, +} + +impl ExecutionAggExprCase { + fn build(self) -> AggregateFunctionExpr { + match self { + ExecutionAggExprCase::CountValue { ignore_nulls } => { + build_test_aggr_expr(false, ignore_nulls, false).unwrap() + } + ExecutionAggExprCase::CountStar => build_count_star_aggr_expr().unwrap(), + ExecutionAggExprCase::MinValue { ignore_nulls } => { + build_min_field_aggr_expr_with_ignore_nulls(ignore_nulls).unwrap() + } + ExecutionAggExprCase::MaxValue { ignore_nulls } => { + build_max_field_aggr_expr_with_ignore_nulls(ignore_nulls).unwrap() + } + } + } +} + +fn execution_test_schema() -> datatypes::schema::SchemaRef { + Arc::new(GreptimeSchema::new(vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true), + ])) +} + +fn execution_test_stats() -> RegionScanInputStats { + RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(12))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(33))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(54))), + field_stats: field_stats(Some(4), Some(Value::Int64(7)), Some(Value::Int64(10))), + partition_expr_matches_region: true, + }, + ], + } +} + +fn execution_test_file_batches( + schema: datatypes::schema::SchemaRef, +) -> Vec<(usize, CommonRecordBatch)> { + vec![ + ( + 0, + scan_batch( + schema.clone(), + vec![10, 11, 12], + vec![Some(1), None, Some(3)], + ), + ), + ( + 1, + scan_batch( + schema.clone(), + vec![30, 31, 32, 33], + vec![Some(4), Some(5), None, Some(6)], + ), + ), + ( + 2, + scan_batch( + schema, + vec![50, 51, 52, 53, 54], + vec![Some(7), Some(8), Some(9), Some(10), None], + ), + ), + ] +} + +fn build_execution_test_plan( + schema: datatypes::schema::SchemaRef, + metadata: RegionMetadataRef, + base_stats: RegionScanInputStats, + file_batches: Vec<(usize, CommonRecordBatch)>, + aggr_expr: AggregateFunctionExpr, + excluded_count: Arc, +) -> Arc { + let scanner = Box::new( + StatsRecordingScanner::new(schema.clone(), metadata, base_stats, excluded_count) + .with_file_batches(file_batches), + ); + let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap()); + let aggr_expr = Arc::new(aggr_expr); + let partial: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr.clone()], + vec![None], + scan, + schema.arrow_schema().clone(), + ) + .unwrap(), + ); + + Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr], + vec![None], + Arc::new(CoalescePartitionsExec::new(partial.clone())), + partial.schema(), + ) + .unwrap(), + ) +} + +fn optimized_plan_uses_stats_union(plan: &Arc) -> bool { + let Some(final_agg) = plan.as_any().downcast_ref::() else { + return false; + }; + + let input = final_agg.input(); + if let Some(coalesce) = input.as_any().downcast_ref::() { + return coalesce + .input() + .as_any() + .downcast_ref::() + .is_some(); + } + + input.as_any().downcast_ref::().is_some() +} + +#[tokio::test] +async fn test_optimizer_execution_matrix() { + struct Case { + name: &'static str, + expr: ExecutionAggExprCase, + expect_rewrite: bool, + expected_excluded_count: usize, + expected: String, + } + + let cases = [ + Case { + name: "count value", + expr: ExecutionAggExprCase::CountValue { + ignore_nulls: false, + }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+--------------+", + "| count(value) |", + "+--------------+", + "| 9 |", + "+--------------+", + ] + .join("\n"), + }, + Case { + name: "count value ignore nulls", + expr: ExecutionAggExprCase::CountValue { ignore_nulls: true }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+--------------+", + "| count(value) |", + "+--------------+", + "| 9 |", + "+--------------+", + ] + .join("\n"), + }, + Case { + name: "count star", + expr: ExecutionAggExprCase::CountStar, + expect_rewrite: true, + expected_excluded_count: 3, + expected: [ + "+----------+", + "| count(*) |", + "+----------+", + "| 12 |", + "+----------+", + ] + .join("\n"), + }, + Case { + name: "min value", + expr: ExecutionAggExprCase::MinValue { + ignore_nulls: false, + }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+------------+", + "| min(value) |", + "+------------+", + "| 1 |", + "+------------+", + ] + .join("\n"), + }, + Case { + name: "min value ignore nulls", + expr: ExecutionAggExprCase::MinValue { ignore_nulls: true }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+------------+", + "| min(value) |", + "+------------+", + "| 1 |", + "+------------+", + ] + .join("\n"), + }, + Case { + name: "max value", + expr: ExecutionAggExprCase::MaxValue { + ignore_nulls: false, + }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+------------+", + "| max(value) |", + "+------------+", + "| 10 |", + "+------------+", + ] + .join("\n"), + }, + Case { + name: "max value ignore nulls", + expr: ExecutionAggExprCase::MaxValue { ignore_nulls: true }, + expect_rewrite: true, + expected_excluded_count: 2, + expected: [ + "+------------+", + "| max(value) |", + "+------------+", + "| 10 |", + "+------------+", + ] + .join("\n"), + }, + ]; + + let schema = execution_test_schema(); + let metadata = test_region_metadata(); + let base_stats = execution_test_stats(); + let file_batches = execution_test_file_batches(schema.clone()); + + for case in cases { + let unoptimized = build_execution_test_plan( + schema.clone(), + metadata.clone(), + base_stats.clone(), + file_batches.clone(), + case.expr.build(), + Arc::new(AtomicUsize::new(0)), + ); + let unoptimized_result = + datafusion::physical_plan::collect(unoptimized, SessionContext::default().task_ctx()) + .await + .unwrap(); + + let optimized_excluded_count = Arc::new(AtomicUsize::new(0)); + let optimized = AggregateStats::do_optimize(build_execution_test_plan( + schema.clone(), + metadata.clone(), + base_stats.clone(), + file_batches.clone(), + case.expr.build(), + optimized_excluded_count.clone(), + )) + .unwrap(); + let optimized_result = datafusion::physical_plan::collect( + optimized.clone(), + SessionContext::default().task_ctx(), + ) + .await + .unwrap(); + + let unoptimized_pretty = arrow::util::pretty::pretty_format_batches(&unoptimized_result) + .unwrap() + .to_string(); + let optimized_pretty = arrow::util::pretty::pretty_format_batches(&optimized_result) + .unwrap() + .to_string(); + + assert_eq!( + unoptimized_pretty, + case.expected.as_str(), + "case: {}", + case.name + ); + assert_eq!( + optimized_pretty, + case.expected.as_str(), + "case: {}", + case.name + ); + assert_eq!( + optimized_plan_uses_stats_union(&optimized), + case.expect_rewrite, + "case: {}", + case.name + ); + assert_eq!( + optimized_excluded_count.load(Ordering::Relaxed), + case.expected_excluded_count, + "case: {}", + case.name + ); + } +}