mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-21 07:20:41 +00:00
@@ -111,29 +111,25 @@ pub struct IncrementalAggregateAnalysis {
|
||||
pub unsupported_exprs: Vec<String>,
|
||||
}
|
||||
|
||||
/// Visitor that captures the aggregate expressions from the **innermost**
|
||||
/// `Aggregate` node in the plan tree.
|
||||
/// Visitor that captures aggregate expressions and counts `Aggregate` nodes in
|
||||
/// the plan tree.
|
||||
///
|
||||
/// Since this visits `f_down` and continues recursion, it will overwrite
|
||||
/// `aggr_exprs` for each `Aggregate` it encounters, ultimately retaining the
|
||||
/// deepest (innermost) one. This is the intended behavior for the incremental
|
||||
/// aggregate rewrite: nested aggregates are not supported, and the delta plan
|
||||
/// produced by the flow engine places a single `Aggregate` at the bottom.
|
||||
///
|
||||
/// If the plan contains multiple nested `Aggregate` nodes (a subquery with its
|
||||
/// own aggregation), the innermost one is captured, which is conservative and
|
||||
/// safe — it prevents the rewrite from incorrectly operating on the outer
|
||||
/// aggregate.
|
||||
/// Incremental aggregate rewrite only supports plans with exactly one aggregate
|
||||
/// node. The count lets the analyzer reject nested/sibling aggregate plans
|
||||
/// instead of accidentally rewriting against whichever aggregate was visited
|
||||
/// last.
|
||||
#[derive(Default)]
|
||||
struct LastAggregateExprFinder {
|
||||
struct AggregateExprFinder {
|
||||
aggr_exprs: Option<Vec<Expr>>,
|
||||
aggregate_count: usize,
|
||||
}
|
||||
|
||||
impl TreeNodeVisitor<'_> for LastAggregateExprFinder {
|
||||
impl TreeNodeVisitor<'_> for AggregateExprFinder {
|
||||
type Node = LogicalPlan;
|
||||
|
||||
fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
|
||||
if let LogicalPlan::Aggregate(aggregate) = node {
|
||||
self.aggregate_count += 1;
|
||||
self.aggr_exprs = Some(aggregate.aggr_expr.clone());
|
||||
}
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
@@ -186,27 +182,51 @@ fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
|
||||
Ok(group_key_names)
|
||||
}
|
||||
|
||||
fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<Option<Vec<Expr>>, Error> {
|
||||
let mut aggregate_finder = LastAggregateExprFinder::default();
|
||||
fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<(usize, Option<Vec<Expr>>), Error> {
|
||||
let mut aggregate_finder = AggregateExprFinder::default();
|
||||
plan.visit(&mut aggregate_finder)
|
||||
.with_context(|_| DatafusionSnafu {
|
||||
context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"),
|
||||
})?;
|
||||
Ok(aggregate_finder.aggr_exprs)
|
||||
Ok((
|
||||
aggregate_finder.aggregate_count,
|
||||
aggregate_finder.aggr_exprs,
|
||||
))
|
||||
}
|
||||
|
||||
fn contains_aggregate(plan: &LogicalPlan) -> bool {
|
||||
matches!(plan, LogicalPlan::Aggregate(_)) || plan.inputs().into_iter().any(contains_aggregate)
|
||||
}
|
||||
fn check_inc_aggr_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
|
||||
// Supported final shape: optional output Projection directly over one
|
||||
// Aggregate. Post-aggregate filters (HAVING), ordering, limits,
|
||||
// distinct/window/union/extension nodes are intentionally not accepted.
|
||||
let plan = match plan {
|
||||
LogicalPlan::Projection(projection) => projection.input.as_ref(),
|
||||
_ => plan,
|
||||
};
|
||||
|
||||
fn has_filter_above_aggregate(plan: &LogicalPlan) -> bool {
|
||||
match plan {
|
||||
// HAVING and other post-aggregate filters appear as `Filter` nodes above
|
||||
// an `Aggregate`. Applying them before the sink-merge would filter on
|
||||
// the delta aggregate rather than the final merged aggregate, so reject
|
||||
// them until the rewrite can rebuild the predicate after merging.
|
||||
LogicalPlan::Filter(filter) if contains_aggregate(filter.input.as_ref()) => true,
|
||||
_ => plan.inputs().into_iter().any(has_filter_above_aggregate),
|
||||
LogicalPlan::Aggregate(aggregate) => check_input_plan_shape(aggregate.input.as_ref()),
|
||||
LogicalPlan::Filter(_) => Err(
|
||||
"unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
|
||||
.to_string(),
|
||||
),
|
||||
_ => Err(
|
||||
"unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
|
||||
match plan {
|
||||
// Supported aggregate input shape: optional WHERE filter over a table scan.
|
||||
LogicalPlan::TableScan(_) => Ok(()),
|
||||
LogicalPlan::Filter(filter)
|
||||
if matches!(filter.input.as_ref(), LogicalPlan::TableScan(_)) =>
|
||||
{
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(
|
||||
"unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,22 +323,30 @@ fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
|
||||
projection_info
|
||||
}
|
||||
|
||||
fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Option<IncrementalAggregateMergeOp> {
|
||||
let aggr_func = get_aggr_func(aggr_expr)?;
|
||||
fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
|
||||
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
|
||||
return Err(aggr_expr.to_string());
|
||||
};
|
||||
if aggr_func.params.distinct {
|
||||
return None;
|
||||
return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
|
||||
}
|
||||
if !aggr_func.params.order_by.is_empty() {
|
||||
return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
|
||||
}
|
||||
if aggr_func.params.null_treatment.is_some() {
|
||||
return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
|
||||
}
|
||||
|
||||
match aggr_func.func.name().to_ascii_lowercase().as_str() {
|
||||
"sum" | "count" => Some(IncrementalAggregateMergeOp::Sum),
|
||||
"min" => Some(IncrementalAggregateMergeOp::Min),
|
||||
"max" => Some(IncrementalAggregateMergeOp::Max),
|
||||
"bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd),
|
||||
"bool_or" => Some(IncrementalAggregateMergeOp::BoolOr),
|
||||
"bit_and" => Some(IncrementalAggregateMergeOp::BitAnd),
|
||||
"bit_or" => Some(IncrementalAggregateMergeOp::BitOr),
|
||||
"bit_xor" => Some(IncrementalAggregateMergeOp::BitXor),
|
||||
_ => None,
|
||||
"sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
|
||||
"min" => Ok(IncrementalAggregateMergeOp::Min),
|
||||
"max" => Ok(IncrementalAggregateMergeOp::Max),
|
||||
"bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
|
||||
"bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
|
||||
"bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
|
||||
"bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
|
||||
"bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
|
||||
_ => Err(aggr_expr.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,7 +398,7 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
|
||||
let group_key_names = find_group_key_names(plan)?;
|
||||
let Some(aggr_exprs) = find_aggregate_exprs(plan)? else {
|
||||
let (aggregate_count, Some(aggr_exprs)) = find_aggregate_exprs(plan)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let projection_info = collect_output_projection_info(plan);
|
||||
@@ -382,11 +410,13 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
.into_iter()
|
||||
.map(|name| format!("duplicate output field name: {name}"))
|
||||
.collect::<Vec<_>>();
|
||||
if has_filter_above_aggregate(plan) {
|
||||
unsupported_exprs.push(
|
||||
"unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
|
||||
.to_string(),
|
||||
);
|
||||
if aggregate_count != 1 {
|
||||
unsupported_exprs.push(format!(
|
||||
"unsupported aggregate plan contains {aggregate_count} Aggregate nodes"
|
||||
));
|
||||
}
|
||||
if let Err(reason) = check_inc_aggr_plan_shape(plan) {
|
||||
unsupported_exprs.push(reason);
|
||||
}
|
||||
unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
|
||||
if group_key_names.is_empty()
|
||||
@@ -400,9 +430,12 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
));
|
||||
}
|
||||
for aggr_expr in aggr_exprs {
|
||||
let Some(merge_op) = merge_op_for_aggregate_expr(&aggr_expr) else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
|
||||
Ok(merge_op) => merge_op,
|
||||
Err(reason) => {
|
||||
unsupported_exprs.push(reason);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let Some(output_field_name) = resolve_aggregate_output_field_name(
|
||||
&aggr_expr,
|
||||
@@ -1402,6 +1435,28 @@ mod test {
|
||||
.alias(field_name)
|
||||
}
|
||||
|
||||
async fn analyze_test_sql(sql: &str) -> IncrementalAggregateAnalysis {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
analyze_incremental_aggregate_plan(&plan).unwrap().unwrap()
|
||||
}
|
||||
|
||||
fn assert_unsupported(analysis: &IncrementalAggregateAnalysis, reason: &str) {
|
||||
assert!(
|
||||
analysis
|
||||
.unsupported_exprs
|
||||
.iter()
|
||||
.any(|expr| expr.contains(reason)),
|
||||
"expected unsupported reason containing {reason:?}, got {:?}",
|
||||
analysis.unsupported_exprs
|
||||
);
|
||||
assert!(
|
||||
analysis.merge_columns.is_empty(),
|
||||
"unsupported analysis should disable merge columns"
|
||||
);
|
||||
}
|
||||
|
||||
/// test if uppercase are handled correctly(with quote)
|
||||
#[tokio::test]
|
||||
async fn test_sql_plan_convert() {
|
||||
@@ -1808,26 +1863,83 @@ mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_having_filter() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts HAVING sum(number) > 10";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "post-aggregate filter");
|
||||
}
|
||||
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(
|
||||
analysis
|
||||
.unsupported_exprs
|
||||
.iter()
|
||||
.any(|expr| expr.contains("post-aggregate filter")),
|
||||
"HAVING/post-aggregate filter should be unsupported: {:?}",
|
||||
analysis.unsupported_exprs
|
||||
);
|
||||
assert!(
|
||||
analysis.merge_columns.is_empty(),
|
||||
"unsupported HAVING should disable merge columns"
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_allows_aggregate_filter() {
|
||||
let sql = "SELECT sum(number) FILTER (WHERE number > 10) AS number, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert!(analysis.group_key_names.contains(&"ts".to_string()));
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(analysis.merge_columns[0].output_field_name, "number");
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_aggregate_order_by() {
|
||||
let sql = "SELECT sum(number ORDER BY ts) AS number, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "aggregate ORDER BY");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_sort_above_aggregate() {
|
||||
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts ORDER BY number DESC";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "post-aggregate plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_limit_above_aggregate() {
|
||||
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts LIMIT 1";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "post-aggregate plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_distinct_above_aggregate() {
|
||||
let sql = "SELECT DISTINCT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "post-aggregate plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_nested_aggregates() {
|
||||
let sql = "SELECT sum(cnt) AS total FROM (SELECT count(*) AS cnt, ts FROM numbers_with_ts GROUP BY ts) s";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "Aggregate nodes");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_union_aggregate_branches() {
|
||||
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts UNION ALL SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "Aggregate nodes");
|
||||
assert_unsupported(&analysis, "post-aggregate plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_window_above_aggregate() {
|
||||
let sql = "SELECT sum(number) AS number, ts, row_number() OVER (ORDER BY sum(number)) AS rn FROM numbers_with_ts GROUP BY ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "post-aggregate plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_join_below_aggregate() {
|
||||
let sql = "SELECT sum(lhs.number) AS number, lhs.ts FROM numbers_with_ts AS lhs JOIN numbers_with_ts AS rhs ON lhs.ts = rhs.ts GROUP BY lhs.ts";
|
||||
let analysis = analyze_test_sql(sql).await;
|
||||
assert_unsupported(&analysis, "aggregate input plan shape");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() {
|
||||
let query_engine = create_test_query_engine();
|
||||
@@ -2627,28 +2739,20 @@ mod test {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_last_aggregate_finder_captures_innermost() {
|
||||
async fn test_aggregate_expr_finder_counts_multiple_aggregates() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
// Subquery has an inner aggregate (count), outer query has another aggregate (sum).
|
||||
// LastAggregateExprFinder should capture the innermost one (count).
|
||||
let sql = "SELECT sum(cnt) AS total, ts \
|
||||
FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \
|
||||
GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let mut finder = LastAggregateExprFinder::default();
|
||||
let mut finder = AggregateExprFinder::default();
|
||||
plan.visit(&mut finder).unwrap();
|
||||
let aggr_exprs = finder.aggr_exprs.unwrap();
|
||||
assert_eq!(
|
||||
aggr_exprs.len(),
|
||||
1,
|
||||
"Expected innermost aggregate to have 1 expression"
|
||||
);
|
||||
let found_name = aggr_exprs[0].qualified_name().1.to_ascii_lowercase();
|
||||
assert!(
|
||||
found_name.contains("count"),
|
||||
"Expected innermost aggregate to be count, got: {found_name}"
|
||||
finder.aggregate_count > 1,
|
||||
"nested aggregate plans should be identifiable as unsupported"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user