diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index dfbadbfc72..70d737732a 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -19,15 +19,21 @@ use std::sync::Arc; use catalog::CatalogManagerRef; use common_error::ext::BoxedError; +use common_function::aggrs::aggr_wrapper::get_aggr_func; use common_telemetry::debug; +use datafusion::datasource::DefaultTableSource; use datafusion::error::Result as DfResult; use datafusion::logical_expr::Expr; use datafusion::sql::unparser::Unparser; use datafusion_common::tree_node::{ Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{DFSchema, DataFusionError, ScalarValue}; -use datafusion_expr::{Distinct, LogicalPlan, Projection}; +use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference}; +use datafusion_expr::logical_plan::TableScan; +use datafusion_expr::{ + Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr, + bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when, +}; use datatypes::schema::{ColumnSchema, SchemaRef}; use query::QueryEngineRef; use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement}; @@ -37,12 +43,330 @@ use sql::parser::{ParseOptions, ParserContext}; use sql::statements::statement::Statement; use sql::statements::tql::Tql; use table::TableRef; +use table::table::adapter::DfTableProviderAdapter; use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL}; use crate::df_optimizer::apply_df_optimizer; use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu}; use crate::{Error, TableName}; +/// Describes how one aggregate output field should be merged with the +/// corresponding existing field in the sink table. +/// +/// `output_field_name` is the final output/sink schema field name produced by +/// the delta plan and read from the sink table. It is not a DataFusion `Column` +/// reference and must not include a plan/table qualifier. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IncrementalAggregateMergeColumn { + /// Final output/sink field name for the aggregate result/state column. + pub output_field_name: String, + pub merge_op: IncrementalAggregateMergeOp, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IncrementalAggregateMergeOp { + Sum, + Min, + Max, + BoolAnd, + BoolOr, + BitAnd, + BitOr, + BitXor, +} + +/// Analysis result for an incremental aggregate plan. +/// +/// `group_key_names` and each merge column's `output_field_name` are final +/// output/sink schema field names used to project both the delta plan and the +/// sink table before the left-join merge. They are not DataFusion logical-plan +/// `Column` references and must not be qualified. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IncrementalAggregateAnalysis { + /// Final output/sink field names for group keys used as merge join keys. + pub group_key_names: Vec, + pub merge_columns: Vec, + pub unsupported_exprs: Vec, +} + +#[derive(Default)] +struct LastAggregateExprFinder { + aggr_exprs: Option>, +} + +impl TreeNodeVisitor<'_> for LastAggregateExprFinder { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result { + if let LogicalPlan::Aggregate(aggregate) = node { + self.aggr_exprs = Some(aggregate.aggr_expr.clone()); + } + Ok(TreeNodeRecursion::Continue) + } +} + +pub fn analyze_incremental_aggregate_plan( + plan: &LogicalPlan, +) -> Result, Error> { + let mut group_finder = FindGroupByFinalName::default(); + plan.visit(&mut group_finder) + .with_context(|_| DatafusionSnafu { + context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"), + })?; + + let mut aggregate_finder = LastAggregateExprFinder::default(); + plan.visit(&mut aggregate_finder) + .with_context(|_| DatafusionSnafu { + context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"), + })?; + let Some(aggr_exprs) = aggregate_finder.aggr_exprs else { + return Ok(None); + }; + + let mut output_aliases = HashMap::new(); + if let LogicalPlan::Projection(projection) = plan { + for expr in &projection.expr { + match expr { + Expr::Alias(alias) => { + if let Expr::Column(col) = alias.expr.as_ref() { + output_aliases.insert(col.name.clone(), alias.name.clone()); + } + } + Expr::Column(col) => { + output_aliases.insert(col.name.clone(), col.name.clone()); + } + _ => {} + } + } + } + + let mut group_key_names = group_finder + .get_group_expr_names() + .unwrap_or_default() + .into_iter() + .collect::>(); + group_key_names.sort(); + + let mut merge_columns = Vec::with_capacity(aggr_exprs.len()); + let mut unsupported_exprs = Vec::new(); + for aggr_expr in aggr_exprs { + let Some(aggr_func) = get_aggr_func(&aggr_expr) else { + unsupported_exprs.push(aggr_expr.to_string()); + continue; + }; + + let aggr_name = aggr_func.func.name().to_ascii_lowercase(); + let merge_op = if aggr_func.params.distinct { + None + } else { + match aggr_name.as_str() { + "sum" | "count" => Some(IncrementalAggregateMergeOp::Sum), + "min" => Some(IncrementalAggregateMergeOp::Min), + "max" => Some(IncrementalAggregateMergeOp::Max), + "bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd), + "bool_or" => Some(IncrementalAggregateMergeOp::BoolOr), + "bit_and" => Some(IncrementalAggregateMergeOp::BitAnd), + "bit_or" => Some(IncrementalAggregateMergeOp::BitOr), + "bit_xor" => Some(IncrementalAggregateMergeOp::BitXor), + _ => None, + } + }; + + let Some(merge_op) = merge_op else { + unsupported_exprs.push(aggr_expr.to_string()); + continue; + }; + + let raw_name = aggr_expr.qualified_name().1; + let output_field_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name); + merge_columns.push(IncrementalAggregateMergeColumn { + output_field_name, + merge_op, + }); + } + + Ok(Some(IncrementalAggregateAnalysis { + group_key_names, + merge_columns, + unsupported_exprs, + })) +} + +pub async fn rewrite_incremental_aggregate_with_sink_merge( + delta_plan: &LogicalPlan, + analysis: &IncrementalAggregateAnalysis, + sink_table: TableRef, + sink_table_name: &TableName, +) -> Result { + ensure!( + analysis.unsupported_exprs.is_empty(), + InvalidQuerySnafu { + reason: format!( + "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}", + analysis.unsupported_exprs + ) + } + ); + + ensure!( + !analysis.merge_columns.is_empty(), + InvalidQuerySnafu { + reason: + "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns" + .to_string() + } + ); + + let delta_alias = "__flow_delta"; + let sink_alias = "__flow_sink"; + + let mut selected_columns = analysis.group_key_names.clone(); + selected_columns.extend( + analysis + .merge_columns + .iter() + .map(|c| c.output_field_name.clone()), + ); + + let selected_exprs = selected_columns.iter().map(col).collect::>(); + let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) + .project(selected_exprs.clone()) + .with_context(|_| DatafusionSnafu { + context: "Failed to project delta plan for incremental sink merge".to_string(), + })? + .alias(delta_alias) + .with_context(|_| DatafusionSnafu { + context: "Failed to alias delta plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build projected delta plan for incremental sink merge".to_string(), + })?; + + let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table)); + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + let sink_scan = LogicalPlan::TableScan( + TableScan::try_new( + TableReference::Full { + catalog: sink_table_name[0].clone().into(), + schema: sink_table_name[1].clone().into(), + table: sink_table_name[2].clone().into(), + }, + table_source, + None, + vec![], + None, + ) + .with_context(|_| DatafusionSnafu { + context: "Failed to build sink table scan for incremental sink merge".to_string(), + })?, + ); + + let sink_selected = LogicalPlanBuilder::from(sink_scan) + .project(selected_exprs) + .with_context(|_| DatafusionSnafu { + context: "Failed to project sink table scan for incremental sink merge".to_string(), + })? + .alias(sink_alias) + .with_context(|_| DatafusionSnafu { + context: "Failed to alias sink plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build projected sink plan for incremental sink merge".to_string(), + })?; + + let join_keys = ( + analysis + .group_key_names + .iter() + .map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}"))) + .collect::>(), + analysis + .group_key_names + .iter() + .map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}"))) + .collect::>(), + ); + + let joined = LogicalPlanBuilder::from(delta_selected) + .join(sink_selected, JoinType::Left, join_keys, None) + .with_context(|_| DatafusionSnafu { + context: "Failed to left join delta and sink plans for incremental sink merge" + .to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build left join plan for incremental sink merge".to_string(), + })?; + + let mut projection_exprs = analysis + .group_key_names + .iter() + .map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone())) + .collect::>(); + projection_exprs.extend( + analysis + .merge_columns + .iter() + .map(|merge_col| build_left_join_merge_expr(delta_alias, sink_alias, merge_col)), + ); + + LogicalPlanBuilder::from(joined) + .project(projection_exprs) + .with_context(|_| DatafusionSnafu { + context: "Failed to build projection merge plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to finalize incremental aggregate sink merge plan".to_string(), + }) +} + +fn build_left_join_merge_expr( + delta_alias: &str, + sink_alias: &str, + merge_col: &IncrementalAggregateMergeColumn, +) -> Expr { + let left = col(format!("{delta_alias}.{}", merge_col.output_field_name)); + let right = col(format!("{sink_alias}.{}", merge_col.output_field_name)); + let merged = match merge_col.merge_op { + IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone())) + .unwrap(), + IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone()) + .when(left.clone().lt_eq(right.clone()), left.clone()) + .otherwise(right.clone()) + .unwrap(), + IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone()) + .when(left.clone().gt_eq(right.clone()), left.clone()) + .otherwise(right.clone()) + .unwrap(), + IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(and(left.clone(), right.clone())) + .unwrap(), + IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(or(left.clone(), right.clone())) + .unwrap(), + IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_and(left.clone(), right.clone())) + .unwrap(), + IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_or(left.clone(), right.clone())) + .unwrap(), + IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_xor(left.clone(), right.clone())) + .unwrap(), + }; + merged.alias(merge_col.output_field_name.clone()) +} + pub async fn get_table_info_df_schema( catalog_mr: CatalogManagerRef, table_name: TableName, @@ -907,6 +1231,136 @@ mod test { } } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let testcases = vec![ + ( + "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::Sum, + ), + ( + "SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::Sum, + ), + ( + "SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::Min, + ), + ( + "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::Max, + ), + ( + "SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::BitAnd, + ), + ( + "SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::BitOr, + ), + ( + "SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::BitXor, + ), + ]; + + for (sql, expected_op) in testcases { + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_key_names.contains(&"ts".to_string())); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!(analysis.merge_columns[0].output_field_name, "number"); + assert_eq!(analysis.merge_columns[0].merge_op, expected_op); + } + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_keeps_aliases_for_multiple_aggregates() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS max_number, min(number) AS min_number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_key_names.contains(&"ts".to_string())); + assert_eq!(analysis.merge_columns.len(), 2); + assert!(analysis.merge_columns.iter().any(|merge_col| { + merge_col.output_field_name == "max_number" + && merge_col.merge_op == IncrementalAggregateMergeOp::Max + })); + assert!(analysis.merge_columns.iter().any(|merge_col| { + merge_col.output_field_name == "min_number" + && merge_col.merge_op == IncrementalAggregateMergeOp::Min + })); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_avg() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(!analysis.unsupported_exprs.is_empty()); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_distinct() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT count(distinct number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(!analysis.unsupported_exprs.is_empty()); + } + + #[tokio::test] + async fn test_rewrite_incremental_aggregate_with_left_join() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + let (sink_table, _) = get_table_info_df_schema( + query_engine.engine_state().catalog_manager().clone(), + [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + ) + .await + .unwrap(); + + let rewritten = rewrite_incremental_aggregate_with_sink_merge( + &plan, + &analysis, + sink_table, + &[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + ) + .await + .unwrap(); + + let plan_text = format!("{}", rewritten.display_indent()); + assert!(plan_text.contains("Left Join")); + assert!(!plan_text.contains("Union")); + } + #[tokio::test] async fn test_null_cast() { let query_engine = create_test_query_engine();