mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-21 07:20:41 +00:00
@@ -28,11 +28,13 @@ use datafusion::sql::unparser::Unparser;
|
||||
use datafusion_common::tree_node::{
|
||||
Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
|
||||
};
|
||||
use datafusion_common::{DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference};
|
||||
use datafusion_common::{
|
||||
Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
|
||||
};
|
||||
use datafusion_expr::logical_plan::TableScan;
|
||||
use datafusion_expr::{
|
||||
Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
|
||||
bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when,
|
||||
bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
|
||||
};
|
||||
use datatypes::schema::{ColumnSchema, SchemaRef};
|
||||
use query::QueryEngineRef;
|
||||
@@ -55,24 +57,20 @@ use crate::{Error, TableName};
|
||||
///
|
||||
/// `output_field_name` is the final output/sink schema field name produced by
|
||||
/// the delta plan and read from the sink table. It is not a DataFusion `Column`
|
||||
/// reference and must not include a plan/table qualifier.
|
||||
/// reference. It may contain dots or other non-identifier characters when the
|
||||
/// query keeps DataFusion's raw aggregate output name, e.g.
|
||||
/// `max(numbers_with_ts.number)`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct IncrementalAggregateMergeColumn {
|
||||
/// Final output/sink field name for the aggregate result/state column.
|
||||
///
|
||||
/// Must NOT include a plan/table qualifier (no `.` separator).
|
||||
pub output_field_name: String,
|
||||
pub merge_op: IncrementalAggregateMergeOp,
|
||||
}
|
||||
|
||||
impl IncrementalAggregateMergeColumn {
|
||||
/// Create a new merge column, validating that `output_field_name` does not
|
||||
/// contain a plan/table qualifier.
|
||||
/// Create a new merge column.
|
||||
pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
|
||||
debug_assert!(
|
||||
!output_field_name.contains('.'),
|
||||
"output_field_name must not include a plan/table qualifier, got: {output_field_name}"
|
||||
);
|
||||
Self {
|
||||
output_field_name,
|
||||
merge_op,
|
||||
@@ -97,7 +95,8 @@ pub enum IncrementalAggregateMergeOp {
|
||||
/// `group_key_names` and each merge column's `output_field_name` are final
|
||||
/// output/sink schema field names used to project both the delta plan and the
|
||||
/// sink table before the left-join merge. They are not DataFusion logical-plan
|
||||
/// `Column` references and must not be qualified.
|
||||
/// `Column` references; callers must attach qualifiers structurally instead of
|
||||
/// formatting qualified names as strings.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct IncrementalAggregateAnalysis {
|
||||
/// Final output/sink field names for group keys used as merge join keys.
|
||||
@@ -152,25 +151,46 @@ fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn analyze_incremental_aggregate_plan(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
|
||||
fn unqualified_col(name: impl Into<String>) -> Expr {
|
||||
Expr::Column(Column::from_name(name.into()))
|
||||
}
|
||||
|
||||
fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
|
||||
Expr::Column(Column::new(Some(qualifier), name.into()))
|
||||
}
|
||||
|
||||
fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
|
||||
Column::new(Some(qualifier), name.into())
|
||||
}
|
||||
|
||||
fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
|
||||
let mut group_finder = FindGroupByFinalName::default();
|
||||
plan.visit(&mut group_finder)
|
||||
.with_context(|_| DatafusionSnafu {
|
||||
context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
|
||||
})?;
|
||||
|
||||
let mut group_key_names = group_finder
|
||||
.get_group_expr_names()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>();
|
||||
group_key_names.sort();
|
||||
Ok(group_key_names)
|
||||
}
|
||||
|
||||
fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<Option<Vec<Expr>>, Error> {
|
||||
let mut aggregate_finder = LastAggregateExprFinder::default();
|
||||
plan.visit(&mut aggregate_finder)
|
||||
.with_context(|_| DatafusionSnafu {
|
||||
context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"),
|
||||
})?;
|
||||
let Some(aggr_exprs) = aggregate_finder.aggr_exprs else {
|
||||
return Ok(None);
|
||||
};
|
||||
Ok(aggregate_finder.aggr_exprs)
|
||||
}
|
||||
|
||||
fn collect_output_aliases(plan: &LogicalPlan) -> (bool, HashMap<String, String>, HashSet<String>) {
|
||||
let mut output_aliases = HashMap::new();
|
||||
let has_top_level_projection = matches!(plan, LogicalPlan::Projection(_));
|
||||
if let LogicalPlan::Projection(projection) = plan {
|
||||
for expr in &projection.expr {
|
||||
match expr {
|
||||
@@ -202,50 +222,79 @@ pub fn analyze_incremental_aggregate_plan(
|
||||
}
|
||||
}
|
||||
|
||||
let mut group_key_names = group_finder
|
||||
.get_group_expr_names()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>();
|
||||
group_key_names.sort();
|
||||
let output_field_names = plan
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|field| field.name().clone())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
(has_top_level_projection, output_aliases, output_field_names)
|
||||
}
|
||||
|
||||
fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Option<IncrementalAggregateMergeOp> {
|
||||
let aggr_func = get_aggr_func(aggr_expr)?;
|
||||
if aggr_func.params.distinct {
|
||||
return None;
|
||||
}
|
||||
|
||||
match aggr_func.func.name().to_ascii_lowercase().as_str() {
|
||||
"sum" | "count" => Some(IncrementalAggregateMergeOp::Sum),
|
||||
"min" => Some(IncrementalAggregateMergeOp::Min),
|
||||
"max" => Some(IncrementalAggregateMergeOp::Max),
|
||||
"bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd),
|
||||
"bool_or" => Some(IncrementalAggregateMergeOp::BoolOr),
|
||||
"bit_and" => Some(IncrementalAggregateMergeOp::BitAnd),
|
||||
"bit_or" => Some(IncrementalAggregateMergeOp::BitOr),
|
||||
"bit_xor" => Some(IncrementalAggregateMergeOp::BitXor),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_aggregate_output_field_name(
|
||||
aggr_expr: &Expr,
|
||||
has_top_level_projection: bool,
|
||||
output_aliases: &HashMap<String, String>,
|
||||
output_field_names: &HashSet<String>,
|
||||
) -> Option<String> {
|
||||
// qualified_name() returns (Option<String>, String) where the second
|
||||
// element is the unqualified column/alias name. This relies on
|
||||
// DataFusion's internal naming convention: aggregate expressions
|
||||
// emit a column named after the aggregate itself (e.g. "SUM(x)"),
|
||||
// which matches what the projection aliases reference.
|
||||
let raw_name = aggr_expr.qualified_name().1;
|
||||
if let Some(alias) = output_aliases.get(&raw_name) {
|
||||
Some(alias.clone())
|
||||
} else if !has_top_level_projection && output_field_names.contains(&raw_name) {
|
||||
Some(raw_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn analyze_incremental_aggregate_plan(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
|
||||
let group_key_names = find_group_key_names(plan)?;
|
||||
let Some(aggr_exprs) = find_aggregate_exprs(plan)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let (has_top_level_projection, output_aliases, output_field_names) =
|
||||
collect_output_aliases(plan);
|
||||
|
||||
let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
|
||||
let mut unsupported_exprs = Vec::new();
|
||||
for aggr_expr in aggr_exprs {
|
||||
let Some(aggr_func) = get_aggr_func(&aggr_expr) else {
|
||||
let Some(merge_op) = merge_op_for_aggregate_expr(&aggr_expr) else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
};
|
||||
|
||||
let aggr_name = aggr_func.func.name().to_ascii_lowercase();
|
||||
let merge_op = if aggr_func.params.distinct {
|
||||
None
|
||||
} else {
|
||||
match aggr_name.as_str() {
|
||||
"sum" | "count" => Some(IncrementalAggregateMergeOp::Sum),
|
||||
"min" => Some(IncrementalAggregateMergeOp::Min),
|
||||
"max" => Some(IncrementalAggregateMergeOp::Max),
|
||||
"bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd),
|
||||
"bool_or" => Some(IncrementalAggregateMergeOp::BoolOr),
|
||||
"bit_and" => Some(IncrementalAggregateMergeOp::BitAnd),
|
||||
"bit_or" => Some(IncrementalAggregateMergeOp::BitOr),
|
||||
"bit_xor" => Some(IncrementalAggregateMergeOp::BitXor),
|
||||
_ => None,
|
||||
}
|
||||
};
|
||||
|
||||
let Some(merge_op) = merge_op else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
};
|
||||
|
||||
// qualified_name() returns (Option<String>, String) where the second
|
||||
// element is the unqualified column/alias name. This relies on
|
||||
// DataFusion's internal naming convention: aggregate expressions
|
||||
// emit a column named after the aggregate itself (e.g. "SUM(x)"),
|
||||
// which matches what the projection aliases reference.
|
||||
let raw_name = aggr_expr.qualified_name().1;
|
||||
let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else {
|
||||
let Some(output_field_name) = resolve_aggregate_output_field_name(
|
||||
&aggr_expr,
|
||||
has_top_level_projection,
|
||||
&output_aliases,
|
||||
&output_field_names,
|
||||
) else {
|
||||
unsupported_exprs.push(aggr_expr.to_string());
|
||||
continue;
|
||||
};
|
||||
@@ -298,7 +347,11 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge(
|
||||
.map(|c| c.output_field_name.clone()),
|
||||
);
|
||||
|
||||
let selected_exprs = selected_columns.iter().map(col).collect::<Vec<_>>();
|
||||
let selected_exprs = selected_columns
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(unqualified_col)
|
||||
.collect::<Vec<_>>();
|
||||
let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
|
||||
.project(selected_exprs.clone())
|
||||
.with_context(|_| DatafusionSnafu {
|
||||
@@ -350,12 +403,14 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge(
|
||||
analysis
|
||||
.group_key_names
|
||||
.iter()
|
||||
.map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}")))
|
||||
.cloned()
|
||||
.map(|c| qualified_column(delta_alias, c))
|
||||
.collect::<Vec<_>>(),
|
||||
analysis
|
||||
.group_key_names
|
||||
.iter()
|
||||
.map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}")))
|
||||
.cloned()
|
||||
.map(|c| qualified_column(sink_alias, c))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
@@ -379,7 +434,8 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge(
|
||||
let mut projection_exprs = analysis
|
||||
.group_key_names
|
||||
.iter()
|
||||
.map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone()))
|
||||
.cloned()
|
||||
.map(|c| qualified_col(delta_alias, c.clone()).alias(c))
|
||||
.collect::<Vec<_>>();
|
||||
for merge_col in &analysis.merge_columns {
|
||||
projection_exprs.push(build_left_join_merge_expr(
|
||||
@@ -405,8 +461,8 @@ fn build_left_join_merge_expr(
|
||||
sink_alias: &str,
|
||||
merge_col: &IncrementalAggregateMergeColumn,
|
||||
) -> Result<Expr, Error> {
|
||||
let left = col(format!("{delta_alias}.{}", merge_col.output_field_name));
|
||||
let right = col(format!("{sink_alias}.{}", merge_col.output_field_name));
|
||||
let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
|
||||
let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
|
||||
let merged = match merge_col.merge_op {
|
||||
IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
|
||||
.when(is_null(right.clone()), left.clone())
|
||||
@@ -1414,6 +1470,117 @@ mod test {
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let sql = "SELECT max(number), ts FROM numbers_with_ts GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].output_field_name,
|
||||
"max(numbers_with_ts.number)"
|
||||
);
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Max
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_wrapper_aliased_as_raw_name() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let sql = r#"SELECT COALESCE(max(number), 0) AS "max(numbers_with_ts.number)", ts FROM numbers_with_ts GROUP BY ts"#;
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(
|
||||
!analysis.unsupported_exprs.is_empty(),
|
||||
"wrapper aliased to a raw aggregate field name must not bypass analysis"
|
||||
);
|
||||
assert!(analysis.merge_columns.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_supports_count_star() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let sql = "SELECT count(*) AS wildcard, ts FROM numbers_with_ts GROUP BY ts";
|
||||
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
|
||||
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(analysis.unsupported_exprs.is_empty());
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(analysis.merge_columns[0].output_field_name, "wildcard");
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_supports_aggregate_input_exprs() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let testcases = [
|
||||
"SELECT sum(abs(number)) AS sum_abs, ts FROM numbers_with_ts GROUP BY ts",
|
||||
"SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) AS above_five, ts FROM numbers_with_ts GROUP BY ts",
|
||||
];
|
||||
|
||||
for sql in testcases {
|
||||
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
|
||||
.await
|
||||
.unwrap();
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(
|
||||
analysis.unsupported_exprs.is_empty(),
|
||||
"aggregate input expressions should be mergeable for SQL: {sql}"
|
||||
);
|
||||
assert_eq!(analysis.merge_columns.len(), 1);
|
||||
assert_eq!(
|
||||
analysis.merge_columns[0].merge_op,
|
||||
IncrementalAggregateMergeOp::Sum
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_rejects_output_expr_wrappers() {
|
||||
let query_engine = create_test_query_engine();
|
||||
let ctx = QueryContext::arc();
|
||||
let testcases = [
|
||||
"SELECT abs(sum(number)) AS abs_sum, ts FROM numbers_with_ts GROUP BY ts",
|
||||
"SELECT max(number) - min(number) AS maxmin, ts FROM numbers_with_ts GROUP BY ts",
|
||||
"SELECT count(number) + 123 AS total_count, ts FROM numbers_with_ts GROUP BY ts",
|
||||
"SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) / count(number) AS ratio, ts FROM numbers_with_ts GROUP BY ts",
|
||||
];
|
||||
|
||||
for sql in testcases {
|
||||
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
|
||||
.await
|
||||
.unwrap();
|
||||
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
|
||||
assert!(
|
||||
!analysis.unsupported_exprs.is_empty(),
|
||||
"aggregate output wrappers should be rejected for SQL: {sql}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualified_col_preserves_non_identifier_field_name() {
|
||||
let expr = qualified_col("__flow_delta", "max(numbers_with_ts.number)");
|
||||
let Expr::Column(column) = expr else {
|
||||
panic!("expected column expression");
|
||||
};
|
||||
assert_eq!(column.name, "max(numbers_with_ts.number)");
|
||||
assert_eq!(column.relation.unwrap().to_string(), "__flow_delta");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() {
|
||||
let query_engine = create_test_query_engine();
|
||||
|
||||
Reference in New Issue
Block a user