1use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32 Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36 Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
37 bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::schema::{ColumnSchema, SchemaRef};
40use query::QueryEngineRef;
41use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
42use session::context::QueryContextRef;
43use snafu::{OptionExt, ResultExt, ensure};
44use sql::parser::{ParseOptions, ParserContext};
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47use table::TableRef;
48use table::table::adapter::DfTableProviderAdapter;
49
50use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
51use crate::df_optimizer::apply_df_optimizer;
52use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
53use crate::{Error, TableName};
54
55#[cfg(test)]
56mod test;
57
58#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct IncrementalAggregateMergeColumn {
68 pub output_field_name: String,
71 pub merge_op: IncrementalAggregateMergeOp,
72}
73
74impl IncrementalAggregateMergeColumn {
75 pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
77 Self {
78 output_field_name,
79 merge_op,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum IncrementalAggregateMergeOp {
86 Sum,
87 Min,
88 Max,
89 BoolAnd,
90 BoolOr,
91 BitAnd,
92 BitOr,
93 BitXor,
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct IncrementalAggregateAnalysis {
105 pub group_key_names: Vec<String>,
107 pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
108 pub literal_columns: Vec<String>,
110 pub output_field_names: Vec<String>,
112 pub unsupported_exprs: Vec<String>,
113}
114
115fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
124 match expr {
125 Expr::Column(col) => {
126 names.push(col.name.clone());
127 }
128 Expr::Alias(alias) => find_column_names(&alias.expr, names),
129 _ => {}
130 }
131}
132
133fn unqualified_col(name: impl Into<String>) -> Expr {
134 Expr::Column(Column::from_name(name.into()))
135}
136
137fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
138 Expr::Column(Column::new(Some(qualifier), name.into()))
139}
140
141fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
142 Column::new(Some(qualifier), name.into())
143}
144
145fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
146 let mut group_finder = FindGroupByFinalName::default();
147 plan.visit(&mut group_finder)
148 .with_context(|_| DatafusionSnafu {
149 context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
150 })?;
151
152 let mut group_key_names = group_finder
153 .get_group_expr_names()
154 .unwrap_or_default()
155 .into_iter()
156 .collect::<Vec<_>>();
157 group_key_names.sort();
158 Ok(group_key_names)
159}
160
161fn has_grouping_set(plan: &LogicalPlan) -> bool {
162 match plan {
163 LogicalPlan::Aggregate(aggregate) => aggregate
164 .group_expr
165 .iter()
166 .any(|expr| matches!(expr, Expr::GroupingSet(_))),
167 _ => plan.inputs().into_iter().any(has_grouping_set),
168 }
169}
170
171fn has_aggregate(plan: &LogicalPlan) -> bool {
172 match plan {
173 LogicalPlan::Aggregate(_) => true,
174 _ => plan.inputs().into_iter().any(has_aggregate),
175 }
176}
177
178fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
179 while let LogicalPlan::SubqueryAlias(alias) = plan {
180 plan = alias.input.as_ref();
181 }
182 plan
183}
184
185fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
186 let plan = match plan {
190 LogicalPlan::Projection(projection) => projection.input.as_ref(),
191 _ => plan,
192 };
193
194 match plan {
195 LogicalPlan::Aggregate(aggregate) => {
196 check_input_plan_shape(aggregate.input.as_ref())?;
197 Ok(Some(aggregate))
198 }
199 LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
200 "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
201 .to_string(),
202 ),
203 _ if has_aggregate(plan) => Err(
204 "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
205 ),
206 _ => Ok(None),
207 }
208}
209
210fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
211 let plan = peel_subquery_aliases(plan);
212 match plan {
213 LogicalPlan::TableScan(_) => Ok(()),
216 LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
217 LogicalPlan::TableScan(_) => Ok(()),
218 _ => Err(
219 "unsupported aggregate input plan shape in incremental aggregate rewrite"
220 .to_string(),
221 ),
222 },
223 _ => Err(
224 "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
225 ),
226 }
227}
228
229#[derive(Debug, Default)]
230struct OutputProjectionInfo {
231 has_top_level_projection: bool,
232 output_aliases: HashMap<String, String>,
233 duplicate_aggregate_aliases: BTreeSet<String>,
234 literal_columns: HashSet<String>,
235 output_field_names: Vec<String>,
236}
237
238impl OutputProjectionInfo {
239 fn output_field_name_set(&self) -> HashSet<String> {
240 self.output_field_names.iter().cloned().collect()
241 }
242
243 fn duplicate_output_names(&self) -> Vec<String> {
244 let mut seen = HashSet::new();
245 let mut duplicates = BTreeSet::new();
246 for name in &self.output_field_names {
247 if !seen.insert(name.clone()) {
248 duplicates.insert(name.clone());
249 }
250 }
251 duplicates.into_iter().collect()
252 }
253}
254
255fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
256 let mut projection_info = OutputProjectionInfo {
257 has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
258 output_field_names: plan
259 .schema()
260 .fields()
261 .iter()
262 .map(|field| field.name().clone())
263 .collect(),
264 ..Default::default()
265 };
266
267 let mut output_aliases = HashMap::new();
268 if let LogicalPlan::Projection(projection) = plan {
269 for expr in &projection.expr {
270 match expr {
271 Expr::Alias(alias) => {
272 let alias_name = alias.name.clone();
278 let mut col_names = Vec::new();
279 find_column_names(&alias.expr, &mut col_names);
280 match col_names.len() {
281 0 if matches!(alias.expr.as_ref(), Expr::Literal(_, _)) => {
282 projection_info.literal_columns.insert(alias_name);
283 }
284 1 => {
285 if let Some(col_name) = col_names.into_iter().next() {
286 if let Some(existing_alias) = output_aliases.get(&col_name) {
287 if existing_alias != &alias_name {
288 projection_info.duplicate_aggregate_aliases.insert(format!(
289 "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
290 ));
291 }
292 } else {
293 output_aliases.insert(col_name, alias_name);
294 }
295 }
296 }
297 _ => {}
298 }
299
300 }
303 Expr::Column(col) => {
304 output_aliases
305 .entry(col.name.clone())
306 .or_insert(col.name.clone());
307 }
308 Expr::Literal(_, _) => {
309 projection_info
310 .literal_columns
311 .insert(expr.qualified_name().1);
312 }
313 _ => {}
314 }
315 }
316 }
317
318 projection_info.output_aliases = output_aliases;
319 projection_info
320}
321
322fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
323 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
324 return Err(aggr_expr.to_string());
325 };
326 if aggr_func.params.distinct {
327 return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
328 }
329 if !aggr_func.params.order_by.is_empty() {
330 return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
331 }
332 if aggr_func.params.null_treatment.is_some() {
333 return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
334 }
335
336 match aggr_func.func.name().to_ascii_lowercase().as_str() {
337 "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
338 "min" => Ok(IncrementalAggregateMergeOp::Min),
339 "max" => Ok(IncrementalAggregateMergeOp::Max),
340 "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
341 "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
342 "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
343 "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
344 "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
345 _ => Err(aggr_expr.to_string()),
346 }
347}
348
349fn resolve_aggregate_output_field_name(
350 aggr_expr: &Expr,
351 projection_info: &OutputProjectionInfo,
352 output_field_name_set: &HashSet<String>,
353) -> Option<String> {
354 let raw_name = aggr_expr.qualified_name().1;
360 if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
361 Some(alias.clone())
362 } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
363 {
364 Some(raw_name)
365 } else {
366 None
367 }
368}
369
370fn find_uncovered_output_fields(
371 projection_info: &OutputProjectionInfo,
372 group_key_names: &[String],
373 merge_columns: &[IncrementalAggregateMergeColumn],
374) -> Vec<String> {
375 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
376 let merge_column_names = merge_columns
377 .iter()
378 .map(|c| c.output_field_name.clone())
379 .collect::<HashSet<_>>();
380
381 projection_info
382 .output_field_names
383 .iter()
384 .filter(|name| {
385 !group_key_names.contains(*name)
386 && !merge_column_names.contains(*name)
387 && !projection_info.literal_columns.contains(*name)
388 })
389 .cloned()
390 .collect()
391}
392
393fn find_unsupported_group_key_projection_outputs(
394 plan: &LogicalPlan,
395 aggregate: &Aggregate,
396 group_key_names: &[String],
397) -> Vec<String> {
398 let LogicalPlan::Projection(projection) = plan else {
399 return vec![];
400 };
401
402 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
403 let group_expr_names = aggregate
404 .group_expr
405 .iter()
406 .filter_map(|expr| expr.name_for_alias().ok())
407 .collect::<HashSet<_>>();
408 projection
409 .expr
410 .iter()
411 .filter_map(|expr| {
412 let output_name = expr.qualified_name().1;
413 if !group_key_names.contains(&output_name) {
414 return None;
415 }
416
417 let source_name = match expr {
418 Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
419 _ => expr.name_for_alias().ok(),
420 };
421 if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
422 None
423 } else {
424 Some(format!(
425 "unsupported group key output field is not a transparent group expression: {output_name}"
426 ))
427 }
428 })
429 .collect()
430}
431
432pub fn analyze_incremental_aggregate_plan(
433 plan: &LogicalPlan,
434) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
435 let group_key_names = find_group_key_names(plan)?;
436 let aggregate = match extract_incremental_aggregate(plan) {
437 Ok(Some(aggregate)) => aggregate,
438 Ok(None) => return Ok(None),
439 Err(reason) => {
440 let projection_info = collect_output_projection_info(plan);
441 let mut unsupported_exprs = projection_info
442 .duplicate_output_names()
443 .into_iter()
444 .map(|name| format!("duplicate output field name: {name}"))
445 .collect::<Vec<_>>();
446 unsupported_exprs.push(reason);
447 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
448 return Ok(Some(IncrementalAggregateAnalysis {
449 group_key_names,
450 merge_columns: vec![],
451 literal_columns: vec![],
452 output_field_names: projection_info.output_field_names,
453 unsupported_exprs,
454 }));
455 }
456 };
457 let aggr_exprs = aggregate.aggr_expr.clone();
458 let projection_info = collect_output_projection_info(plan);
459 let output_field_name_set = projection_info.output_field_name_set();
460
461 let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
462 let mut unsupported_exprs = projection_info
463 .duplicate_output_names()
464 .into_iter()
465 .map(|name| format!("duplicate output field name: {name}"))
466 .collect::<Vec<_>>();
467 if has_grouping_set(plan) {
468 unsupported_exprs.push(
469 "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
470 );
471 }
472 if group_key_names.is_empty() {
473 unsupported_exprs
474 .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
475 }
476 unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
477 plan,
478 aggregate,
479 &group_key_names,
480 ));
481 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
482 for aggr_expr in aggr_exprs {
483 let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
484 Ok(merge_op) => merge_op,
485 Err(reason) => {
486 unsupported_exprs.push(reason);
487 continue;
488 }
489 };
490 let Some(output_field_name) = resolve_aggregate_output_field_name(
491 &aggr_expr,
492 &projection_info,
493 &output_field_name_set,
494 ) else {
495 unsupported_exprs.push(aggr_expr.to_string());
496 continue;
497 };
498 merge_columns.push(IncrementalAggregateMergeColumn::new(
499 output_field_name,
500 merge_op,
501 ));
502 }
503 unsupported_exprs.extend(
504 find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
505 .into_iter()
506 .map(|name| format!("unsupported output field: {name}")),
507 );
508 if !unsupported_exprs.is_empty() {
509 merge_columns.clear();
510 }
511 let mut literal_columns = projection_info
512 .literal_columns
513 .into_iter()
514 .collect::<Vec<_>>();
515 literal_columns.sort();
516
517 Ok(Some(IncrementalAggregateAnalysis {
518 group_key_names,
519 merge_columns,
520 literal_columns,
521 output_field_names: projection_info.output_field_names,
522 unsupported_exprs,
523 }))
524}
525
526pub async fn rewrite_incremental_aggregate_with_sink_merge(
552 delta_plan: &LogicalPlan,
553 analysis: &IncrementalAggregateAnalysis,
554 sink_table: TableRef,
555 sink_table_name: &TableName,
556) -> Result<LogicalPlan, Error> {
557 ensure!(
558 analysis.unsupported_exprs.is_empty(),
559 InvalidQuerySnafu {
560 reason: format!(
561 "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
562 analysis.unsupported_exprs
563 )
564 }
565 );
566
567 ensure!(
568 !analysis.merge_columns.is_empty(),
569 InvalidQuerySnafu {
570 reason:
571 "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
572 .to_string()
573 }
574 );
575
576 ensure!(
577 !analysis.group_key_names.is_empty(),
578 InvalidQuerySnafu {
579 reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
580 .to_string()
581 }
582 );
583
584 let delta_alias = "__flow_delta";
585 let sink_alias = "__flow_sink";
586
587 let mut selected_columns = analysis.group_key_names.clone();
588 selected_columns.extend(
589 analysis
590 .merge_columns
591 .iter()
592 .map(|c| c.output_field_name.clone()),
593 );
594 let mut delta_selected_columns = selected_columns.clone();
595 delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
596
597 let delta_selected_exprs = delta_selected_columns
598 .iter()
599 .cloned()
600 .map(unqualified_col)
601 .collect::<Vec<_>>();
602 let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
603 .project(delta_selected_exprs)
604 .with_context(|_| DatafusionSnafu {
605 context: "Failed to project delta plan for incremental sink merge".to_string(),
606 })?
607 .alias(delta_alias)
608 .with_context(|_| DatafusionSnafu {
609 context: "Failed to alias delta plan for incremental sink merge".to_string(),
610 })?
611 .build()
612 .with_context(|_| DatafusionSnafu {
613 context: "Failed to build projected delta plan for incremental sink merge".to_string(),
614 })?;
615
616 let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
617 let table_source = Arc::new(DefaultTableSource::new(table_provider));
618 let sink_scan = LogicalPlan::TableScan(
619 TableScan::try_new(
620 TableReference::Full {
621 catalog: sink_table_name[0].clone().into(),
622 schema: sink_table_name[1].clone().into(),
623 table: sink_table_name[2].clone().into(),
624 },
625 table_source,
626 None,
627 vec![],
628 None,
629 )
630 .with_context(|_| DatafusionSnafu {
631 context: "Failed to build sink table scan for incremental sink merge".to_string(),
632 })?,
633 );
634
635 let sink_selected_exprs = selected_columns
636 .iter()
637 .cloned()
638 .map(unqualified_col)
639 .collect::<Vec<_>>();
640 let sink_selected = LogicalPlanBuilder::from(sink_scan)
641 .project(sink_selected_exprs)
642 .with_context(|_| DatafusionSnafu {
643 context: "Failed to project sink table scan for incremental sink merge".to_string(),
644 })?
645 .alias(sink_alias)
646 .with_context(|_| DatafusionSnafu {
647 context: "Failed to alias sink plan for incremental sink merge".to_string(),
648 })?
649 .build()
650 .with_context(|_| DatafusionSnafu {
651 context: "Failed to build projected sink plan for incremental sink merge".to_string(),
652 })?;
653
654 let join_keys = (
655 analysis
656 .group_key_names
657 .iter()
658 .cloned()
659 .map(|c| qualified_column(delta_alias, c))
660 .collect::<Vec<_>>(),
661 analysis
662 .group_key_names
663 .iter()
664 .cloned()
665 .map(|c| qualified_column(sink_alias, c))
666 .collect::<Vec<_>>(),
667 );
668
669 let joined = LogicalPlanBuilder::from(delta_selected)
670 .join_detailed(
671 sink_selected,
672 JoinType::Left,
673 join_keys,
674 None,
675 NullEquality::NullEqualsNull,
676 )
677 .with_context(|_| DatafusionSnafu {
678 context: "Failed to left join delta and sink plans for incremental sink merge"
679 .to_string(),
680 })?
681 .build()
682 .with_context(|_| DatafusionSnafu {
683 context: "Failed to build left join plan for incremental sink merge".to_string(),
684 })?;
685
686 let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
687 let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
688 let merge_columns = analysis
689 .merge_columns
690 .iter()
691 .map(|c| (&c.output_field_name, c))
692 .collect::<HashMap<_, _>>();
693
694 let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
695 for output_field_name in &analysis.output_field_names {
696 if group_key_names.contains(output_field_name)
697 || literal_columns.contains(output_field_name)
698 {
699 projection_exprs.push(
700 qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
701 );
702 } else if let Some(merge_col) = merge_columns.get(output_field_name) {
703 projection_exprs.push(build_left_join_merge_expr(
704 delta_alias,
705 sink_alias,
706 merge_col,
707 )?);
708 } else {
709 return InvalidQuerySnafu {
710 reason: format!(
711 "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
712 ),
713 }
714 .fail();
715 }
716 }
717
718 LogicalPlanBuilder::from(joined)
719 .project(projection_exprs)
720 .with_context(|_| DatafusionSnafu {
721 context: "Failed to build projection merge plan for incremental sink merge".to_string(),
722 })?
723 .build()
724 .with_context(|_| DatafusionSnafu {
725 context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
726 })
727}
728
729fn build_left_join_merge_expr(
730 delta_alias: &str,
731 sink_alias: &str,
732 merge_col: &IncrementalAggregateMergeColumn,
733) -> Result<Expr, Error> {
734 let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
735 let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
736 let merged = match merge_col.merge_op {
737 IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
738 .when(is_null(right.clone()), left.clone())
739 .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
740 .with_context(|_| DatafusionSnafu {
741 context: "Failed to build SUM merge expression".to_string(),
742 })?,
743 IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
744 .when(left.clone().lt_eq(right.clone()), left.clone())
745 .otherwise(right.clone())
746 .with_context(|_| DatafusionSnafu {
747 context: "Failed to build MIN merge expression".to_string(),
748 })?,
749 IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
750 .when(left.clone().gt_eq(right.clone()), left.clone())
751 .otherwise(right.clone())
752 .with_context(|_| DatafusionSnafu {
753 context: "Failed to build MAX merge expression".to_string(),
754 })?,
755 IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
756 .when(is_null(right.clone()), left.clone())
757 .otherwise(and(left.clone(), right.clone()))
758 .with_context(|_| DatafusionSnafu {
759 context: "Failed to build BOOL_AND merge expression".to_string(),
760 })?,
761 IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
762 .when(is_null(right.clone()), left.clone())
763 .otherwise(or(left.clone(), right.clone()))
764 .with_context(|_| DatafusionSnafu {
765 context: "Failed to build BOOL_OR merge expression".to_string(),
766 })?,
767 IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
768 .when(is_null(right.clone()), left.clone())
769 .otherwise(bitwise_and(left.clone(), right.clone()))
770 .with_context(|_| DatafusionSnafu {
771 context: "Failed to build BIT_AND merge expression".to_string(),
772 })?,
773 IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
774 .when(is_null(right.clone()), left.clone())
775 .otherwise(bitwise_or(left.clone(), right.clone()))
776 .with_context(|_| DatafusionSnafu {
777 context: "Failed to build BIT_OR merge expression".to_string(),
778 })?,
779 IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
780 .when(is_null(right.clone()), left.clone())
781 .otherwise(bitwise_xor(left.clone(), right.clone()))
782 .with_context(|_| DatafusionSnafu {
783 context: "Failed to build BIT_XOR merge expression".to_string(),
784 })?,
785 };
786 Ok(merged.alias(merge_col.output_field_name.clone()))
787}
788
789pub async fn get_table_info_df_schema(
790 catalog_mr: CatalogManagerRef,
791 table_name: TableName,
792) -> Result<(TableRef, Arc<DFSchema>), Error> {
793 let full_table_name = table_name.clone().join(".");
794 let table = catalog_mr
795 .table(&table_name[0], &table_name[1], &table_name[2], None)
796 .await
797 .map_err(BoxedError::new)
798 .context(ExternalSnafu)?
799 .context(TableNotFoundSnafu {
800 name: &full_table_name,
801 })?;
802 let table_info = table.table_info();
803
804 let schema = table_info.meta.schema.clone();
805
806 let df_schema: Arc<DFSchema> = Arc::new(
807 schema
808 .arrow_schema()
809 .clone()
810 .try_into()
811 .with_context(|_| DatafusionSnafu {
812 context: format!(
813 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
814 schema.arrow_schema()
815 ),
816 })?,
817 );
818 Ok((table, df_schema))
819}
820
821pub async fn sql_to_df_plan(
824 query_ctx: QueryContextRef,
825 engine: QueryEngineRef,
826 sql: &str,
827 optimize: bool,
828) -> Result<LogicalPlan, Error> {
829 let stmts =
830 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
831 .map_err(BoxedError::new)
832 .context(ExternalSnafu)?;
833
834 ensure!(
835 stmts.len() == 1,
836 InvalidQuerySnafu {
837 reason: format!("Expect only one statement, found {}", stmts.len())
838 }
839 );
840 let stmt = &stmts[0];
841 let query_stmt = match stmt {
842 Statement::Tql(tql) => match tql {
843 Tql::Eval(eval) => {
844 let eval = eval.clone();
845 let promql = PromQuery {
846 start: eval.start,
847 end: eval.end,
848 step: eval.step,
849 query: eval.query,
850 lookback: eval
851 .lookback
852 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
853 alias: eval.alias.clone(),
854 };
855
856 QueryLanguageParser::parse_promql(&promql, &query_ctx)
857 .map_err(BoxedError::new)
858 .context(ExternalSnafu)?
859 }
860 _ => InvalidQuerySnafu {
861 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
862 }
863 .fail()?,
864 },
865 _ => QueryStatement::Sql(stmt.clone()),
866 };
867 let plan = engine
868 .planner()
869 .plan(&query_stmt, query_ctx.clone())
870 .await
871 .map_err(BoxedError::new)
872 .context(ExternalSnafu)?;
873
874 let plan = if optimize {
875 apply_df_optimizer(plan, &query_ctx).await?
876 } else {
877 plan
878 };
879 Ok(plan)
880}
881
882pub(crate) async fn gen_plan_with_matching_schema(
885 sql: &str,
886 query_ctx: QueryContextRef,
887 engine: QueryEngineRef,
888 sink_table_schema: SchemaRef,
889 primary_key_indices: &[usize],
890 allow_partial: bool,
891) -> Result<LogicalPlan, Error> {
892 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
893
894 let mut add_auto_column = ColumnMatcherRewriter::new(
895 sink_table_schema,
896 primary_key_indices.to_vec(),
897 allow_partial,
898 );
899 let plan = plan
900 .clone()
901 .rewrite(&mut add_auto_column)
902 .with_context(|_| DatafusionSnafu {
903 context: format!("Failed to rewrite plan:\n {}\n", plan),
904 })?
905 .data;
906 Ok(plan)
907}
908
909pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
910 struct ForceQuoteIdentifiers;
912 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
913 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
914 if identifier.to_lowercase() != identifier {
915 Some('`')
916 } else {
917 None
918 }
919 }
920 }
921 let unparser = Unparser::new(&ForceQuoteIdentifiers);
922 let sql = unparser
924 .plan_to_sql(plan)
925 .with_context(|_e| DatafusionSnafu {
926 context: format!("Failed to unparse logical plan {plan:?}"),
927 })?;
928 Ok(sql.to_string())
929}
930
931#[derive(Debug, Clone, Default)]
933pub struct FindGroupByFinalName {
934 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
935}
936
937impl FindGroupByFinalName {
938 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
939 self.group_exprs
940 .as_ref()
941 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
942 }
943}
944
945impl TreeNodeVisitor<'_> for FindGroupByFinalName {
946 type Node = LogicalPlan;
947
948 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
949 if let LogicalPlan::Aggregate(aggregate) = node {
950 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
951 debug!(
952 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
953 self.group_exprs
954 );
955 } else if let LogicalPlan::Distinct(distinct) = node {
956 debug!("FindGroupByFinalName: Distinct: {}", node);
957 match distinct {
958 Distinct::All(input) => {
959 if let LogicalPlan::TableScan(table_scan) = &**input {
960 let len = table_scan.projected_schema.fields().len();
962 let columns = (0..len)
963 .map(|f| {
964 let (qualifier, field) =
965 table_scan.projected_schema.qualified_field(f);
966 datafusion_common::Column::new(qualifier.cloned(), field.name())
967 })
968 .map(datafusion_expr::Expr::Column);
969 self.group_exprs = Some(columns.collect());
970 } else {
971 self.group_exprs = Some(input.expressions().iter().cloned().collect())
972 }
973 }
974 Distinct::On(distinct_on) => {
975 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
976 }
977 }
978 debug!(
979 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
980 self.group_exprs
981 );
982 }
983
984 Ok(TreeNodeRecursion::Continue)
985 }
986
987 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
989 if let LogicalPlan::Projection(projection) = node {
990 for expr in &projection.expr {
991 let Some(group_exprs) = &mut self.group_exprs else {
992 return Ok(TreeNodeRecursion::Continue);
993 };
994 if let datafusion_expr::Expr::Alias(alias) = expr {
995 let mut new_group_exprs = group_exprs.clone();
997 for group_expr in group_exprs.iter() {
998 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
999 new_group_exprs.remove(group_expr);
1000 new_group_exprs.insert(expr.clone());
1001 break;
1002 }
1003 }
1004 *group_exprs = new_group_exprs;
1005 }
1006 }
1007 }
1008 debug!("Aliased group by exprs: {:?}", self.group_exprs);
1009 Ok(TreeNodeRecursion::Continue)
1010 }
1011}
1012
1013#[derive(Debug)]
1020pub struct ColumnMatcherRewriter {
1021 pub schema: SchemaRef,
1022 pub is_rewritten: bool,
1023 pub primary_key_indices: Vec<usize>,
1024 pub allow_partial: bool,
1025}
1026
1027impl ColumnMatcherRewriter {
1028 pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1029 Self {
1030 schema,
1031 is_rewritten: false,
1032 primary_key_indices,
1033 allow_partial,
1034 }
1035 }
1036
1037 fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1039 if self.allow_partial {
1040 return self.modify_project_exprs_with_partial(exprs);
1041 }
1042
1043 let all_names = self
1044 .schema
1045 .column_schemas()
1046 .iter()
1047 .map(|c| c.name.clone())
1048 .collect::<BTreeSet<_>>();
1049 for (idx, expr) in exprs.iter_mut().enumerate() {
1051 if !all_names.contains(&expr.qualified_name().1)
1052 && let Some(col_name) = self
1053 .schema
1054 .column_schemas()
1055 .get(idx)
1056 .map(|c| c.name.clone())
1057 {
1058 *expr = expr.clone().alias(col_name);
1062 }
1063 }
1064
1065 let query_col_cnt = exprs.len();
1067 let table_col_cnt = self.schema.column_schemas().len();
1068 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1069
1070 let placeholder_ts_expr =
1071 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1072 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1073
1074 if query_col_cnt == table_col_cnt {
1075 } else if query_col_cnt + 1 == table_col_cnt {
1077 let last_col_schema = self.schema.column_schemas().last().unwrap();
1078
1079 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1081 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1082 {
1083 exprs.push(placeholder_ts_expr);
1084 } else if last_col_schema.data_type.is_timestamp() {
1085 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1087 } else {
1088 return Err(DataFusionError::Plan(format!(
1090 "Expect the last column in table to be timestamp column, found column {} with type {:?}",
1091 last_col_schema.name, last_col_schema.data_type
1092 )));
1093 }
1094 } else if query_col_cnt + 2 == table_col_cnt {
1095 let mut col_iter = self.schema.column_schemas().iter().rev();
1096 let last_col_schema = col_iter.next().unwrap();
1097 let second_last_col_schema = col_iter.next().unwrap();
1098 if second_last_col_schema.data_type.is_timestamp() {
1099 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1100 } else {
1101 return Err(DataFusionError::Plan(format!(
1102 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1103 second_last_col_schema.name, second_last_col_schema.data_type
1104 )));
1105 }
1106
1107 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1108 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1109 {
1110 exprs.push(placeholder_ts_expr);
1111 } else {
1112 return Err(DataFusionError::Plan(format!(
1113 "Expect timestamp column {}, found {:?}",
1114 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1115 )));
1116 }
1117 } else {
1118 return Err(DataFusionError::Plan(format!(
1119 "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
1120 query_col_cnt,
1121 exprs,
1122 table_col_cnt,
1123 self.schema.column_schemas()
1124 )));
1125 }
1126 Ok(exprs)
1127 }
1128
1129 fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1130 let table_col_cnt = self.schema.column_schemas().len();
1131 let query_col_cnt = exprs.len();
1132
1133 if query_col_cnt > table_col_cnt {
1134 return Err(DataFusionError::Plan(format!(
1135 "Expect query column count <= table column count, found {} query columns {:?}, {} table columns {:?}",
1136 query_col_cnt,
1137 exprs,
1138 table_col_cnt,
1139 self.schema.column_schemas()
1140 )));
1141 }
1142
1143 let name_to_expr: HashMap<String, Expr> = exprs
1144 .clone()
1145 .into_iter()
1146 .map(|e| (e.qualified_name().1, e))
1147 .collect();
1148
1149 let required_columns = self.required_columns_for_partial();
1150 let missing: Vec<_> = required_columns
1151 .iter()
1152 .filter(|name| !name_to_expr.contains_key(*name))
1153 .cloned()
1154 .collect();
1155 if !missing.is_empty() {
1156 return Err(DataFusionError::Plan(format!(
1157 "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null",
1158 missing
1159 )));
1160 }
1161
1162 let placeholder_ts_expr =
1163 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1164 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1165
1166 let timestamp_index = self.schema.timestamp_index();
1167 let mut remap = name_to_expr;
1168 let mut new_exprs = Vec::with_capacity(table_col_cnt);
1169
1170 for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1171 let col_name = col_schema.name.clone();
1172 if let Some(expr) = remap.remove(&col_name) {
1173 let expr = if expr.qualified_name().1 == col_name {
1174 expr
1175 } else {
1176 expr.alias(col_name.clone())
1177 };
1178 new_exprs.push(expr);
1179 continue;
1180 }
1181
1182 if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1183 new_exprs.push(placeholder_ts_expr.clone());
1184 continue;
1185 }
1186
1187 if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1188 new_exprs.push(datafusion::prelude::now().alias(&col_name));
1189 continue;
1190 }
1191
1192 new_exprs.push(Self::null_expr(col_schema));
1193 }
1194
1195 if !remap.is_empty() {
1196 let extra: Vec<_> = remap.keys().cloned().collect();
1197 return Err(DataFusionError::Plan(format!(
1198 "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null",
1199 extra
1200 )));
1201 }
1202
1203 Ok(new_exprs)
1204 }
1205
1206 fn null_expr(col_schema: &ColumnSchema) -> Expr {
1207 Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1208 }
1209
1210 fn required_columns_for_partial(&self) -> HashSet<String> {
1211 let mut required = HashSet::new();
1212 for idx in &self.primary_key_indices {
1213 if let Some(col) = self.schema.column_schemas().get(*idx) {
1214 required.insert(col.name.clone());
1215 }
1216 }
1217
1218 if let Some(ts_idx) = self.schema.timestamp_index()
1219 && let Some(col) = self.schema.column_schemas().get(ts_idx)
1220 && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1221 {
1222 required.insert(col.name.clone());
1223 }
1224
1225 required
1226 }
1227}
1228
1229impl TreeNodeRewriter for ColumnMatcherRewriter {
1230 type Node = LogicalPlan;
1231 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1232 if self.is_rewritten {
1233 return Ok(Transformed::no(node));
1234 }
1235
1236 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1238 let mut exprs = vec![];
1239
1240 for field in node.schema().fields().iter() {
1241 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1242 field.name(),
1243 )));
1244 }
1245
1246 let projection =
1247 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1248
1249 node = projection;
1250 }
1251 else if let LogicalPlan::TableScan(table_scan) = node {
1253 let mut exprs = vec![];
1254
1255 for field in table_scan.projected_schema.fields().iter() {
1256 exprs.push(Expr::Column(datafusion::common::Column::new(
1257 Some(table_scan.table_name.clone()),
1258 field.name(),
1259 )));
1260 }
1261
1262 let projection = LogicalPlan::Projection(Projection::try_new(
1263 exprs,
1264 Arc::new(LogicalPlan::TableScan(table_scan)),
1265 )?);
1266
1267 node = projection;
1268 }
1269
1270 if let LogicalPlan::Projection(project) = &node {
1274 let exprs = project.expr.clone();
1275 let exprs = self.modify_project_exprs(exprs)?;
1276
1277 self.is_rewritten = true;
1278 let new_plan =
1279 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1280 Ok(Transformed::yes(new_plan))
1281 } else {
1282 let mut exprs = vec![];
1284 for field in node.schema().fields().iter() {
1285 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1286 field.name(),
1287 )));
1288 }
1289 let exprs = self.modify_project_exprs(exprs)?;
1290 self.is_rewritten = true;
1291 let new_plan =
1292 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1293 Ok(Transformed::yes(new_plan))
1294 }
1295 }
1296
1297 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1299 node.recompute_schema().map(Transformed::yes)
1300 }
1301}
1302
1303#[derive(Debug)]
1305pub struct AddFilterRewriter {
1306 extra_filter: Expr,
1307 is_rewritten: bool,
1308}
1309
1310impl AddFilterRewriter {
1311 pub fn new(filter: Expr) -> Self {
1312 Self {
1313 extra_filter: filter,
1314 is_rewritten: false,
1315 }
1316 }
1317}
1318
1319impl TreeNodeRewriter for AddFilterRewriter {
1320 type Node = LogicalPlan;
1321 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1322 if self.is_rewritten {
1323 return Ok(Transformed::no(node));
1324 }
1325 match node {
1326 LogicalPlan::Filter(mut filter) => {
1327 filter.predicate = filter.predicate.and(self.extra_filter.clone());
1328 self.is_rewritten = true;
1329 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1330 }
1331 LogicalPlan::TableScan(_) => {
1332 let filter =
1334 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1335 self.is_rewritten = true;
1336 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1337 }
1338 _ => Ok(Transformed::no(node)),
1339 }
1340 }
1341}