mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 22:40:40 +00:00
@@ -28,7 +28,7 @@ use datafusion::sql::unparser::Unparser;
|
||||
use datafusion_common::tree_node::{
|
||||
Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
|
||||
};
|
||||
use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference};
|
||||
use datafusion_common::{DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference};
|
||||
use datafusion_expr::logical_plan::TableScan;
|
||||
use datafusion_expr::{
|
||||
Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
|
||||
@@ -105,6 +105,23 @@ impl TreeNodeVisitor<'_> for LastAggregateExprFinder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively find all `Expr::Column` names inside an expression tree.
|
||||
/// Only recurses into wrappers that are merge-transparent (type casts).
|
||||
/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`) are
|
||||
/// intentionally not recursed into since their merge semantics would be
|
||||
/// incorrect — the caller will fall back to the raw aggregate name.
|
||||
fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
|
||||
match expr {
|
||||
Expr::Column(col) => {
|
||||
names.push(col.name.clone());
|
||||
}
|
||||
Expr::Alias(alias) => find_column_names(&alias.expr, names),
|
||||
Expr::Cast(cast) => find_column_names(&cast.expr, names),
|
||||
Expr::TryCast(try_cast) => find_column_names(&try_cast.expr, names),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn analyze_incremental_aggregate_plan(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
|
||||
@@ -128,12 +145,26 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
for expr in &projection.expr {
|
||||
match expr {
|
||||
Expr::Alias(alias) => {
|
||||
if let Expr::Column(col) = alias.expr.as_ref() {
|
||||
output_aliases.insert(col.name.clone(), alias.name.clone());
|
||||
// Alias resolution has three cases:
|
||||
// - 0 Column refs (e.g., literal `42 AS lit`): skip — no mapping
|
||||
// - 1 Column ref: record the mapping (e.g., `CAST(sum(x)) AS total`)
|
||||
// - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`):
|
||||
// skip — ambiguous merge semantics, fall back to raw agg name
|
||||
let alias_name = alias.name.clone();
|
||||
let mut col_names = Vec::new();
|
||||
find_column_names(&alias.expr, &mut col_names);
|
||||
if col_names.len() == 1 {
|
||||
if let Some(col_name) = col_names.into_iter().next() {
|
||||
output_aliases.entry(col_name).or_insert(alias_name);
|
||||
}
|
||||
}
|
||||
// If >1 column references detected (e.g., COALESCE(sum(x), sum(y))),
|
||||
// intentionally skip alias mapping — the merge semantics are ambiguous.
|
||||
}
|
||||
Expr::Column(col) => {
|
||||
output_aliases.insert(col.name.clone(), col.name.clone());
|
||||
output_aliases
|
||||
.entry(col.name.clone())
|
||||
.or_insert(col.name.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -178,7 +209,10 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
};
|
||||
|
||||
let raw_name = aggr_expr.qualified_name().1;
|
||||
let output_field_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name);
|
||||
let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
};
|
||||
merge_columns.push(IncrementalAggregateMergeColumn {
|
||||
output_field_name,
|
||||
merge_op,
|
||||
@@ -290,7 +324,13 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge(
|
||||
);
|
||||
|
||||
let joined = LogicalPlanBuilder::from(delta_selected)
|
||||
.join(sink_selected, JoinType::Left, join_keys, None)
|
||||
.join_detailed(
|
||||
sink_selected,
|
||||
JoinType::Left,
|
||||
join_keys,
|
||||
None,
|
||||
NullEquality::NullEqualsNull,
|
||||
)
|
||||
.with_context(|_| DatafusionSnafu {
|
||||
context: "Failed to left join delta and sink plans for incremental sink merge"
|
||||
.to_string(),
|
||||
@@ -1340,6 +1380,28 @@ mod test {
|
||||
assert!(!analysis.unsupported_exprs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_coalesce_wrapped_aggregate() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
// COALESCE wraps the aggregate output — the wrapper is not merge-transparent,
|
||||
// so the analyzer should mark the aggregate as unsupported rather than
|
||||
// attempting an unsafe incremental rewrite.
|
||||
let sql =
|
||||
"SELECT COALESCE(max(number), 0) AS coalesced_max, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
// Non-transparent wrapper → alias unresolvable → unsupported
|
||||
assert!(
|
||||
!analysis.unsupported_exprs.is_empty(),
|
||||
"COALESCE-wrapped aggregate should be unsupported"
|
||||
);
|
||||
assert!(
|
||||
analysis.merge_columns.is_empty(),
|
||||
"COALESCE-wrapped aggregate should have no merge columns"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rewrite_incremental_aggregate_with_left_join() {
|
||||
let query_engine = create_test_query_engine();
|
||||
@@ -1391,4 +1453,27 @@ mod test {
|
||||
.encode(&plan, DefaultSerializer)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_handles_cast_wrapped_alias() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
// CAST wraps the aggregate output — the analyzer should still find the alias
|
||||
let sql =
|
||||
"SELECT CAST(sum(number) AS BIGINT) AS total, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert!(analysis.group_key_names.contains(&"ts".to_string()));
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].output_field_name, "total",
|
||||
"Expected alias 'total' for CAST-wrapped aggregate, but got '{}'",
|
||||
analysis.merge_columns[0].output_field_name
|
||||
);
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user