per review

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-05-15 18:14:55 +08:00
parent 4b3efba805
commit 0251113fc7

View File

@@ -59,10 +59,27 @@ use crate::{Error, TableName};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IncrementalAggregateMergeColumn {
/// Final output/sink field name for the aggregate result/state column.
///
/// Must NOT include a plan/table qualifier (no `.` separator).
pub output_field_name: String,
pub merge_op: IncrementalAggregateMergeOp,
}
impl IncrementalAggregateMergeColumn {
/// Create a new merge column, validating that `output_field_name` does not
/// contain a plan/table qualifier.
pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
debug_assert!(
!output_field_name.contains('.'),
"output_field_name must not include a plan/table qualifier, got: {output_field_name}"
);
Self {
output_field_name,
merge_op,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IncrementalAggregateMergeOp {
Sum,
@@ -89,6 +106,19 @@ pub struct IncrementalAggregateAnalysis {
pub unsupported_exprs: Vec<String>,
}
/// Visitor that captures the aggregate expressions from the **innermost**
/// `Aggregate` node in the plan tree.
///
/// Since this visits `f_down` and continues recursion, it will overwrite
/// `aggr_exprs` for each `Aggregate` it encounters, ultimately retaining the
/// deepest (innermost) one. This is the intended behavior for the incremental
/// aggregate rewrite: nested aggregates are not supported, and the delta plan
/// produced by the flow engine places a single `Aggregate` at the bottom.
///
/// If the plan contains multiple nested `Aggregate` nodes (a subquery with its
/// own aggregation), the innermost one is captured, which is conservative and
/// safe — it prevents the rewrite from incorrectly operating on the outer
/// aggregate.
#[derive(Default)]
struct LastAggregateExprFinder {
aggr_exprs: Option<Vec<Expr>>,
@@ -209,15 +239,20 @@ pub fn analyze_incremental_aggregate_plan(
continue;
};
// qualified_name() returns (Option<String>, String) where the second
// element is the unqualified column/alias name. This relies on
// DataFusion's internal naming convention: aggregate expressions
// emit a column named after the aggregate itself (e.g. "SUM(x)"),
// which matches what the projection aliases reference.
let raw_name = aggr_expr.qualified_name().1;
let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
merge_columns.push(IncrementalAggregateMergeColumn {
merge_columns.push(IncrementalAggregateMergeColumn::new(
output_field_name,
merge_op,
});
));
}
Ok(Some(IncrementalAggregateAnalysis {
@@ -1293,38 +1328,55 @@ mod test {
async fn test_analyze_incremental_aggregate_plan() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let testcases = vec![
let testcases: Vec<(&str, IncrementalAggregateMergeOp, &str)> = vec![
(
"SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Sum,
"number",
),
(
"SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Sum,
"number",
),
(
"SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Min,
"number",
),
(
"SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Max,
"number",
),
(
"SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitAnd,
"number",
),
(
"SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitOr,
"number",
),
(
"SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitXor,
"number",
),
(
"SELECT bool_and(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BoolAnd,
"cond",
),
(
"SELECT bool_or(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BoolOr,
"cond",
),
];
for (sql, expected_op) in testcases {
for (sql, expected_op, expected_field_name) in testcases {
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
.await
.unwrap();
@@ -1333,7 +1385,10 @@ mod test {
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_key_names.contains(&"ts".to_string()));
assert_eq!(analysis.merge_columns.len(), 1);
assert_eq!(analysis.merge_columns[0].output_field_name, "number");
assert_eq!(
analysis.merge_columns[0].output_field_name,
expected_field_name
);
assert_eq!(analysis.merge_columns[0].merge_op, expected_op);
}
}
@@ -1359,6 +1414,26 @@ mod test {
}));
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT sum(number) AS total, ts, number AS bucket FROM numbers_with_ts GROUP BY ts, number";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_key_names.contains(&"ts".to_string()));
assert!(analysis.group_key_names.contains(&"bucket".to_string()));
assert_eq!(analysis.group_key_names.len(), 2);
assert_eq!(analysis.merge_columns.len(), 1);
assert_eq!(analysis.merge_columns[0].output_field_name, "total");
assert_eq!(
analysis.merge_columns[0].merge_op,
IncrementalAggregateMergeOp::Sum
);
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_rejects_avg() {
let query_engine = create_test_query_engine();
@@ -1477,4 +1552,30 @@ mod test {
IncrementalAggregateMergeOp::Sum
);
}
#[tokio::test]
async fn test_last_aggregate_finder_captures_innermost() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
// Subquery has an inner aggregate (count), outer query has another aggregate (sum).
// LastAggregateExprFinder should capture the innermost one (count).
let sql = "SELECT sum(cnt) AS total, ts \
FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \
GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let mut finder = LastAggregateExprFinder::default();
plan.visit(&mut finder).unwrap();
let aggr_exprs = finder.aggr_exprs.unwrap();
assert_eq!(
aggr_exprs.len(),
1,
"Expected innermost aggregate to have 1 expression"
);
let found_name = aggr_exprs[0].qualified_name().1.to_ascii_lowercase();
assert!(
found_name.contains("count"),
"Expected innermost aggregate to be count, got: {found_name}"
);
}
}