mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-18 14:00:39 +00:00
@@ -59,10 +59,27 @@ use crate::{Error, TableName};
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct IncrementalAggregateMergeColumn {
|
||||
/// Final output/sink field name for the aggregate result/state column.
|
||||
///
|
||||
/// Must NOT include a plan/table qualifier (no `.` separator).
|
||||
pub output_field_name: String,
|
||||
pub merge_op: IncrementalAggregateMergeOp,
|
||||
}
|
||||
|
||||
impl IncrementalAggregateMergeColumn {
|
||||
/// Create a new merge column, validating that `output_field_name` does not
|
||||
/// contain a plan/table qualifier.
|
||||
pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
|
||||
debug_assert!(
|
||||
!output_field_name.contains('.'),
|
||||
"output_field_name must not include a plan/table qualifier, got: {output_field_name}"
|
||||
);
|
||||
Self {
|
||||
output_field_name,
|
||||
merge_op,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum IncrementalAggregateMergeOp {
|
||||
Sum,
|
||||
@@ -89,6 +106,19 @@ pub struct IncrementalAggregateAnalysis {
|
||||
pub unsupported_exprs: Vec<String>,
|
||||
}
|
||||
|
||||
/// Visitor that captures the aggregate expressions from the **innermost**
|
||||
/// `Aggregate` node in the plan tree.
|
||||
///
|
||||
/// Since this visits `f_down` and continues recursion, it will overwrite
|
||||
/// `aggr_exprs` for each `Aggregate` it encounters, ultimately retaining the
|
||||
/// deepest (innermost) one. This is the intended behavior for the incremental
|
||||
/// aggregate rewrite: nested aggregates are not supported, and the delta plan
|
||||
/// produced by the flow engine places a single `Aggregate` at the bottom.
|
||||
///
|
||||
/// If the plan contains multiple nested `Aggregate` nodes (a subquery with its
|
||||
/// own aggregation), the innermost one is captured, which is conservative and
|
||||
/// safe — it prevents the rewrite from incorrectly operating on the outer
|
||||
/// aggregate.
|
||||
#[derive(Default)]
|
||||
struct LastAggregateExprFinder {
|
||||
aggr_exprs: Option<Vec<Expr>>,
|
||||
@@ -209,15 +239,20 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
continue;
|
||||
};
|
||||
|
||||
// qualified_name() returns (Option<String>, String) where the second
|
||||
// element is the unqualified column/alias name. This relies on
|
||||
// DataFusion's internal naming convention: aggregate expressions
|
||||
// emit a column named after the aggregate itself (e.g. "SUM(x)"),
|
||||
// which matches what the projection aliases reference.
|
||||
let raw_name = aggr_expr.qualified_name().1;
|
||||
let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
};
|
||||
merge_columns.push(IncrementalAggregateMergeColumn {
|
||||
merge_columns.push(IncrementalAggregateMergeColumn::new(
|
||||
output_field_name,
|
||||
merge_op,
|
||||
});
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(IncrementalAggregateAnalysis {
|
||||
@@ -1293,38 +1328,55 @@ mod test {
|
||||
async fn test_analyze_incremental_aggregate_plan() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let testcases = vec![
|
||||
let testcases: Vec<(&str, IncrementalAggregateMergeOp, &str)> = vec![
|
||||
(
|
||||
"SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::Sum,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::Sum,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::Min,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::Max,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::BitAnd,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::BitOr,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::BitXor,
|
||||
"number",
|
||||
),
|
||||
(
|
||||
"SELECT bool_and(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::BoolAnd,
|
||||
"cond",
|
||||
),
|
||||
(
|
||||
"SELECT bool_or(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts",
|
||||
IncrementalAggregateMergeOp::BoolOr,
|
||||
"cond",
|
||||
),
|
||||
];
|
||||
|
||||
for (sql, expected_op) in testcases {
|
||||
for (sql, expected_op, expected_field_name) in testcases {
|
||||
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1333,7 +1385,10 @@ mod test {
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert!(analysis.group_key_names.contains(&"ts".to_string()));
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(analysis.merge_columns[0].output_field_name, "number");
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].output_field_name,
|
||||
expected_field_name
|
||||
);
|
||||
assert_eq!(analysis.merge_columns[0].merge_op, expected_op);
|
||||
}
|
||||
}
|
||||
@@ -1359,6 +1414,26 @@ mod test {
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let sql = "SELECT sum(number) AS total, ts, number AS bucket FROM numbers_with_ts GROUP BY ts, number";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert!(analysis.group_key_names.contains(&"ts".to_string()));
|
||||
assert!(analysis.group_key_names.contains(&"bucket".to_string()));
|
||||
assert_eq!(analysis.group_key_names.len(), 2);
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(analysis.merge_columns[0].output_field_name, "total");
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_avg() {
|
||||
let query_engine = create_test_query_engine();
|
||||
@@ -1477,4 +1552,30 @@ mod test {
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_last_aggregate_finder_captures_innermost() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
// Subquery has an inner aggregate (count), outer query has another aggregate (sum).
|
||||
// LastAggregateExprFinder should capture the innermost one (count).
|
||||
let sql = "SELECT sum(cnt) AS total, ts \
|
||||
FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \
|
||||
GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let mut finder = LastAggregateExprFinder::default();
|
||||
plan.visit(&mut finder).unwrap();
|
||||
let aggr_exprs = finder.aggr_exprs.unwrap();
|
||||
assert_eq!(
|
||||
aggr_exprs.len(),
|
||||
1,
|
||||
"Expected innermost aggregate to have 1 expression"
|
||||
);
|
||||
let found_name = aggr_exprs[0].qualified_name().1.to_ascii_lowercase();
|
||||
assert!(
|
||||
found_name.contains("count"),
|
||||
"Expected innermost aggregate to be count, got: {found_name}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user