test: add expected plan test

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-05-19 22:14:37 +08:00
parent d4e9ec264b
commit 90a119cead

View File

@@ -199,6 +199,7 @@ fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<Option<Vec<Expr>>, Error>
struct OutputProjectionInfo {
has_top_level_projection: bool,
output_aliases: HashMap<String, String>,
duplicate_aggregate_aliases: BTreeSet<String>,
literal_columns: HashSet<String>,
output_field_names: Vec<String>,
}
@@ -251,7 +252,15 @@ fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
}
1 => {
if let Some(col_name) = col_names.into_iter().next() {
output_aliases.entry(col_name).or_insert(alias_name);
if let Some(existing_alias) = output_aliases.get(&col_name) {
if existing_alias != &alias_name {
projection_info.duplicate_aggregate_aliases.insert(format!(
"same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
));
}
} else {
output_aliases.insert(col_name, alias_name);
}
}
}
_ => {}
@@ -358,6 +367,7 @@ pub fn analyze_incremental_aggregate_plan(
.into_iter()
.map(|name| format!("duplicate output field name: {name}"))
.collect::<Vec<_>>();
unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
if group_key_names.is_empty()
&& projection_info
.output_field_names
@@ -391,6 +401,9 @@ pub fn analyze_incremental_aggregate_plan(
.into_iter()
.map(|name| format!("unsupported output field: {name}")),
);
if !unsupported_exprs.is_empty() {
merge_columns.clear();
}
let mut literal_columns = projection_info
.literal_columns
.into_iter()
@@ -406,6 +419,36 @@ pub fn analyze_incremental_aggregate_plan(
}))
}
/// Rewrites one incremental aggregate delta plan by left-joining it with the
/// existing sink-table state and projecting merged aggregate outputs.
///
/// For a grouped aggregate such as:
///
/// ```text
/// SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts
/// ```
///
/// the rewrite is roughly:
///
/// ```text
/// delta = SELECT ts, number FROM <delta_plan> AS __flow_delta
/// sink = SELECT ts, number FROM <sink_table> AS __flow_sink
/// SELECT
/// CASE
/// WHEN __flow_sink.number IS NULL THEN __flow_delta.number
/// WHEN __flow_delta.number >= __flow_sink.number THEN __flow_delta.number
/// ELSE __flow_sink.number
/// END AS number,
/// __flow_delta.ts AS ts
/// FROM delta
/// LEFT JOIN sink
/// ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts
/// ```
///
/// For a global aggregate without group keys, DataFusion still requires a
/// non-empty join condition. We add `__flow_global_aggregate_join_key = 1` to
/// both sides and join on it. This relies on the global aggregate sink keeping a
/// single state row; multiple sink rows would fan out the single delta row.
pub async fn rewrite_incremental_aggregate_with_sink_merge(
delta_plan: &LogicalPlan,
analysis: &IncrementalAggregateAnalysis,
@@ -1249,6 +1292,95 @@ mod test {
u32_table(table_name, columns, 0)
}
fn assert_same_logical_plan(actual: &LogicalPlan, expected: &LogicalPlan) {
assert_eq!(
format!("{}", expected.display_indent()),
format!("{}", actual.display_indent())
);
}
fn test_sink_scan(sink_table: TableRef, sink_table_name: &TableName) -> LogicalPlan {
let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
let table_source = Arc::new(DefaultTableSource::new(table_provider));
LogicalPlan::TableScan(
TableScan::try_new(
TableReference::Full {
catalog: sink_table_name[0].clone().into(),
schema: sink_table_name[1].clone().into(),
table: sink_table_name[2].clone().into(),
},
table_source,
None,
vec![],
None,
)
.unwrap(),
)
}
fn expected_left_join_rewrite(
delta_plan: &LogicalPlan,
sink_table: TableRef,
sink_table_name: &TableName,
delta_selected_exprs: Vec<Expr>,
sink_selected_exprs: Vec<Expr>,
join_keys: (Vec<Column>, Vec<Column>),
projection_exprs: Vec<Expr>,
) -> LogicalPlan {
let delta_alias = "__flow_delta";
let sink_alias = "__flow_sink";
let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
.project(delta_selected_exprs)
.unwrap()
.alias(delta_alias)
.unwrap()
.build()
.unwrap();
let sink_selected = LogicalPlanBuilder::from(test_sink_scan(sink_table, sink_table_name))
.project(sink_selected_exprs)
.unwrap()
.alias(sink_alias)
.unwrap()
.build()
.unwrap();
let joined = LogicalPlanBuilder::from(delta_selected)
.join_detailed(
sink_selected,
JoinType::Left,
join_keys,
None,
NullEquality::NullEqualsNull,
)
.unwrap()
.build()
.unwrap();
LogicalPlanBuilder::from(joined)
.project(projection_exprs)
.unwrap()
.build()
.unwrap()
}
fn max_merge_expr(field_name: &str) -> Expr {
let left = qualified_col("__flow_delta", field_name);
let right = qualified_col("__flow_sink", field_name);
when(is_null(right.clone()), left.clone())
.when(left.clone().gt_eq(right.clone()), left)
.otherwise(right)
.unwrap()
.alias(field_name)
}
fn sum_merge_expr(field_name: &str) -> Expr {
let left = qualified_col("__flow_delta", field_name);
let right = qualified_col("__flow_sink", field_name);
when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(binary_expr(left, Operator::Plus, right))
.unwrap()
.alias(field_name)
}
/// test if uppercase are handled correctly(with quote)
#[tokio::test]
async fn test_sql_plan_convert() {
@@ -1752,25 +1884,22 @@ mod test {
vec!["number".to_string(), "ts".to_string(), "lit".to_string()]
);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
];
let (sink_table, _) = get_table_info_df_schema(
query_engine.engine_state().catalog_manager().clone(),
[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
sink_table_name.clone(),
)
.await
.unwrap();
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
@@ -1782,6 +1911,27 @@ mod test {
.map(|field| field.name().clone())
.collect::<Vec<_>>();
assert_eq!(rewritten_fields, analysis.output_field_names);
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("ts"),
unqualified_col("number"),
unqualified_col("lit"),
],
vec![unqualified_col("ts"), unqualified_col("number")],
(
vec![qualified_column("__flow_delta", "ts")],
vec![qualified_column("__flow_sink", "ts")],
),
vec![
max_merge_expr("number"),
qualified_col("__flow_delta", "ts").alias("ts"),
qualified_col("__flow_delta", "lit").alias("lit"),
],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]
@@ -1813,6 +1963,54 @@ mod test {
analysis.output_field_names,
vec!["number".to_string(), "label".to_string()]
);
let sink_table = single_row_u32_table("string_literal_sink", vec!["number"]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"string_literal_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
assert_eq!(
rewritten
.schema()
.fields()
.iter()
.map(|field| field.name().clone())
.collect::<Vec<_>>(),
vec!["number".to_string(), "label".to_string()]
);
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("number"),
unqualified_col("label"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
(
vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)],
vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)],
),
vec![
max_merge_expr("number"),
qualified_col("__flow_delta", "label").alias("label"),
],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]
@@ -1924,10 +2122,13 @@ mod test {
analysis
.unsupported_exprs
.iter()
.any(|expr| expr.contains("unsupported output field: b")),
.any(|expr| expr.contains("same aggregate output")
&& expr.contains("a")
&& expr.contains("b")),
"same aggregate with multiple aliases should be unsupported until explicit reproduction is implemented: {:?}",
analysis.unsupported_exprs
);
assert!(analysis.merge_columns.is_empty());
}
#[test]
@@ -2013,13 +2214,14 @@ mod test {
.await
.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
];
let (sink_table, _) = get_table_info_df_schema(
query_engine.engine_state().catalog_manager().clone(),
[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
sink_table_name.clone(),
)
.await
.unwrap();
@@ -2027,19 +2229,28 @@ mod test {
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
assert!(!plan_text.contains("Union"));
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![unqualified_col("ts"), unqualified_col("number")],
vec![unqualified_col("ts"), unqualified_col("number")],
(
vec![qualified_column("__flow_delta", "ts")],
vec![qualified_column("__flow_sink", "ts")],
),
vec![
max_merge_expr("number"),
qualified_col("__flow_delta", "ts").alias("ts"),
],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]
@@ -2054,21 +2265,39 @@ mod test {
assert_eq!(analysis.merge_columns.len(), 1);
let sink_table = single_row_u32_table("global_sink", vec!["number"]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"global_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"global_sink".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
(
vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)],
vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)],
),
vec![max_merge_expr("number")],
);
assert_same_logical_plan(&rewritten, &expected);
assert_eq!(
rewritten
.schema()
@@ -2090,22 +2319,39 @@ mod test {
assert!(analysis.unsupported_exprs.is_empty());
let sink_table = empty_u32_table("empty_global_sink", vec!["number"]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"empty_global_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"empty_global_sink".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
assert!(plan_text.contains(GLOBAL_AGGREGATE_JOIN_KEY));
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
(
vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)],
vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)],
),
vec![max_merge_expr("number")],
);
assert_same_logical_plan(&rewritten, &expected);
assert_eq!(
rewritten
.schema()
@@ -2128,15 +2374,16 @@ mod test {
assert_eq!(analysis.literal_columns, vec!["lit".to_string()]);
let sink_table = single_row_u32_table("global_literal_sink", vec!["number"]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"global_literal_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"global_literal_sink".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
@@ -2150,6 +2397,29 @@ mod test {
.collect::<Vec<_>>(),
vec!["number".to_string(), "lit".to_string()]
);
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("number"),
unqualified_col("lit"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
vec![
unqualified_col("number"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
(
vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)],
vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)],
),
vec![
max_merge_expr("number"),
qualified_col("__flow_delta", "lit").alias("lit"),
],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]
@@ -2163,15 +2433,16 @@ mod test {
assert_eq!(analysis.merge_columns.len(), 2);
let sink_table = single_row_u32_table("global_multi_merge_sink", vec!["cnt", "total"]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"global_multi_merge_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"global_multi_merge_sink".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
@@ -2185,6 +2456,27 @@ mod test {
.collect::<Vec<_>>(),
vec!["cnt".to_string(), "total".to_string()]
);
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![
unqualified_col("cnt"),
unqualified_col("total"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
vec![
unqualified_col("cnt"),
unqualified_col("total"),
lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY),
],
(
vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)],
vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)],
),
vec![sum_merge_expr("cnt"), sum_merge_expr("total")],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]
@@ -2198,15 +2490,16 @@ mod test {
let raw_field_name = "max(numbers_with_ts.number)";
let sink_table = single_row_u32_table("raw_aggregate_sink", vec!["number", raw_field_name]);
let sink_table_name = [
"greptime".to_string(),
"public".to_string(),
"raw_aggregate_sink".to_string(),
];
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"raw_aggregate_sink".to_string(),
],
sink_table.clone(),
&sink_table_name,
)
.await
.unwrap();
@@ -2218,8 +2511,22 @@ mod test {
.map(|field| field.name().clone())
.collect::<Vec<_>>();
assert!(rewritten_fields.contains(&raw_field_name.to_string()));
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains(raw_field_name));
let expected = expected_left_join_rewrite(
&plan,
sink_table,
&sink_table_name,
vec![unqualified_col("number"), unqualified_col(raw_field_name)],
vec![unqualified_col("number"), unqualified_col(raw_field_name)],
(
vec![qualified_column("__flow_delta", "number")],
vec![qualified_column("__flow_sink", "number")],
),
vec![
max_merge_expr(raw_field_name),
qualified_col("__flow_delta", "number").alias("number"),
],
);
assert_same_logical_plan(&rewritten, &expected);
}
#[tokio::test]