feat: stats optimize

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-17 14:57:17 +08:00
parent ca0a9a2d5d
commit 6aaecbc9f6
4 changed files with 1198 additions and 42 deletions

View File

@@ -14,13 +14,21 @@
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, StructArray};
use arrow::record_batch::RecordBatch;
use common_error::ext::BoxedError;
use common_telemetry::debug;
use datafusion::config::ConfigOptions;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::source::DataSourceExec;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion_common::Result as DfResult;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, Result as DfResult, ScalarValue};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::DataType;
use table::table::scan::RegionScanExec;
@@ -30,6 +38,90 @@ mod split;
mod tests;
use check::RewriteCheck;
use split::{common_stats_file_ordinals, filter_stats_by_file_ordinals, partial_state_from_stats};
enum RewriteTarget<'a> {
SingleStage {
aggregate_exec: &'a AggregateExec,
region_scan: &'a RegionScanExec,
},
FinalOverPartial {
final_exec: &'a AggregateExec,
partial_exec: &'a AggregateExec,
region_scan: &'a RegionScanExec,
keep_coalesce: bool,
},
}
impl<'a> RewriteTarget<'a> {
fn extract(plan: &'a Arc<dyn ExecutionPlan>) -> Option<Self> {
let aggregate_exec = plan.as_any().downcast_ref::<AggregateExec>()?;
if matches!(
aggregate_exec.mode(),
AggregateMode::Single | AggregateMode::SinglePartitioned
) {
let region_scan = AggregateStats::extract_region_scan(aggregate_exec.input())?;
return Some(Self::SingleStage {
aggregate_exec,
region_scan,
});
}
if !matches!(aggregate_exec.mode(), AggregateMode::Final) {
return None;
}
if let Some(coalesce) = aggregate_exec
.input()
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
{
let partial_exec = coalesce.input().as_any().downcast_ref::<AggregateExec>()?;
if !matches!(partial_exec.mode(), AggregateMode::Partial) {
return None;
}
let region_scan = AggregateStats::extract_region_scan(partial_exec.input())?;
return Some(Self::FinalOverPartial {
final_exec: aggregate_exec,
partial_exec,
region_scan,
keep_coalesce: true,
});
}
let partial_exec = aggregate_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()?;
if !matches!(partial_exec.mode(), AggregateMode::Partial) {
return None;
}
let region_scan = AggregateStats::extract_region_scan(partial_exec.input())?;
Some(Self::FinalOverPartial {
final_exec: aggregate_exec,
partial_exec,
region_scan,
keep_coalesce: false,
})
}
fn first_stage_aggregate(&self) -> &'a AggregateExec {
match self {
RewriteTarget::SingleStage { aggregate_exec, .. } => aggregate_exec,
RewriteTarget::FinalOverPartial { partial_exec, .. } => partial_exec,
}
}
fn region_scan(&self) -> &'a RegionScanExec {
match self {
RewriteTarget::SingleStage { region_scan, .. }
| RewriteTarget::FinalOverPartial { region_scan, .. } => region_scan,
}
}
}
#[derive(Debug)]
pub struct AggregateStats;
@@ -83,32 +175,245 @@ impl AggregateStats {
fn do_optimize(plan: Arc<dyn ExecutionPlan>) -> DfResult<Arc<dyn ExecutionPlan>> {
let result = plan
.transform_down(|plan| {
let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
let Some(target) = RewriteTarget::extract(&plan) else {
return Ok(Transformed::no(plan));
};
let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else {
return Ok(Transformed::no(plan));
};
let check = RewriteCheck::new(aggregate_exec, region_scan);
let check = RewriteCheck::new(target.first_stage_aggregate(), target.region_scan());
if let Some(reason) = check.skip_reason()? {
debug!("Skip aggregate stats optimization: {reason}");
return Ok(Transformed::no(plan));
}
// Subtask 03 only adds the scan-side exclusion plumbing. The optimizer must not
// exclude stats-covered files until subtask 04 also materializes their
// stats-derived partial state and merges it back into the aggregate result.
Ok(Transformed::no(plan))
let aggs = check.parse_aggs().map_err(|reason| {
DataFusionError::Internal(format!(
"aggregate stats rewrite became ineligible after eligibility check: {reason}"
))
})?;
let Some(scan_input_stats) = target.region_scan().scan_input_stats()? else {
return Ok(Transformed::no(plan));
};
let excluded_file_ordinals = common_stats_file_ordinals(&aggs, &scan_input_stats);
if excluded_file_ordinals.is_empty() {
debug!(
"Skip aggregate stats optimization: no shared stats-covered files across aggregates"
);
return Ok(Transformed::no(plan));
}
let rewritten = Self::rewrite_aggregate(
&target,
&aggs,
&scan_input_stats,
&excluded_file_ordinals,
)?;
Ok(Transformed::yes(rewritten))
})?
.data;
Ok(result)
}
fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> {
let child = aggregate_exec.children().into_iter().next()?;
child.as_any().downcast_ref::<RegionScanExec>()
fn extract_region_scan(plan: &Arc<dyn ExecutionPlan>) -> Option<&RegionScanExec> {
plan.as_any().downcast_ref::<RegionScanExec>()
}
fn rewrite_aggregate(
target: &RewriteTarget<'_>,
aggs: &[StatsAgg],
scan_input_stats: &store_api::scan_stats::RegionScanStats,
excluded_file_ordinals: &[usize],
) -> DfResult<Arc<dyn ExecutionPlan>> {
match target {
RewriteTarget::SingleStage {
aggregate_exec,
region_scan,
} => Self::rewrite_single_stage(
aggregate_exec,
region_scan,
aggs,
scan_input_stats,
excluded_file_ordinals,
),
RewriteTarget::FinalOverPartial {
final_exec,
partial_exec,
region_scan,
keep_coalesce,
} => Self::rewrite_final_over_partial(
final_exec,
partial_exec,
region_scan,
*keep_coalesce,
aggs,
scan_input_stats,
excluded_file_ordinals,
),
}
}
fn rewrite_single_stage(
aggregate_exec: &AggregateExec,
region_scan: &RegionScanExec,
aggs: &[StatsAgg],
scan_input_stats: &store_api::scan_stats::RegionScanStats,
excluded_file_ordinals: &[usize],
) -> DfResult<Arc<dyn ExecutionPlan>> {
let union = Self::build_partial_union_source(
aggregate_exec,
region_scan,
aggs,
scan_input_stats,
excluded_file_ordinals,
)?;
let union = Self::coalesce_if_needed(union);
Ok(Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
aggregate_exec.group_expr().clone(),
aggregate_exec.aggr_expr().to_vec(),
vec![None; aggregate_exec.aggr_expr().len()],
union,
aggregate_exec.input_schema(),
)?
.with_limit_options(aggregate_exec.limit_options()),
))
}
fn rewrite_final_over_partial(
final_exec: &AggregateExec,
partial_exec: &AggregateExec,
region_scan: &RegionScanExec,
keep_coalesce: bool,
aggs: &[StatsAgg],
scan_input_stats: &store_api::scan_stats::RegionScanStats,
excluded_file_ordinals: &[usize],
) -> DfResult<Arc<dyn ExecutionPlan>> {
let partial_source = Self::build_partial_union_source(
partial_exec,
region_scan,
aggs,
scan_input_stats,
excluded_file_ordinals,
)?;
let final_input = if keep_coalesce {
Arc::new(CoalescePartitionsExec::new(partial_source)) as Arc<dyn ExecutionPlan>
} else if partial_source.output_partitioning().partition_count() > 1 {
Arc::new(CoalescePartitionsExec::new(partial_source)) as Arc<dyn ExecutionPlan>
} else {
partial_source
};
Ok(Arc::new(
AggregateExec::try_new(
*final_exec.mode(),
final_exec.group_expr().clone(),
final_exec.aggr_expr().to_vec(),
final_exec.filter_expr().to_vec(),
final_input,
final_exec.input_schema(),
)?
.with_limit_options(final_exec.limit_options()),
))
}
fn build_partial_union_source(
aggregate_exec: &AggregateExec,
region_scan: &RegionScanExec,
aggs: &[StatsAgg],
scan_input_stats: &store_api::scan_stats::RegionScanStats,
excluded_file_ordinals: &[usize],
) -> DfResult<Arc<dyn ExecutionPlan>> {
let stats_scan_input =
filter_stats_by_file_ordinals(scan_input_stats, excluded_file_ordinals);
let stats_states = aggs
.iter()
.map(|agg| {
partial_state_from_stats(agg, &stats_scan_input)?.ok_or_else(|| {
DataFusionError::Internal(
"missing stats-derived partial state for excluded files".to_string(),
)
})
})
.collect::<DfResult<Vec<_>>>()?;
let filtered_scan = Arc::new(
region_scan
.with_excluded_file_ordinals(excluded_file_ordinals.to_vec())
.map_err(boxed_external)?,
);
let partial_scan = Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
aggregate_exec.group_expr().clone(),
aggregate_exec.aggr_expr().to_vec(),
aggregate_exec.filter_expr().to_vec(),
filtered_scan,
aggregate_exec.input_schema(),
)?
.with_limit_options(aggregate_exec.limit_options()),
);
let stats_input = Self::build_stats_input(aggregate_exec.aggr_expr(), stats_states)?;
UnionExec::try_new(vec![partial_scan, stats_input])
}
fn coalesce_if_needed(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
if plan.output_partitioning().partition_count() > 1 {
Arc::new(CoalescePartitionsExec::new(plan))
} else {
plan
}
}
fn build_stats_input(
aggr_exprs: &[Arc<AggregateFunctionExpr>],
stats_states: Vec<ScalarValue>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let fields = aggr_exprs.iter().try_fold(Vec::new(), |mut fields, expr| {
fields.extend(expr.state_fields()?);
Ok::<_, DataFusionError>(fields)
})?;
let schema = Arc::new(arrow::datatypes::Schema::new(fields));
let columns = stats_states
.into_iter()
.try_fold(Vec::new(), |mut columns, state| {
columns.extend(Self::state_columns(state)?);
Ok::<_, DataFusionError>(columns)
})?;
let batch = RecordBatch::try_new(schema.clone(), columns)
.map_err(|err| DataFusionError::ArrowError(Box::new(err), None))?;
Ok(DataSourceExec::from_data_source(
MemorySourceConfig::try_new(&[vec![batch]], schema, None)?,
))
}
fn state_columns(state: ScalarValue) -> DfResult<Vec<ArrayRef>> {
let ScalarValue::Struct(array) = state else {
return Err(DataFusionError::Internal(
"aggregate stats rewrite expected a struct partial state".to_string(),
));
};
let struct_array = array
.as_ref()
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"aggregate stats rewrite expected a struct array partial state".to_string(),
)
})?;
Ok(struct_array.columns().to_vec())
}
}
fn boxed_external(err: BoxedError) -> DataFusionError {
DataFusionError::External(Box::new(err))
}

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use common_telemetry::debug;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion_common::Result as DfResult;
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
use datafusion_physical_expr::PhysicalExpr;
@@ -39,9 +39,20 @@ impl<'a> RewriteCheck<'a> {
}
pub(super) fn skip_reason(&self) -> DfResult<Option<RejectReason>> {
// MVP only handles global aggregates over append-mode region scans. Anything
// else falls back to the normal execution path.
if !self.region_scan.append_mode() || !self.aggregate_exec.group_expr().is_empty() {
// MVP only handles global first-stage aggregates over append-mode region
// scans. Anything else falls back to the normal execution path.
if !self.region_scan.append_mode()
|| !self.aggregate_exec.group_expr().is_empty()
|| !matches!(
self.aggregate_exec.mode(),
AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned
)
|| self
.aggregate_exec
.filter_expr()
.iter()
.any(|expr| expr.is_some())
{
return Ok(Some(RejectReason::UnsupportedPlan));
}
@@ -77,7 +88,7 @@ impl<'a> RewriteCheck<'a> {
}
}
fn parse_aggs(&self) -> Result<Vec<StatsAgg>, RejectReason> {
pub(super) fn parse_aggs(&self) -> Result<Vec<StatsAgg>, RejectReason> {
let aggr_exprs = self.aggregate_exec.aggr_expr();
if aggr_exprs.is_empty() {
return Err(RejectReason::UnsupportedAggregate);
@@ -95,6 +106,7 @@ impl<'a> RewriteCheck<'a> {
let inputs = expr.expressions();
let name = expr.fun().name().to_ascii_lowercase();
// COUNT(*) is usually rewrite to COUNT(time-index)
// before this physical optimizer runs, so CountStar is mostly a defensive fallback
if name == "count" && is_count_star_expr(&inputs) {
@@ -139,11 +151,7 @@ impl<'a> RewriteCheck<'a> {
}
pub(super) fn check_agg_shape(expr: &AggregateFunctionExpr) -> Result<(), RejectReason> {
if expr.is_distinct()
|| expr.ignore_nulls()
|| expr.is_reversed()
|| !expr.order_bys().is_empty()
{
if expr.is_distinct() || expr.is_reversed() || !expr.order_bys().is_empty() {
return Err(RejectReason::UnsupportedAggregate);
}

View File

@@ -123,6 +123,25 @@ pub(super) fn common_stats_file_ordinals(
.collect()
}
pub(super) fn filter_stats_by_file_ordinals(
scan_input_stats: &RegionScanStats,
file_ordinals: &[usize],
) -> RegionScanStats {
let selected = file_ordinals
.iter()
.copied()
.collect::<std::collections::BTreeSet<_>>();
RegionScanStats {
files: scan_input_stats
.files
.iter()
.filter(|file| selected.contains(&file.file_ordinal))
.cloned()
.collect(),
}
}
pub(super) trait StatsAggExt {
fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool;
}

View File

@@ -13,13 +13,27 @@
// limitations under the License.
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow::array::{Int64Array, TimestampMillisecondArray};
use arrow::datatypes::{Field, Schema};
use common_error::ext::BoxedError;
use common_recordbatch::{
EmptyRecordBatchStream, RecordBatch as CommonRecordBatch, RecordBatches,
SendableRecordBatchStream,
};
use common_time::Timestamp;
use common_time::timestamp::TimeUnit as TimestampUnit;
use datafusion::execution::context::SessionContext;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::{Expr, LogicalPlan};
@@ -32,20 +46,26 @@ use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema as GreptimeSchema};
use datatypes::value::Value;
use session::context::QueryContext;
use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
use store_api::region_engine::{
PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
};
use store_api::scan_stats::{
RegionScanColumnStats as RegionScanColumnInputStats,
RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats,
};
use store_api::storage::{RegionId, ScanRequest};
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::table::scan::RegionScanExec;
use table::test_util::EmptyTable;
use super::StatsAgg;
use super::check::{RejectReason, RewriteCheck, is_supported_aggregate_name};
use super::split::{
FileStatsRequirement, StatsAggExt, common_stats_file_ordinals, has_partition_expr_mismatch,
partial_state_from_stats, split_count_field_files, split_count_star_files,
split_min_max_field_files, split_time_files,
FileStatsRequirement, StatsAggExt, common_stats_file_ordinals, filter_stats_by_file_ordinals,
has_partition_expr_mismatch, partial_state_from_stats, split_count_field_files,
split_count_star_files, split_min_max_field_files, split_time_files,
};
use super::{AggregateStats, StatsAgg};
use crate::parser::QueryLanguageParser;
use crate::tests::new_query_engine_with_table;
@@ -68,17 +88,180 @@ fn field_stats(
)])
}
#[derive(Debug)]
struct StatsRecordingScanner {
schema: datatypes::schema::SchemaRef,
metadata: RegionMetadataRef,
properties: ScannerProperties,
base_stats: RegionScanInputStats,
excluded_count: Arc<AtomicUsize>,
file_batches: Vec<(usize, CommonRecordBatch)>,
excluded_file_ordinals: Vec<usize>,
}
impl StatsRecordingScanner {
fn new(
schema: datatypes::schema::SchemaRef,
metadata: RegionMetadataRef,
base_stats: RegionScanInputStats,
excluded_count: Arc<AtomicUsize>,
) -> Self {
Self {
schema,
metadata,
properties: ScannerProperties::default().with_append_mode(true),
base_stats,
excluded_count,
file_batches: Vec::new(),
excluded_file_ordinals: Vec::new(),
}
}
fn with_file_batches(mut self, file_batches: Vec<(usize, CommonRecordBatch)>) -> Self {
self.file_batches = file_batches;
self
}
}
impl DisplayAs for StatsRecordingScanner {
fn fmt_as(&self, _: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "StatsRecordingScanner")
}
}
impl RegionScanner for StatsRecordingScanner {
fn name(&self) -> &str {
"StatsRecordingScanner"
}
fn properties(&self) -> &ScannerProperties {
&self.properties
}
fn schema(&self) -> datatypes::schema::SchemaRef {
self.schema.clone()
}
fn metadata(&self) -> RegionMetadataRef {
self.metadata.clone()
}
fn prepare(&mut self, request: PrepareRequest) -> Result<(), common_error::ext::BoxedError> {
request.validate()?;
if let Some(excluded_file_ordinals) = request.excluded_file_ordinals.as_ref() {
self.excluded_count
.store(excluded_file_ordinals.len(), Ordering::Relaxed);
self.excluded_file_ordinals = excluded_file_ordinals.clone();
}
self.properties.prepare(request);
Ok(())
}
fn scan_partition(
&self,
_: &QueryScanContext,
_: &ExecutionPlanMetricsSet,
_: usize,
) -> Result<SendableRecordBatchStream, common_error::ext::BoxedError> {
let batches = self
.file_batches
.iter()
.filter(|(file_ordinal, _)| !self.excluded_file_ordinals.contains(file_ordinal))
.map(|(_, batch)| batch.clone())
.collect::<Vec<_>>();
if batches.is_empty() {
Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
} else {
Ok(RecordBatches::try_new(self.schema.clone(), batches)
.map_err(BoxedError::new)?
.as_stream())
}
}
fn has_predicate_without_region(&self) -> bool {
false
}
fn scan_input_stats(
&self,
) -> Result<Option<RegionScanInputStats>, common_error::ext::BoxedError> {
Ok(Some(filter_stats_by_file_ordinals(
&self.base_stats,
&self
.base_stats
.files
.iter()
.filter_map(|file| {
(!self.excluded_file_ordinals.contains(&file.file_ordinal))
.then_some(file.file_ordinal)
})
.collect::<Vec<_>>(),
)))
}
fn add_dyn_filter_to_predicate(&mut self, _: Vec<Arc<dyn PhysicalExpr>>) -> Vec<bool> {
Vec::new()
}
fn set_logical_region(&mut self, logical_region: bool) {
self.properties.set_logical_region(logical_region);
}
}
fn test_region_metadata() -> RegionMetadataRef {
let mut builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
builder
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
semantic_type: api::v1::SemanticType::Timestamp,
column_id: 1,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true),
semantic_type: api::v1::SemanticType::Field,
column_id: 2,
})
.primary_key(vec![]);
Arc::new(builder.build().unwrap())
}
fn scan_batch(
schema: datatypes::schema::SchemaRef,
timestamps: Vec<i64>,
values: Vec<Option<i64>>,
) -> CommonRecordBatch {
let df_record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
schema.arrow_schema().clone(),
vec![
Arc::new(TimestampMillisecondArray::from_iter_values(timestamps)),
Arc::new(Int64Array::from(values)),
],
)
.unwrap();
CommonRecordBatch::from_df_record_batch(schema, df_record_batch)
}
fn build_test_aggr_expr(
distinct: bool,
ignore_nulls: bool,
order_by: bool,
) -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
true,
)]));
let args = vec![Arc::new(Column::new("value", 0)) as Arc<dyn PhysicalExpr>];
let schema = Arc::new(Schema::new(vec![
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("value", DataType::Int64, true),
]));
let args = vec![Arc::new(Column::new("value", 1)) as Arc<dyn PhysicalExpr>];
let mut builder = AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args)
.schema(schema)
.alias("count(value)");
@@ -99,6 +282,50 @@ fn build_test_aggr_expr(
builder.build()
}
fn build_min_field_aggr_expr_with_ignore_nulls(
ignore_nulls: bool,
) -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::new(vec![
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("value", DataType::Int64, true),
]));
let args = vec![Arc::new(Column::new("value", 1)) as Arc<dyn PhysicalExpr>];
let mut builder = AggregateExprBuilder::new(Arc::new((*min_udaf()).clone()), args)
.schema(schema)
.alias("min(value)");
if ignore_nulls {
builder = builder.ignore_nulls();
}
builder.build()
}
fn build_max_field_aggr_expr_with_ignore_nulls(
ignore_nulls: bool,
) -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::new(vec![
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("value", DataType::Int64, true),
]));
let args = vec![Arc::new(Column::new("value", 1)) as Arc<dyn PhysicalExpr>];
let mut builder = AggregateExprBuilder::new(Arc::new((*max_udaf()).clone()), args)
.schema(schema)
.alias("max(value)");
if ignore_nulls {
builder = builder.ignore_nulls();
}
builder.build()
}
fn build_count_star_aggr_expr() -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::empty());
let args = vec![Arc::new(Literal::new(COUNT_STAR_EXPANSION)) as Arc<dyn PhysicalExpr>];
@@ -270,19 +497,50 @@ fn test_split_count_star_files_keeps_zero_row_files_stats_eligible() {
}
#[test]
fn test_supported_aggregate_rejects_distinct_ignore_nulls_and_order_by_shapes() {
fn test_supported_aggregate_shape_allows_ignore_nulls_for_count_min_max() {
let distinct = build_test_aggr_expr(true, false, false).unwrap();
let ignore_nulls = build_test_aggr_expr(false, true, false).unwrap();
let count_ignore_nulls = build_test_aggr_expr(false, true, false).unwrap();
let min_ignore_nulls = build_min_field_aggr_expr_with_ignore_nulls(true).unwrap();
let max_ignore_nulls = build_max_field_aggr_expr_with_ignore_nulls(true).unwrap();
let order_by = build_test_aggr_expr(false, false, true).unwrap();
let reject_reason = |expr: AggregateFunctionExpr| {
let schema = execution_test_schema();
let region_scan = Arc::new(
RegionScanExec::new(
Box::new(StatsRecordingScanner::new(
schema.clone(),
test_region_metadata(),
execution_test_stats(),
Arc::new(AtomicUsize::new(0)),
)),
ScanRequest::default(),
None,
)
.unwrap(),
);
let aggregate = Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![]),
vec![Arc::new(expr)],
vec![None],
region_scan.clone(),
schema.arrow_schema().clone(),
)
.unwrap(),
);
let check = RewriteCheck::new(aggregate.as_ref(), region_scan.as_ref());
check.skip_reason().unwrap()
};
assert!(matches!(
RewriteCheck::check_agg_shape(&distinct),
Err(RejectReason::UnsupportedAggregate)
));
assert!(matches!(
RewriteCheck::check_agg_shape(&ignore_nulls),
Err(RejectReason::UnsupportedAggregate)
));
assert!(reject_reason(count_ignore_nulls).is_none());
assert!(reject_reason(min_ignore_nulls).is_none());
assert!(reject_reason(max_ignore_nulls).is_none());
assert!(matches!(
RewriteCheck::check_agg_shape(&order_by),
Err(RejectReason::UnsupportedAggregate)
@@ -632,6 +890,32 @@ fn test_common_stats_file_ordinals_returns_only_shared_stats_eligible_files() {
assert_eq!(common_stats_file_ordinals(&aggregates, &stats), vec![0]);
}
#[test]
fn test_filter_stats_by_file_ordinals_keeps_only_selected_files() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
let filtered = filter_stats_by_file_ordinals(&stats, &[1]);
assert_eq!(filtered.files.len(), 1);
assert_eq!(filtered.files[0].file_ordinal, 1);
}
#[test]
fn test_partial_state_from_stats_count_star() {
let aggregate = StatsAgg::CountStar;
@@ -779,3 +1063,543 @@ fn test_partial_state_from_stats_max_field() {
.unwrap();
assert_eq!(max_values.value(0), 9);
}
#[test]
fn test_optimizer_rewrites_into_final_union_partial_scan_and_stats() {
let schema = Arc::new(GreptimeSchema::new(vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true),
]));
let metadata = test_region_metadata();
let excluded_count = Arc::new(AtomicUsize::new(0));
let base_stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
],
};
let scanner = Box::new(StatsRecordingScanner::new(
schema.clone(),
metadata,
base_stats,
excluded_count.clone(),
));
let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
let aggr_expr = Arc::new(build_test_aggr_expr(false, false, false).unwrap());
let plan: Arc<dyn ExecutionPlan> = Arc::new(
AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr],
vec![None],
scan,
schema.arrow_schema().clone(),
)
.unwrap(),
);
let optimized = AggregateStats::do_optimize(plan).unwrap();
let final_agg = optimized.as_any().downcast_ref::<AggregateExec>().unwrap();
assert_eq!(final_agg.mode(), &AggregateMode::Final);
let coalesce = final_agg
.input()
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
.unwrap();
let union = coalesce
.input()
.as_any()
.downcast_ref::<UnionExec>()
.unwrap();
assert_eq!(union.inputs().len(), 2);
let partial_agg = union.inputs()[0]
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
assert_eq!(partial_agg.mode(), &AggregateMode::Partial);
let partial_scan = partial_agg
.input()
.as_any()
.downcast_ref::<RegionScanExec>()
.unwrap();
let remaining = partial_scan.scan_input_stats().unwrap().unwrap();
assert_eq!(
remaining
.files
.iter()
.map(|file| file.file_ordinal)
.collect::<Vec<_>>(),
vec![1]
);
assert_eq!(excluded_count.load(Ordering::Relaxed), 2);
}
#[test]
fn test_optimizer_rewrites_final_partial_plan_by_replacing_partial_input() {
let schema = Arc::new(GreptimeSchema::new(vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true),
]));
let metadata = test_region_metadata();
let excluded_count = Arc::new(AtomicUsize::new(0));
let base_stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
],
};
let scanner = Box::new(StatsRecordingScanner::new(
schema.clone(),
metadata,
base_stats,
excluded_count.clone(),
));
let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
let aggr_expr = Arc::new(build_test_aggr_expr(false, false, false).unwrap());
let partial: Arc<dyn ExecutionPlan> = Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr.clone()],
vec![None],
scan,
schema.arrow_schema().clone(),
)
.unwrap(),
);
let plan: Arc<dyn ExecutionPlan> = Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr],
vec![None],
Arc::new(CoalescePartitionsExec::new(partial.clone())),
partial.schema(),
)
.unwrap(),
);
let optimized = AggregateStats::do_optimize(plan).unwrap();
let final_agg = optimized.as_any().downcast_ref::<AggregateExec>().unwrap();
assert_eq!(final_agg.mode(), &AggregateMode::Final);
let coalesce = final_agg
.input()
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
.unwrap();
let union = coalesce
.input()
.as_any()
.downcast_ref::<UnionExec>()
.unwrap();
assert_eq!(union.inputs().len(), 2);
let partial_agg = union.inputs()[0]
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
assert_eq!(partial_agg.mode(), &AggregateMode::Partial);
let partial_scan = partial_agg
.input()
.as_any()
.downcast_ref::<RegionScanExec>()
.unwrap();
let remaining = partial_scan.scan_input_stats().unwrap().unwrap();
assert_eq!(
remaining
.files
.iter()
.map(|file| file.file_ordinal)
.collect::<Vec<_>>(),
vec![1]
);
assert_eq!(excluded_count.load(Ordering::Relaxed), 2);
}
#[derive(Clone, Copy)]
enum ExecutionAggExprCase {
CountValue { ignore_nulls: bool },
CountStar,
MinValue { ignore_nulls: bool },
MaxValue { ignore_nulls: bool },
}
impl ExecutionAggExprCase {
fn build(self) -> AggregateFunctionExpr {
match self {
ExecutionAggExprCase::CountValue { ignore_nulls } => {
build_test_aggr_expr(false, ignore_nulls, false).unwrap()
}
ExecutionAggExprCase::CountStar => build_count_star_aggr_expr().unwrap(),
ExecutionAggExprCase::MinValue { ignore_nulls } => {
build_min_field_aggr_expr_with_ignore_nulls(ignore_nulls).unwrap()
}
ExecutionAggExprCase::MaxValue { ignore_nulls } => {
build_max_field_aggr_expr_with_ignore_nulls(ignore_nulls).unwrap()
}
}
}
}
fn execution_test_schema() -> datatypes::schema::SchemaRef {
Arc::new(GreptimeSchema::new(vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true),
]))
}
fn execution_test_stats() -> RegionScanInputStats {
RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(12))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(33))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(54))),
field_stats: field_stats(Some(4), Some(Value::Int64(7)), Some(Value::Int64(10))),
partition_expr_matches_region: true,
},
],
}
}
fn execution_test_file_batches(
schema: datatypes::schema::SchemaRef,
) -> Vec<(usize, CommonRecordBatch)> {
vec![
(
0,
scan_batch(
schema.clone(),
vec![10, 11, 12],
vec![Some(1), None, Some(3)],
),
),
(
1,
scan_batch(
schema.clone(),
vec![30, 31, 32, 33],
vec![Some(4), Some(5), None, Some(6)],
),
),
(
2,
scan_batch(
schema,
vec![50, 51, 52, 53, 54],
vec![Some(7), Some(8), Some(9), Some(10), None],
),
),
]
}
fn build_execution_test_plan(
schema: datatypes::schema::SchemaRef,
metadata: RegionMetadataRef,
base_stats: RegionScanInputStats,
file_batches: Vec<(usize, CommonRecordBatch)>,
aggr_expr: AggregateFunctionExpr,
excluded_count: Arc<AtomicUsize>,
) -> Arc<dyn ExecutionPlan> {
let scanner = Box::new(
StatsRecordingScanner::new(schema.clone(), metadata, base_stats, excluded_count)
.with_file_batches(file_batches),
);
let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
let aggr_expr = Arc::new(aggr_expr);
let partial: Arc<dyn ExecutionPlan> = Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr.clone()],
vec![None],
scan,
schema.arrow_schema().clone(),
)
.unwrap(),
);
Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr],
vec![None],
Arc::new(CoalescePartitionsExec::new(partial.clone())),
partial.schema(),
)
.unwrap(),
)
}
fn optimized_plan_uses_stats_union(plan: &Arc<dyn ExecutionPlan>) -> bool {
let Some(final_agg) = plan.as_any().downcast_ref::<AggregateExec>() else {
return false;
};
let input = final_agg.input();
if let Some(coalesce) = input.as_any().downcast_ref::<CoalescePartitionsExec>() {
return coalesce
.input()
.as_any()
.downcast_ref::<UnionExec>()
.is_some();
}
input.as_any().downcast_ref::<UnionExec>().is_some()
}
#[tokio::test]
async fn test_optimizer_execution_matrix() {
struct Case {
name: &'static str,
expr: ExecutionAggExprCase,
expect_rewrite: bool,
expected_excluded_count: usize,
expected: String,
}
let cases = [
Case {
name: "count value",
expr: ExecutionAggExprCase::CountValue {
ignore_nulls: false,
},
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+--------------+",
"| count(value) |",
"+--------------+",
"| 9 |",
"+--------------+",
]
.join("\n"),
},
Case {
name: "count value ignore nulls",
expr: ExecutionAggExprCase::CountValue { ignore_nulls: true },
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+--------------+",
"| count(value) |",
"+--------------+",
"| 9 |",
"+--------------+",
]
.join("\n"),
},
Case {
name: "count star",
expr: ExecutionAggExprCase::CountStar,
expect_rewrite: true,
expected_excluded_count: 3,
expected: [
"+----------+",
"| count(*) |",
"+----------+",
"| 12 |",
"+----------+",
]
.join("\n"),
},
Case {
name: "min value",
expr: ExecutionAggExprCase::MinValue {
ignore_nulls: false,
},
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+------------+",
"| min(value) |",
"+------------+",
"| 1 |",
"+------------+",
]
.join("\n"),
},
Case {
name: "min value ignore nulls",
expr: ExecutionAggExprCase::MinValue { ignore_nulls: true },
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+------------+",
"| min(value) |",
"+------------+",
"| 1 |",
"+------------+",
]
.join("\n"),
},
Case {
name: "max value",
expr: ExecutionAggExprCase::MaxValue {
ignore_nulls: false,
},
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+------------+",
"| max(value) |",
"+------------+",
"| 10 |",
"+------------+",
]
.join("\n"),
},
Case {
name: "max value ignore nulls",
expr: ExecutionAggExprCase::MaxValue { ignore_nulls: true },
expect_rewrite: true,
expected_excluded_count: 2,
expected: [
"+------------+",
"| max(value) |",
"+------------+",
"| 10 |",
"+------------+",
]
.join("\n"),
},
];
let schema = execution_test_schema();
let metadata = test_region_metadata();
let base_stats = execution_test_stats();
let file_batches = execution_test_file_batches(schema.clone());
for case in cases {
let unoptimized = build_execution_test_plan(
schema.clone(),
metadata.clone(),
base_stats.clone(),
file_batches.clone(),
case.expr.build(),
Arc::new(AtomicUsize::new(0)),
);
let unoptimized_result =
datafusion::physical_plan::collect(unoptimized, SessionContext::default().task_ctx())
.await
.unwrap();
let optimized_excluded_count = Arc::new(AtomicUsize::new(0));
let optimized = AggregateStats::do_optimize(build_execution_test_plan(
schema.clone(),
metadata.clone(),
base_stats.clone(),
file_batches.clone(),
case.expr.build(),
optimized_excluded_count.clone(),
))
.unwrap();
let optimized_result = datafusion::physical_plan::collect(
optimized.clone(),
SessionContext::default().task_ctx(),
)
.await
.unwrap();
let unoptimized_pretty = arrow::util::pretty::pretty_format_batches(&unoptimized_result)
.unwrap()
.to_string();
let optimized_pretty = arrow::util::pretty::pretty_format_batches(&optimized_result)
.unwrap()
.to_string();
assert_eq!(
unoptimized_pretty,
case.expected.as_str(),
"case: {}",
case.name
);
assert_eq!(
optimized_pretty,
case.expected.as_str(),
"case: {}",
case.name
);
assert_eq!(
optimized_plan_uses_stats_union(&optimized),
case.expect_rewrite,
"case: {}",
case.name
);
assert_eq!(
optimized_excluded_count.load(Ordering::Relaxed),
case.expected_excluded_count,
"case: {}",
case.name
);
}
}