refactor: per review

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-05-15 13:27:22 +08:00
parent 3094a7868d
commit c3f378dc68

View File

@@ -28,7 +28,7 @@ use datafusion::sql::unparser::Unparser;
use datafusion_common::tree_node::{
Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference};
use datafusion_common::{DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference};
use datafusion_expr::logical_plan::TableScan;
use datafusion_expr::{
Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
@@ -105,6 +105,23 @@ impl TreeNodeVisitor<'_> for LastAggregateExprFinder {
}
}
/// Recursively find all `Expr::Column` names inside an expression tree.
/// Only recurses into wrappers that are merge-transparent (type casts).
/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`) are
/// intentionally not recursed into since their merge semantics would be
/// incorrect — the caller will fall back to the raw aggregate name.
fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
match expr {
Expr::Column(col) => {
names.push(col.name.clone());
}
Expr::Alias(alias) => find_column_names(&alias.expr, names),
Expr::Cast(cast) => find_column_names(&cast.expr, names),
Expr::TryCast(try_cast) => find_column_names(&try_cast.expr, names),
_ => {}
}
}
pub fn analyze_incremental_aggregate_plan(
plan: &LogicalPlan,
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
@@ -128,12 +145,26 @@ pub fn analyze_incremental_aggregate_plan(
for expr in &projection.expr {
match expr {
Expr::Alias(alias) => {
if let Expr::Column(col) = alias.expr.as_ref() {
output_aliases.insert(col.name.clone(), alias.name.clone());
// Alias resolution has three cases:
// - 0 Column refs (e.g., literal `42 AS lit`): skip — no mapping
// - 1 Column ref: record the mapping (e.g., `CAST(sum(x)) AS total`)
// - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`):
// skip — ambiguous merge semantics, fall back to raw agg name
let alias_name = alias.name.clone();
let mut col_names = Vec::new();
find_column_names(&alias.expr, &mut col_names);
if col_names.len() == 1 {
if let Some(col_name) = col_names.into_iter().next() {
output_aliases.entry(col_name).or_insert(alias_name);
}
}
// If >1 column references detected (e.g., COALESCE(sum(x), sum(y))),
// intentionally skip alias mapping — the merge semantics are ambiguous.
}
Expr::Column(col) => {
output_aliases.insert(col.name.clone(), col.name.clone());
output_aliases
.entry(col.name.clone())
.or_insert(col.name.clone());
}
_ => {}
}
@@ -178,7 +209,10 @@ pub fn analyze_incremental_aggregate_plan(
};
let raw_name = aggr_expr.qualified_name().1;
let output_field_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name);
let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
merge_columns.push(IncrementalAggregateMergeColumn {
output_field_name,
merge_op,
@@ -290,7 +324,13 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge(
);
let joined = LogicalPlanBuilder::from(delta_selected)
.join(sink_selected, JoinType::Left, join_keys, None)
.join_detailed(
sink_selected,
JoinType::Left,
join_keys,
None,
NullEquality::NullEqualsNull,
)
.with_context(|_| DatafusionSnafu {
context: "Failed to left join delta and sink plans for incremental sink merge"
.to_string(),
@@ -1340,6 +1380,28 @@ mod test {
assert!(!analysis.unsupported_exprs.is_empty());
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_rejects_coalesce_wrapped_aggregate() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
// COALESCE wraps the aggregate output — the wrapper is not merge-transparent,
// so the analyzer should mark the aggregate as unsupported rather than
// attempting an unsafe incremental rewrite.
let sql =
"SELECT COALESCE(max(number), 0) AS coalesced_max, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
// Non-transparent wrapper → alias unresolvable → unsupported
assert!(
!analysis.unsupported_exprs.is_empty(),
"COALESCE-wrapped aggregate should be unsupported"
);
assert!(
analysis.merge_columns.is_empty(),
"COALESCE-wrapped aggregate should have no merge columns"
);
}
#[tokio::test]
async fn test_rewrite_incremental_aggregate_with_left_join() {
let query_engine = create_test_query_engine();
@@ -1391,4 +1453,27 @@ mod test {
.encode(&plan, DefaultSerializer)
.unwrap();
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_handles_cast_wrapped_alias() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
// CAST wraps the aggregate output — the analyzer should still find the alias
let sql =
"SELECT CAST(sum(number) AS BIGINT) AS total, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_key_names.contains(&"ts".to_string()));
assert_eq!(analysis.merge_columns.len(), 1);
assert_eq!(
analysis.merge_columns[0].output_field_name, "total",
"Expected alias 'total' for CAST-wrapped aggregate, but got '{}'",
analysis.merge_columns[0].output_field_name
);
assert_eq!(
analysis.merge_columns[0].merge_op,
IncrementalAggregateMergeOp::Sum
);
}
}