feat: inc query join rewrite helper

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-05-13 15:46:23 +08:00
parent d709fd29ef
commit efd98df923

View File

@@ -19,15 +19,21 @@ use std::sync::Arc;
use catalog::CatalogManagerRef;
use common_error::ext::BoxedError;
use common_function::aggrs::aggr_wrapper::get_aggr_func;
use common_telemetry::debug;
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result as DfResult;
use datafusion::logical_expr::Expr;
use datafusion::sql::unparser::Unparser;
use datafusion_common::tree_node::{
Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
use datafusion_expr::{Distinct, LogicalPlan, Projection};
use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference};
use datafusion_expr::logical_plan::TableScan;
use datafusion_expr::{
Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when,
};
use datatypes::schema::{ColumnSchema, SchemaRef};
use query::QueryEngineRef;
use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
@@ -37,12 +43,330 @@ use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;
use sql::statements::tql::Tql;
use table::TableRef;
use table::table::adapter::DfTableProviderAdapter;
use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
use crate::df_optimizer::apply_df_optimizer;
use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
use crate::{Error, TableName};
/// Describes how one aggregate output field should be merged with the
/// corresponding existing field in the sink table.
///
/// `output_field_name` is the final output/sink schema field name produced by
/// the delta plan and read from the sink table. It is not a DataFusion `Column`
/// reference and must not include a plan/table qualifier.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IncrementalAggregateMergeColumn {
/// Final output/sink field name for the aggregate result/state column.
pub output_field_name: String,
pub merge_op: IncrementalAggregateMergeOp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IncrementalAggregateMergeOp {
Sum,
Min,
Max,
BoolAnd,
BoolOr,
BitAnd,
BitOr,
BitXor,
}
/// Analysis result for an incremental aggregate plan.
///
/// `group_key_names` and each merge column's `output_field_name` are final
/// output/sink schema field names used to project both the delta plan and the
/// sink table before the left-join merge. They are not DataFusion logical-plan
/// `Column` references and must not be qualified.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IncrementalAggregateAnalysis {
/// Final output/sink field names for group keys used as merge join keys.
pub group_key_names: Vec<String>,
pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
pub unsupported_exprs: Vec<String>,
}
#[derive(Default)]
struct LastAggregateExprFinder {
aggr_exprs: Option<Vec<Expr>>,
}
impl TreeNodeVisitor<'_> for LastAggregateExprFinder {
type Node = LogicalPlan;
fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
if let LogicalPlan::Aggregate(aggregate) = node {
self.aggr_exprs = Some(aggregate.aggr_expr.clone());
}
Ok(TreeNodeRecursion::Continue)
}
}
pub fn analyze_incremental_aggregate_plan(
plan: &LogicalPlan,
) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
let mut group_finder = FindGroupByFinalName::default();
plan.visit(&mut group_finder)
.with_context(|_| DatafusionSnafu {
context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
})?;
let mut aggregate_finder = LastAggregateExprFinder::default();
plan.visit(&mut aggregate_finder)
.with_context(|_| DatafusionSnafu {
context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"),
})?;
let Some(aggr_exprs) = aggregate_finder.aggr_exprs else {
return Ok(None);
};
let mut output_aliases = HashMap::new();
if let LogicalPlan::Projection(projection) = plan {
for expr in &projection.expr {
match expr {
Expr::Alias(alias) => {
if let Expr::Column(col) = alias.expr.as_ref() {
output_aliases.insert(col.name.clone(), alias.name.clone());
}
}
Expr::Column(col) => {
output_aliases.insert(col.name.clone(), col.name.clone());
}
_ => {}
}
}
}
let mut group_key_names = group_finder
.get_group_expr_names()
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
group_key_names.sort();
let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
let mut unsupported_exprs = Vec::new();
for aggr_expr in aggr_exprs {
let Some(aggr_func) = get_aggr_func(&aggr_expr) else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
let aggr_name = aggr_func.func.name().to_ascii_lowercase();
let merge_op = if aggr_func.params.distinct {
None
} else {
match aggr_name.as_str() {
"sum" | "count" => Some(IncrementalAggregateMergeOp::Sum),
"min" => Some(IncrementalAggregateMergeOp::Min),
"max" => Some(IncrementalAggregateMergeOp::Max),
"bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd),
"bool_or" => Some(IncrementalAggregateMergeOp::BoolOr),
"bit_and" => Some(IncrementalAggregateMergeOp::BitAnd),
"bit_or" => Some(IncrementalAggregateMergeOp::BitOr),
"bit_xor" => Some(IncrementalAggregateMergeOp::BitXor),
_ => None,
}
};
let Some(merge_op) = merge_op else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
let raw_name = aggr_expr.qualified_name().1;
let output_field_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name);
merge_columns.push(IncrementalAggregateMergeColumn {
output_field_name,
merge_op,
});
}
Ok(Some(IncrementalAggregateAnalysis {
group_key_names,
merge_columns,
unsupported_exprs,
}))
}
pub async fn rewrite_incremental_aggregate_with_sink_merge(
delta_plan: &LogicalPlan,
analysis: &IncrementalAggregateAnalysis,
sink_table: TableRef,
sink_table_name: &TableName,
) -> Result<LogicalPlan, Error> {
ensure!(
analysis.unsupported_exprs.is_empty(),
InvalidQuerySnafu {
reason: format!(
"UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
analysis.unsupported_exprs
)
}
);
ensure!(
!analysis.merge_columns.is_empty(),
InvalidQuerySnafu {
reason:
"UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
.to_string()
}
);
let delta_alias = "__flow_delta";
let sink_alias = "__flow_sink";
let mut selected_columns = analysis.group_key_names.clone();
selected_columns.extend(
analysis
.merge_columns
.iter()
.map(|c| c.output_field_name.clone()),
);
let selected_exprs = selected_columns.iter().map(col).collect::<Vec<_>>();
let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
.project(selected_exprs.clone())
.with_context(|_| DatafusionSnafu {
context: "Failed to project delta plan for incremental sink merge".to_string(),
})?
.alias(delta_alias)
.with_context(|_| DatafusionSnafu {
context: "Failed to alias delta plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build projected delta plan for incremental sink merge".to_string(),
})?;
let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
let table_source = Arc::new(DefaultTableSource::new(table_provider));
let sink_scan = LogicalPlan::TableScan(
TableScan::try_new(
TableReference::Full {
catalog: sink_table_name[0].clone().into(),
schema: sink_table_name[1].clone().into(),
table: sink_table_name[2].clone().into(),
},
table_source,
None,
vec![],
None,
)
.with_context(|_| DatafusionSnafu {
context: "Failed to build sink table scan for incremental sink merge".to_string(),
})?,
);
let sink_selected = LogicalPlanBuilder::from(sink_scan)
.project(selected_exprs)
.with_context(|_| DatafusionSnafu {
context: "Failed to project sink table scan for incremental sink merge".to_string(),
})?
.alias(sink_alias)
.with_context(|_| DatafusionSnafu {
context: "Failed to alias sink plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build projected sink plan for incremental sink merge".to_string(),
})?;
let join_keys = (
analysis
.group_key_names
.iter()
.map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}")))
.collect::<Vec<_>>(),
analysis
.group_key_names
.iter()
.map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}")))
.collect::<Vec<_>>(),
);
let joined = LogicalPlanBuilder::from(delta_selected)
.join(sink_selected, JoinType::Left, join_keys, None)
.with_context(|_| DatafusionSnafu {
context: "Failed to left join delta and sink plans for incremental sink merge"
.to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build left join plan for incremental sink merge".to_string(),
})?;
let mut projection_exprs = analysis
.group_key_names
.iter()
.map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone()))
.collect::<Vec<_>>();
projection_exprs.extend(
analysis
.merge_columns
.iter()
.map(|merge_col| build_left_join_merge_expr(delta_alias, sink_alias, merge_col)),
);
LogicalPlanBuilder::from(joined)
.project(projection_exprs)
.with_context(|_| DatafusionSnafu {
context: "Failed to build projection merge plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
})
}
fn build_left_join_merge_expr(
delta_alias: &str,
sink_alias: &str,
merge_col: &IncrementalAggregateMergeColumn,
) -> Expr {
let left = col(format!("{delta_alias}.{}", merge_col.output_field_name));
let right = col(format!("{sink_alias}.{}", merge_col.output_field_name));
let merged = match merge_col.merge_op {
IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
.unwrap(),
IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
.when(left.clone().lt_eq(right.clone()), left.clone())
.otherwise(right.clone())
.unwrap(),
IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
.when(left.clone().gt_eq(right.clone()), left.clone())
.otherwise(right.clone())
.unwrap(),
IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(and(left.clone(), right.clone()))
.unwrap(),
IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(or(left.clone(), right.clone()))
.unwrap(),
IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_and(left.clone(), right.clone()))
.unwrap(),
IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_or(left.clone(), right.clone()))
.unwrap(),
IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_xor(left.clone(), right.clone()))
.unwrap(),
};
merged.alias(merge_col.output_field_name.clone())
}
pub async fn get_table_info_df_schema(
catalog_mr: CatalogManagerRef,
table_name: TableName,
@@ -907,6 +1231,136 @@ mod test {
}
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let testcases = vec![
(
"SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Sum,
),
(
"SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Sum,
),
(
"SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Min,
),
(
"SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::Max,
),
(
"SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitAnd,
),
(
"SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitOr,
),
(
"SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts",
IncrementalAggregateMergeOp::BitXor,
),
];
for (sql, expected_op) in testcases {
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
.await
.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_key_names.contains(&"ts".to_string()));
assert_eq!(analysis.merge_columns.len(), 1);
assert_eq!(analysis.merge_columns[0].output_field_name, "number");
assert_eq!(analysis.merge_columns[0].merge_op, expected_op);
}
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_keeps_aliases_for_multiple_aggregates() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT max(number) AS max_number, min(number) AS min_number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_key_names.contains(&"ts".to_string()));
assert_eq!(analysis.merge_columns.len(), 2);
assert!(analysis.merge_columns.iter().any(|merge_col| {
merge_col.output_field_name == "max_number"
&& merge_col.merge_op == IncrementalAggregateMergeOp::Max
}));
assert!(analysis.merge_columns.iter().any(|merge_col| {
merge_col.output_field_name == "min_number"
&& merge_col.merge_op == IncrementalAggregateMergeOp::Min
}));
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_rejects_avg() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(!analysis.unsupported_exprs.is_empty());
}
#[tokio::test]
async fn test_analyze_incremental_aggregate_plan_rejects_distinct() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT count(distinct number) AS number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
assert!(!analysis.unsupported_exprs.is_empty());
}
#[tokio::test]
async fn test_rewrite_incremental_aggregate_with_left_join() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
.await
.unwrap();
let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap();
let (sink_table, _) = get_table_info_df_schema(
query_engine.engine_state().catalog_manager().clone(),
[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
)
.await
.unwrap();
let rewritten = rewrite_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
assert!(!plan_text.contains("Union"));
}
#[tokio::test]
async fn test_null_cast() {
let query_engine = create_test_query_engine();