1use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32 Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36 Distinct, ExprSchemable, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and,
37 binary_expr, bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::prelude::ConcreteDataType;
40use datatypes::schema::{ColumnSchema, SchemaRef};
41use query::QueryEngineRef;
42use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
43use session::context::QueryContextRef;
44use snafu::{OptionExt, ResultExt, ensure};
45use sql::parser::{ParseOptions, ParserContext};
46use sql::statements::statement::Statement;
47use sql::statements::tql::Tql;
48use table::TableRef;
49use table::table::adapter::DfTableProviderAdapter;
50
51use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
52use crate::df_optimizer::apply_df_optimizer;
53use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
54use crate::{Error, TableName};
55
56#[cfg(test)]
57mod test;
58
59#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct IncrementalAggregateMergeColumn {
69 pub output_field_name: String,
72 pub merge_op: IncrementalAggregateMergeOp,
73}
74
75impl IncrementalAggregateMergeColumn {
76 pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
78 Self {
79 output_field_name,
80 merge_op,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum IncrementalAggregateMergeOp {
87 Sum,
88 Min,
89 Max,
90 BoolAnd,
91 BoolOr,
92 BitAnd,
93 BitOr,
94 BitXor,
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct IncrementalAggregateAnalysis {
106 pub group_key_names: Vec<String>,
108 pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
109 pub literal_columns: Vec<String>,
111 pub output_field_names: Vec<String>,
113 pub unsupported_exprs: Vec<String>,
114}
115
116fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
125 match expr {
126 Expr::Column(col) => {
127 names.push(col.name.clone());
128 }
129 Expr::Alias(alias) => find_column_names(&alias.expr, names),
130 _ => {}
131 }
132}
133
134fn unqualified_col(name: impl Into<String>) -> Expr {
135 Expr::Column(Column::from_name(name.into()))
136}
137
138fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
139 Expr::Column(Column::new(Some(qualifier), name.into()))
140}
141
142fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
143 Column::new(Some(qualifier), name.into())
144}
145
146fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
147 let mut group_finder = FindGroupByFinalName::default();
148 plan.visit(&mut group_finder)
149 .with_context(|_| DatafusionSnafu {
150 context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
151 })?;
152
153 let mut group_key_names = group_finder
154 .get_group_expr_names()
155 .unwrap_or_default()
156 .into_iter()
157 .collect::<Vec<_>>();
158 group_key_names.sort();
159 Ok(group_key_names)
160}
161
162fn has_grouping_set(plan: &LogicalPlan) -> bool {
163 match plan {
164 LogicalPlan::Aggregate(aggregate) => aggregate
165 .group_expr
166 .iter()
167 .any(|expr| matches!(expr, Expr::GroupingSet(_))),
168 _ => plan.inputs().into_iter().any(has_grouping_set),
169 }
170}
171
172fn has_aggregate(plan: &LogicalPlan) -> bool {
173 match plan {
174 LogicalPlan::Aggregate(_) => true,
175 _ => plan.inputs().into_iter().any(has_aggregate),
176 }
177}
178
179fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
180 while let LogicalPlan::SubqueryAlias(alias) = plan {
181 plan = alias.input.as_ref();
182 }
183 plan
184}
185
186fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
187 let plan = match plan {
191 LogicalPlan::Projection(projection) => projection.input.as_ref(),
192 _ => plan,
193 };
194
195 match plan {
196 LogicalPlan::Aggregate(aggregate) => {
197 check_input_plan_shape(aggregate.input.as_ref())?;
198 Ok(Some(aggregate))
199 }
200 LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
201 "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
202 .to_string(),
203 ),
204 _ if has_aggregate(plan) => Err(
205 "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
206 ),
207 _ => Ok(None),
208 }
209}
210
211fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
212 let plan = peel_subquery_aliases(plan);
213 match plan {
214 LogicalPlan::TableScan(_) => Ok(()),
217 LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
218 LogicalPlan::TableScan(_) => Ok(()),
219 _ => Err(
220 "unsupported aggregate input plan shape in incremental aggregate rewrite"
221 .to_string(),
222 ),
223 },
224 _ => Err(
225 "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
226 ),
227 }
228}
229
230#[derive(Debug, Default)]
231struct OutputProjectionInfo {
232 has_top_level_projection: bool,
233 output_aliases: HashMap<String, String>,
234 duplicate_aggregate_aliases: BTreeSet<String>,
235 literal_columns: HashSet<String>,
236 output_field_names: Vec<String>,
237}
238
239impl OutputProjectionInfo {
240 fn output_field_name_set(&self) -> HashSet<String> {
241 self.output_field_names.iter().cloned().collect()
242 }
243
244 fn duplicate_output_names(&self) -> Vec<String> {
245 let mut seen = HashSet::new();
246 let mut duplicates = BTreeSet::new();
247 for name in &self.output_field_names {
248 if !seen.insert(name.clone()) {
249 duplicates.insert(name.clone());
250 }
251 }
252 duplicates.into_iter().collect()
253 }
254}
255
256fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
257 let mut projection_info = OutputProjectionInfo {
258 has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
259 output_field_names: plan
260 .schema()
261 .fields()
262 .iter()
263 .map(|field| field.name().clone())
264 .collect(),
265 ..Default::default()
266 };
267
268 let mut output_aliases = HashMap::new();
269 if let LogicalPlan::Projection(projection) = plan {
270 for expr in &projection.expr {
271 match expr {
272 Expr::Alias(alias) => {
273 let alias_name = alias.name.clone();
279 let mut col_names = Vec::new();
280 find_column_names(&alias.expr, &mut col_names);
281 match col_names.len() {
282 0 if is_passthrough_output_column(&alias_name, alias.expr.as_ref()) => {
283 projection_info.literal_columns.insert(alias_name);
284 }
285 1 => {
286 if let Some(col_name) = col_names.into_iter().next() {
287 if let Some(existing_alias) = output_aliases.get(&col_name) {
288 if existing_alias != &alias_name {
289 projection_info.duplicate_aggregate_aliases.insert(format!(
290 "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
291 ));
292 }
293 } else {
294 output_aliases.insert(col_name, alias_name);
295 }
296 }
297 }
298 _ => {}
299 }
300
301 }
304 Expr::Column(col) => {
305 output_aliases
306 .entry(col.name.clone())
307 .or_insert(col.name.clone());
308 }
309 Expr::Literal(_, _) => {
310 projection_info
311 .literal_columns
312 .insert(expr.qualified_name().1);
313 }
314 _ => {}
315 }
316 }
317 }
318
319 if projection_info
320 .output_field_names
321 .iter()
322 .any(|name| name == AUTO_CREATED_PLACEHOLDER_TS_COL)
323 {
324 projection_info
325 .literal_columns
326 .insert(AUTO_CREATED_PLACEHOLDER_TS_COL.to_string());
327 }
328
329 projection_info.output_aliases = output_aliases;
330 projection_info
331}
332
333fn is_passthrough_output_column(alias_name: &str, expr: &Expr) -> bool {
334 matches!(expr, Expr::Literal(_, _))
335 || match alias_name {
336 AUTO_CREATED_UPDATE_AT_TS_COL => expr == &datafusion::prelude::now(),
337 AUTO_CREATED_PLACEHOLDER_TS_COL => is_literal_or_cast_literal(expr),
338 _ => false,
339 }
340}
341
342fn is_literal_or_cast_literal(expr: &Expr) -> bool {
343 match expr {
344 Expr::Literal(_, _) => true,
345 Expr::Cast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
346 Expr::TryCast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
347 _ => false,
348 }
349}
350
351fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
352 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
353 return Err(aggr_expr.to_string());
354 };
355 if aggr_func.params.distinct {
356 return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
357 }
358 if !aggr_func.params.order_by.is_empty() {
359 return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
360 }
361 if aggr_func.params.null_treatment.is_some() {
362 return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
363 }
364
365 match aggr_func.func.name().to_ascii_lowercase().as_str() {
366 "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
367 "min" => Ok(IncrementalAggregateMergeOp::Min),
368 "max" => Ok(IncrementalAggregateMergeOp::Max),
369 "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
370 "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
371 "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
372 "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
373 "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
374 _ => Err(aggr_expr.to_string()),
375 }
376}
377
378fn resolve_aggregate_output_field_name(
379 aggr_expr: &Expr,
380 projection_info: &OutputProjectionInfo,
381 output_field_name_set: &HashSet<String>,
382) -> Option<String> {
383 let raw_name = aggr_expr.qualified_name().1;
389 if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
390 Some(alias.clone())
391 } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
392 {
393 Some(raw_name)
394 } else {
395 None
396 }
397}
398
399fn find_uncovered_output_fields(
400 projection_info: &OutputProjectionInfo,
401 group_key_names: &[String],
402 merge_columns: &[IncrementalAggregateMergeColumn],
403) -> Vec<String> {
404 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
405 let merge_column_names = merge_columns
406 .iter()
407 .map(|c| c.output_field_name.clone())
408 .collect::<HashSet<_>>();
409
410 projection_info
411 .output_field_names
412 .iter()
413 .filter(|name| {
414 !group_key_names.contains(*name)
415 && !merge_column_names.contains(*name)
416 && !projection_info.literal_columns.contains(*name)
417 && name.as_str() != AUTO_CREATED_UPDATE_AT_TS_COL
421 && name.as_str() != AUTO_CREATED_PLACEHOLDER_TS_COL
422 })
423 .cloned()
424 .collect()
425}
426
427fn find_unsupported_group_key_projection_outputs(
428 plan: &LogicalPlan,
429 aggregate: &Aggregate,
430 group_key_names: &[String],
431) -> Vec<String> {
432 let LogicalPlan::Projection(projection) = plan else {
433 return vec![];
434 };
435
436 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
437 let group_expr_names = aggregate
438 .group_expr
439 .iter()
440 .filter_map(|expr| expr.name_for_alias().ok())
441 .collect::<HashSet<_>>();
442 projection
443 .expr
444 .iter()
445 .filter_map(|expr| {
446 let output_name = expr.qualified_name().1;
447 if !group_key_names.contains(&output_name) {
448 return None;
449 }
450
451 let source_name = match expr {
452 Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
453 _ => expr.name_for_alias().ok(),
454 };
455 if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
456 None
457 } else {
458 Some(format!(
459 "unsupported group key output field is not a transparent group expression: {output_name}"
460 ))
461 }
462 })
463 .collect()
464}
465
466pub fn analyze_incremental_aggregate_plan(
467 plan: &LogicalPlan,
468) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
469 let group_key_names = find_group_key_names(plan)?;
470 let aggregate = match extract_incremental_aggregate(plan) {
471 Ok(Some(aggregate)) => aggregate,
472 Ok(None) => return Ok(None),
473 Err(reason) => {
474 let projection_info = collect_output_projection_info(plan);
475 let mut unsupported_exprs = projection_info
476 .duplicate_output_names()
477 .into_iter()
478 .map(|name| format!("duplicate output field name: {name}"))
479 .collect::<Vec<_>>();
480 unsupported_exprs.push(reason);
481 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
482 return Ok(Some(IncrementalAggregateAnalysis {
483 group_key_names,
484 merge_columns: vec![],
485 literal_columns: vec![],
486 output_field_names: projection_info.output_field_names,
487 unsupported_exprs,
488 }));
489 }
490 };
491 let aggr_exprs = aggregate.aggr_expr.clone();
492 let projection_info = collect_output_projection_info(plan);
493 let output_field_name_set = projection_info.output_field_name_set();
494
495 let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
496 let mut unsupported_exprs = projection_info
497 .duplicate_output_names()
498 .into_iter()
499 .map(|name| format!("duplicate output field name: {name}"))
500 .collect::<Vec<_>>();
501 if has_grouping_set(plan) {
502 unsupported_exprs.push(
503 "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
504 );
505 }
506 if group_key_names.is_empty() {
507 unsupported_exprs
508 .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
509 }
510 unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
511 plan,
512 aggregate,
513 &group_key_names,
514 ));
515 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
516 for aggr_expr in aggr_exprs {
517 let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
518 Ok(merge_op) => merge_op,
519 Err(reason) => {
520 unsupported_exprs.push(reason);
521 continue;
522 }
523 };
524 let Some(output_field_name) = resolve_aggregate_output_field_name(
525 &aggr_expr,
526 &projection_info,
527 &output_field_name_set,
528 ) else {
529 unsupported_exprs.push(aggr_expr.to_string());
530 continue;
531 };
532 merge_columns.push(IncrementalAggregateMergeColumn::new(
533 output_field_name,
534 merge_op,
535 ));
536 }
537 unsupported_exprs.extend(
538 find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
539 .into_iter()
540 .map(|name| format!("unsupported output field: {name}")),
541 );
542 if !unsupported_exprs.is_empty() {
543 merge_columns.clear();
544 }
545 let mut literal_columns = projection_info
546 .literal_columns
547 .into_iter()
548 .collect::<Vec<_>>();
549 literal_columns.sort();
550
551 Ok(Some(IncrementalAggregateAnalysis {
552 group_key_names,
553 merge_columns,
554 literal_columns,
555 output_field_names: projection_info.output_field_names,
556 unsupported_exprs,
557 }))
558}
559
560pub async fn rewrite_incremental_aggregate_with_sink_merge(
592 delta_plan: &LogicalPlan,
593 analysis: &IncrementalAggregateAnalysis,
594 sink_table: TableRef,
595 sink_table_name: &TableName,
596 sink_dirty_filter: Option<Expr>,
597) -> Result<LogicalPlan, Error> {
598 ensure!(
599 analysis.unsupported_exprs.is_empty(),
600 InvalidQuerySnafu {
601 reason: format!(
602 "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
603 analysis.unsupported_exprs
604 )
605 }
606 );
607
608 ensure!(
609 !analysis.merge_columns.is_empty(),
610 InvalidQuerySnafu {
611 reason:
612 "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
613 .to_string()
614 }
615 );
616
617 ensure!(
618 !analysis.group_key_names.is_empty(),
619 InvalidQuerySnafu {
620 reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
621 .to_string()
622 }
623 );
624
625 let delta_alias = "__flow_delta";
626 let sink_alias = "__flow_sink";
627
628 let mut selected_columns = analysis.group_key_names.clone();
629 selected_columns.extend(
630 analysis
631 .merge_columns
632 .iter()
633 .map(|c| c.output_field_name.clone()),
634 );
635 let mut delta_selected_columns = selected_columns.clone();
636 delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
637
638 let delta_selected_exprs = delta_selected_columns
639 .iter()
640 .cloned()
641 .map(unqualified_col)
642 .collect::<Vec<_>>();
643 let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
644 .project(delta_selected_exprs)
645 .with_context(|_| DatafusionSnafu {
646 context: "Failed to project delta plan for incremental sink merge".to_string(),
647 })?
648 .alias(delta_alias)
649 .with_context(|_| DatafusionSnafu {
650 context: "Failed to alias delta plan for incremental sink merge".to_string(),
651 })?
652 .build()
653 .with_context(|_| DatafusionSnafu {
654 context: "Failed to build projected delta plan for incremental sink merge".to_string(),
655 })?;
656
657 let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
658 let table_source = Arc::new(DefaultTableSource::new(table_provider));
659 let sink_scan = LogicalPlan::TableScan(
660 TableScan::try_new(
661 TableReference::Full {
662 catalog: sink_table_name[0].clone().into(),
663 schema: sink_table_name[1].clone().into(),
664 table: sink_table_name[2].clone().into(),
665 },
666 table_source,
667 None,
668 vec![],
669 None,
670 )
671 .with_context(|_| DatafusionSnafu {
672 context: "Failed to build sink table scan for incremental sink merge".to_string(),
673 })?,
674 );
675
676 let sink_selected_exprs = selected_columns
677 .iter()
678 .cloned()
679 .map(unqualified_col)
680 .collect::<Vec<_>>();
681 let sink_input = if let Some(predicate) = sink_dirty_filter {
682 LogicalPlanBuilder::from(sink_scan)
683 .filter(predicate)
684 .with_context(|_| DatafusionSnafu {
685 context: "Failed to filter sink table scan for incremental sink merge".to_string(),
686 })?
687 .build()
688 .with_context(|_| DatafusionSnafu {
689 context: "Failed to build filtered sink plan for incremental sink merge"
690 .to_string(),
691 })?
692 } else {
693 sink_scan
694 };
695
696 let sink_selected = LogicalPlanBuilder::from(sink_input)
697 .project(sink_selected_exprs)
698 .with_context(|_| DatafusionSnafu {
699 context: "Failed to project sink table scan for incremental sink merge".to_string(),
700 })?
701 .alias(sink_alias)
702 .with_context(|_| DatafusionSnafu {
703 context: "Failed to alias sink plan for incremental sink merge".to_string(),
704 })?
705 .build()
706 .with_context(|_| DatafusionSnafu {
707 context: "Failed to build projected sink plan for incremental sink merge".to_string(),
708 })?;
709
710 let join_keys = (
711 analysis
712 .group_key_names
713 .iter()
714 .cloned()
715 .map(|c| qualified_column(delta_alias, c))
716 .collect::<Vec<_>>(),
717 analysis
718 .group_key_names
719 .iter()
720 .cloned()
721 .map(|c| qualified_column(sink_alias, c))
722 .collect::<Vec<_>>(),
723 );
724
725 let joined = LogicalPlanBuilder::from(delta_selected)
726 .join_detailed(
727 sink_selected,
728 JoinType::Left,
729 join_keys,
730 None,
731 NullEquality::NullEqualsNull,
732 )
733 .with_context(|_| DatafusionSnafu {
734 context: "Failed to left join delta and sink plans for incremental sink merge"
735 .to_string(),
736 })?
737 .build()
738 .with_context(|_| DatafusionSnafu {
739 context: "Failed to build left join plan for incremental sink merge".to_string(),
740 })?;
741
742 let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
743 let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
744 let merge_columns = analysis
745 .merge_columns
746 .iter()
747 .map(|c| (&c.output_field_name, c))
748 .collect::<HashMap<_, _>>();
749
750 let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
751 for output_field_name in &analysis.output_field_names {
752 if group_key_names.contains(output_field_name)
753 || literal_columns.contains(output_field_name)
754 {
755 projection_exprs.push(
756 qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
757 );
758 } else if let Some(merge_col) = merge_columns.get(output_field_name) {
759 projection_exprs.push(build_left_join_merge_expr(
760 delta_alias,
761 sink_alias,
762 merge_col,
763 )?);
764 } else {
765 return InvalidQuerySnafu {
766 reason: format!(
767 "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
768 ),
769 }
770 .fail();
771 }
772 }
773
774 LogicalPlanBuilder::from(joined)
775 .project(projection_exprs)
776 .with_context(|_| DatafusionSnafu {
777 context: "Failed to build projection merge plan for incremental sink merge".to_string(),
778 })?
779 .build()
780 .with_context(|_| DatafusionSnafu {
781 context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
782 })
783}
784
785fn build_left_join_merge_expr(
786 delta_alias: &str,
787 sink_alias: &str,
788 merge_col: &IncrementalAggregateMergeColumn,
789) -> Result<Expr, Error> {
790 let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
791 let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
792 let merged = match merge_col.merge_op {
793 IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
794 .when(is_null(right.clone()), left.clone())
795 .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
796 .with_context(|_| DatafusionSnafu {
797 context: "Failed to build SUM merge expression".to_string(),
798 })?,
799 IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
800 .when(left.clone().lt_eq(right.clone()), left.clone())
801 .otherwise(right.clone())
802 .with_context(|_| DatafusionSnafu {
803 context: "Failed to build MIN merge expression".to_string(),
804 })?,
805 IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
806 .when(left.clone().gt_eq(right.clone()), left.clone())
807 .otherwise(right.clone())
808 .with_context(|_| DatafusionSnafu {
809 context: "Failed to build MAX merge expression".to_string(),
810 })?,
811 IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
812 .when(is_null(right.clone()), left.clone())
813 .otherwise(and(left.clone(), right.clone()))
814 .with_context(|_| DatafusionSnafu {
815 context: "Failed to build BOOL_AND merge expression".to_string(),
816 })?,
817 IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
818 .when(is_null(right.clone()), left.clone())
819 .otherwise(or(left.clone(), right.clone()))
820 .with_context(|_| DatafusionSnafu {
821 context: "Failed to build BOOL_OR merge expression".to_string(),
822 })?,
823 IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
824 .when(is_null(right.clone()), left.clone())
825 .otherwise(bitwise_and(left.clone(), right.clone()))
826 .with_context(|_| DatafusionSnafu {
827 context: "Failed to build BIT_AND merge expression".to_string(),
828 })?,
829 IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
830 .when(is_null(right.clone()), left.clone())
831 .otherwise(bitwise_or(left.clone(), right.clone()))
832 .with_context(|_| DatafusionSnafu {
833 context: "Failed to build BIT_OR merge expression".to_string(),
834 })?,
835 IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
836 .when(is_null(right.clone()), left.clone())
837 .otherwise(bitwise_xor(left.clone(), right.clone()))
838 .with_context(|_| DatafusionSnafu {
839 context: "Failed to build BIT_XOR merge expression".to_string(),
840 })?,
841 };
842 Ok(merged.alias(merge_col.output_field_name.clone()))
843}
844
845pub async fn get_table_info_df_schema(
846 catalog_mr: CatalogManagerRef,
847 table_name: TableName,
848) -> Result<(TableRef, Arc<DFSchema>), Error> {
849 let full_table_name = table_name.clone().join(".");
850 let table = catalog_mr
851 .table(&table_name[0], &table_name[1], &table_name[2], None)
852 .await
853 .map_err(BoxedError::new)
854 .context(ExternalSnafu)?
855 .context(TableNotFoundSnafu {
856 name: &full_table_name,
857 })?;
858 let table_info = table.table_info();
859
860 let schema = table_info.meta.schema.clone();
861
862 let df_schema: Arc<DFSchema> = Arc::new(
863 schema
864 .arrow_schema()
865 .clone()
866 .try_into()
867 .with_context(|_| DatafusionSnafu {
868 context: format!(
869 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
870 schema.arrow_schema()
871 ),
872 })?,
873 );
874 Ok((table, df_schema))
875}
876
877pub async fn sql_to_df_plan(
880 query_ctx: QueryContextRef,
881 engine: QueryEngineRef,
882 sql: &str,
883 optimize: bool,
884) -> Result<LogicalPlan, Error> {
885 let stmts =
886 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
887 .map_err(BoxedError::new)
888 .context(ExternalSnafu)?;
889
890 ensure!(
891 stmts.len() == 1,
892 InvalidQuerySnafu {
893 reason: format!("Expect only one statement, found {}", stmts.len())
894 }
895 );
896 let stmt = &stmts[0];
897 let query_stmt = match stmt {
898 Statement::Tql(tql) => match tql {
899 Tql::Eval(eval) => {
900 let eval = eval.clone();
901 let promql = PromQuery {
902 start: eval.start,
903 end: eval.end,
904 step: eval.step,
905 query: eval.query,
906 lookback: eval
907 .lookback
908 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
909 alias: eval.alias.clone(),
910 };
911
912 QueryLanguageParser::parse_promql(&promql, &query_ctx)
913 .map_err(BoxedError::new)
914 .context(ExternalSnafu)?
915 }
916 _ => InvalidQuerySnafu {
917 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
918 }
919 .fail()?,
920 },
921 _ => QueryStatement::Sql(stmt.clone()),
922 };
923 let plan = engine
924 .planner()
925 .plan(&query_stmt, query_ctx.clone())
926 .await
927 .map_err(BoxedError::new)
928 .context(ExternalSnafu)?;
929
930 let plan = if optimize {
931 apply_df_optimizer(plan, &query_ctx).await?
932 } else {
933 plan
934 };
935 Ok(plan)
936}
937
938pub(crate) async fn gen_plan_with_matching_schema(
941 sql: &str,
942 query_ctx: QueryContextRef,
943 engine: QueryEngineRef,
944 sink_table_schema: SchemaRef,
945 primary_key_indices: &[usize],
946 allow_partial: bool,
947) -> Result<LogicalPlan, Error> {
948 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
949
950 let mut add_auto_column = ColumnMatcherRewriter::new(
951 sink_table_schema,
952 primary_key_indices.to_vec(),
953 allow_partial,
954 );
955 let plan = plan
956 .clone()
957 .rewrite(&mut add_auto_column)
958 .with_context(|_| DatafusionSnafu {
959 context: "Failed to rewrite plan".to_string(),
960 })?
961 .data;
962 Ok(plan)
963}
964
965pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
966 struct ForceQuoteIdentifiers;
968 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
969 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
970 if identifier.to_lowercase() != identifier {
971 Some('`')
972 } else {
973 None
974 }
975 }
976 }
977 let unparser = Unparser::new(&ForceQuoteIdentifiers);
978 let sql = unparser
980 .plan_to_sql(plan)
981 .with_context(|_e| DatafusionSnafu {
982 context: format!("Failed to unparse logical plan {plan:?}"),
983 })?;
984 Ok(sql.to_string())
985}
986
987#[derive(Debug, Clone, Default)]
989pub struct FindGroupByFinalName {
990 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
991}
992
993impl FindGroupByFinalName {
994 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
995 self.group_exprs
996 .as_ref()
997 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
998 }
999}
1000
1001impl TreeNodeVisitor<'_> for FindGroupByFinalName {
1002 type Node = LogicalPlan;
1003
1004 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1005 if let LogicalPlan::Aggregate(aggregate) = node {
1006 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
1007 debug!(
1008 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
1009 self.group_exprs
1010 );
1011 } else if let LogicalPlan::Distinct(distinct) = node {
1012 debug!("FindGroupByFinalName: Distinct: {}", node);
1013 match distinct {
1014 Distinct::All(input) => {
1015 if let LogicalPlan::TableScan(table_scan) = &**input {
1016 let len = table_scan.projected_schema.fields().len();
1018 let columns = (0..len)
1019 .map(|f| {
1020 let (qualifier, field) =
1021 table_scan.projected_schema.qualified_field(f);
1022 datafusion_common::Column::new(qualifier.cloned(), field.name())
1023 })
1024 .map(datafusion_expr::Expr::Column);
1025 self.group_exprs = Some(columns.collect());
1026 } else {
1027 self.group_exprs = Some(input.expressions().iter().cloned().collect())
1028 }
1029 }
1030 Distinct::On(distinct_on) => {
1031 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
1032 }
1033 }
1034 debug!(
1035 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
1036 self.group_exprs
1037 );
1038 }
1039
1040 Ok(TreeNodeRecursion::Continue)
1041 }
1042
1043 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1045 if let LogicalPlan::Projection(projection) = node {
1046 for expr in &projection.expr {
1047 let Some(group_exprs) = &mut self.group_exprs else {
1048 return Ok(TreeNodeRecursion::Continue);
1049 };
1050 if let datafusion_expr::Expr::Alias(alias) = expr {
1051 let mut new_group_exprs = group_exprs.clone();
1053 for group_expr in group_exprs.iter() {
1054 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
1055 new_group_exprs.remove(group_expr);
1056 new_group_exprs.insert(expr.clone());
1057 break;
1058 }
1059 }
1060 *group_exprs = new_group_exprs;
1061 }
1062 }
1063 }
1064 debug!("Aliased group by exprs: {:?}", self.group_exprs);
1065 Ok(TreeNodeRecursion::Continue)
1066 }
1067}
1068
1069#[derive(Debug)]
1076pub struct ColumnMatcherRewriter {
1077 pub schema: SchemaRef,
1078 pub is_rewritten: bool,
1079 pub primary_key_indices: Vec<usize>,
1080 pub allow_partial: bool,
1081}
1082
1083impl ColumnMatcherRewriter {
1084 pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1085 Self {
1086 schema,
1087 is_rewritten: false,
1088 primary_key_indices,
1089 allow_partial,
1090 }
1091 }
1092
1093 fn modify_project_exprs(
1095 &mut self,
1096 mut exprs: Vec<Expr>,
1097 input_schema: &DFSchema,
1098 ) -> DfResult<Vec<Expr>> {
1099 if self.allow_partial {
1100 return self.modify_project_exprs_with_partial(exprs);
1101 }
1102
1103 let original_exprs = exprs.clone();
1104
1105 let all_names = self
1106 .schema
1107 .column_schemas()
1108 .iter()
1109 .map(|c| c.name.clone())
1110 .collect::<BTreeSet<_>>();
1111 let query_col_cnt = exprs.len();
1113 let table_col_cnt = self.schema.column_schemas().len();
1114 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1115
1116 let placeholder_ts_expr =
1117 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1118 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1119
1120 if query_col_cnt == table_col_cnt {
1121 } else if query_col_cnt + 1 == table_col_cnt {
1123 let last_col_schema = self.schema.column_schemas().last().unwrap();
1124
1125 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1127 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1128 {
1129 exprs.push(placeholder_ts_expr);
1130 } else if last_col_schema.data_type.is_timestamp() {
1131 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1133 } else {
1134 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1135 &original_exprs,
1136 self.schema.as_ref(),
1137 )));
1138 }
1139 } else if query_col_cnt + 2 == table_col_cnt {
1140 let mut col_iter = self.schema.column_schemas().iter().rev();
1141 let last_col_schema = col_iter.next().unwrap();
1142 let second_last_col_schema = col_iter.next().unwrap();
1143 if second_last_col_schema.data_type.is_timestamp() {
1144 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1145 } else {
1146 return Err(DataFusionError::Plan(format!(
1147 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1148 second_last_col_schema.name, second_last_col_schema.data_type
1149 )));
1150 }
1151
1152 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1153 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1154 {
1155 exprs.push(placeholder_ts_expr);
1156 } else {
1157 return Err(DataFusionError::Plan(format!(
1158 "Expect timestamp column {}, found {:?}",
1159 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1160 )));
1161 }
1162 } else {
1163 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1164 &original_exprs,
1165 self.schema.as_ref(),
1166 )));
1167 }
1168
1169 self.match_extra_output_columns(exprs, input_schema, &original_exprs, &all_names)
1170 }
1171
1172 fn match_extra_output_columns(
1183 &self,
1184 mut exprs: Vec<Expr>,
1185 input_schema: &DFSchema,
1186 original_exprs: &[Expr],
1187 all_names: &BTreeSet<String>,
1188 ) -> DfResult<Vec<Expr>> {
1189 let mut output_names = exprs
1190 .iter()
1191 .map(|expr| expr.qualified_name().1)
1192 .collect::<Vec<_>>();
1193 let output_name_set = output_names.iter().cloned().collect::<BTreeSet<_>>();
1194 let extra_expr_indices = output_names
1195 .iter()
1196 .enumerate()
1197 .filter_map(|(idx, name)| (!all_names.contains(name)).then_some(idx))
1198 .collect::<Vec<_>>();
1199 let missing_sink_indices = self
1200 .schema
1201 .column_schemas()
1202 .iter()
1203 .enumerate()
1204 .filter_map(|(idx, column)| (!output_name_set.contains(&column.name)).then_some(idx))
1205 .collect::<Vec<_>>();
1206
1207 if extra_expr_indices.is_empty() && missing_sink_indices.is_empty() {
1208 return Ok(exprs);
1209 }
1210
1211 if extra_expr_indices.len() != missing_sink_indices.len() {
1212 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1213 original_exprs,
1214 self.schema.as_ref(),
1215 )));
1216 }
1217
1218 let mut positional_matches = Vec::new();
1219 for expr_idx in extra_expr_indices {
1220 if !missing_sink_indices.contains(&expr_idx) {
1221 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1222 original_exprs,
1223 self.schema.as_ref(),
1224 )));
1225 }
1226
1227 let target_col_schema = &self.schema.column_schemas()[expr_idx];
1228 let expr_type =
1229 ConcreteDataType::from_arrow_type(&exprs[expr_idx].get_type(input_schema)?);
1230 if is_obviously_incompatible_positional_match(&expr_type, &target_col_schema.data_type)
1231 {
1232 return Err(DataFusionError::Plan(format!(
1233 "Cannot match flow output column '{}' to sink column '{}' by position: incompatible data types, flow output type is {:?}, sink column type is {:?}. {}",
1234 output_names[expr_idx],
1235 target_col_schema.name,
1236 expr_type,
1237 target_col_schema.data_type,
1238 format_flow_sink_schema_mismatch(original_exprs, self.schema.as_ref())
1239 )));
1240 }
1241
1242 let target_name = target_col_schema.name.clone();
1243 positional_matches.push(format!(
1244 "{} -> {} (flow output type: {:?}, sink column type: {:?})",
1245 output_names[expr_idx], target_name, expr_type, target_col_schema.data_type
1246 ));
1247 exprs[expr_idx] = exprs[expr_idx].clone().alias(target_name.clone());
1248 output_names[expr_idx] = target_name;
1249 }
1250
1251 if !positional_matches.is_empty() {
1252 debug!(
1253 "Matched flow output columns to sink columns by position: {:?}",
1254 positional_matches
1255 );
1256 }
1257
1258 let duplicated_output_names = duplicate_names(&output_names);
1259 if !duplicated_output_names.is_empty() {
1260 return Err(DataFusionError::Plan(format!(
1261 "Flow output schema contains duplicate column(s) after schema matching {:?}. {}",
1262 duplicated_output_names,
1263 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1264 )));
1265 }
1266
1267 Ok(exprs)
1268 }
1269
1270 fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1271 let table_col_cnt = self.schema.column_schemas().len();
1272 let query_col_cnt = exprs.len();
1273
1274 if query_col_cnt > table_col_cnt {
1275 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1276 &exprs,
1277 self.schema.as_ref(),
1278 )));
1279 }
1280
1281 let name_to_expr: HashMap<String, Expr> = exprs
1282 .clone()
1283 .into_iter()
1284 .map(|e| (e.qualified_name().1, e))
1285 .collect();
1286
1287 let required_columns = self.required_columns_for_partial();
1288 let missing: Vec<_> = required_columns
1289 .iter()
1290 .filter(|name| !name_to_expr.contains_key(*name))
1291 .cloned()
1292 .collect();
1293 if !missing.is_empty() {
1294 return Err(DataFusionError::Plan(format!(
1295 "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null. {}",
1296 missing,
1297 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1298 )));
1299 }
1300
1301 let placeholder_ts_expr =
1302 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1303 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1304
1305 let timestamp_index = self.schema.timestamp_index();
1306 let mut remap = name_to_expr;
1307 let mut new_exprs = Vec::with_capacity(table_col_cnt);
1308
1309 for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1310 let col_name = col_schema.name.clone();
1311 if let Some(expr) = remap.remove(&col_name) {
1312 let expr = if expr.qualified_name().1 == col_name {
1313 expr
1314 } else {
1315 expr.alias(col_name.clone())
1316 };
1317 new_exprs.push(expr);
1318 continue;
1319 }
1320
1321 if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1322 new_exprs.push(placeholder_ts_expr.clone());
1323 continue;
1324 }
1325
1326 if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1327 new_exprs.push(datafusion::prelude::now().alias(&col_name));
1328 continue;
1329 }
1330
1331 new_exprs.push(Self::null_expr(col_schema));
1332 }
1333
1334 if !remap.is_empty() {
1335 let extra: Vec<_> = remap.keys().cloned().collect();
1336 return Err(DataFusionError::Plan(format!(
1337 "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null. {}",
1338 extra,
1339 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1340 )));
1341 }
1342
1343 Ok(new_exprs)
1344 }
1345
1346 fn null_expr(col_schema: &ColumnSchema) -> Expr {
1347 Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1348 }
1349
1350 fn required_columns_for_partial(&self) -> HashSet<String> {
1351 let mut required = HashSet::new();
1352 for idx in &self.primary_key_indices {
1353 if let Some(col) = self.schema.column_schemas().get(*idx) {
1354 required.insert(col.name.clone());
1355 }
1356 }
1357
1358 if let Some(ts_idx) = self.schema.timestamp_index()
1359 && let Some(col) = self.schema.column_schemas().get(ts_idx)
1360 && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1361 {
1362 required.insert(col.name.clone());
1363 }
1364
1365 required
1366 }
1367}
1368
1369fn is_obviously_incompatible_positional_match(
1370 expr_type: &ConcreteDataType,
1371 sink_type: &ConcreteDataType,
1372) -> bool {
1373 if expr_type.is_null() || expr_type == sink_type {
1378 return false;
1379 }
1380
1381 expr_type.is_timestamp() != sink_type.is_timestamp()
1382 || expr_type.is_string() != sink_type.is_string()
1383 || expr_type.is_boolean() != sink_type.is_boolean()
1384 || expr_type.is_json() != sink_type.is_json()
1385 || expr_type.is_vector() != sink_type.is_vector()
1386}
1387
1388fn duplicate_names(names: &[String]) -> Vec<String> {
1389 let mut seen = HashSet::new();
1390 let mut duplicated = BTreeSet::new();
1391 for name in names {
1392 if !seen.insert(name.as_str()) {
1393 duplicated.insert(name.as_str());
1394 }
1395 }
1396 duplicated.into_iter().map(str::to_string).collect()
1397}
1398
1399fn format_flow_sink_schema_mismatch(
1400 query_exprs: &[Expr],
1401 sink_schema: &datatypes::schema::Schema,
1402) -> String {
1403 let flow_output_columns = query_exprs
1404 .iter()
1405 .map(|expr| expr.qualified_name().1)
1406 .collect::<Vec<_>>();
1407 let sink_table_columns = sink_schema
1408 .column_schemas()
1409 .iter()
1410 .map(|col| col.name.clone())
1411 .collect::<Vec<_>>();
1412
1413 let flow_output_set = flow_output_columns.iter().cloned().collect::<HashSet<_>>();
1414 let sink_table_set = sink_table_columns.iter().cloned().collect::<HashSet<_>>();
1415
1416 let mut extra_flow_columns = flow_output_columns
1417 .iter()
1418 .filter(|name| !sink_table_set.contains(*name))
1419 .cloned()
1420 .collect::<Vec<_>>();
1421 extra_flow_columns.sort();
1422 extra_flow_columns.dedup();
1423
1424 let mut missing_sink_columns = sink_table_columns
1425 .iter()
1426 .filter(|name| !flow_output_set.contains(*name))
1427 .cloned()
1428 .collect::<Vec<_>>();
1429 missing_sink_columns.sort();
1430 missing_sink_columns.dedup();
1431
1432 format!(
1433 "Flow output schema does not match sink table schema: found {} flow output columns and {} sink table columns. flow output columns: {:?}, sink table columns: {:?}, extra flow columns not in sink: {:?}, missing sink columns from flow output: {:?}",
1434 flow_output_columns.len(),
1435 sink_table_columns.len(),
1436 flow_output_columns,
1437 sink_table_columns,
1438 extra_flow_columns,
1439 missing_sink_columns
1440 )
1441}
1442
1443impl TreeNodeRewriter for ColumnMatcherRewriter {
1444 type Node = LogicalPlan;
1445 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1446 if self.is_rewritten {
1447 return Ok(Transformed::no(node));
1448 }
1449
1450 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1452 let mut exprs = vec![];
1453
1454 for field in node.schema().fields().iter() {
1455 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1456 field.name(),
1457 )));
1458 }
1459
1460 let projection =
1461 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1462
1463 node = projection;
1464 }
1465 else if let LogicalPlan::TableScan(table_scan) = node {
1467 let mut exprs = vec![];
1468
1469 for field in table_scan.projected_schema.fields().iter() {
1470 exprs.push(Expr::Column(datafusion::common::Column::new(
1471 Some(table_scan.table_name.clone()),
1472 field.name(),
1473 )));
1474 }
1475
1476 let projection = LogicalPlan::Projection(Projection::try_new(
1477 exprs,
1478 Arc::new(LogicalPlan::TableScan(table_scan)),
1479 )?);
1480
1481 node = projection;
1482 }
1483
1484 if let LogicalPlan::Projection(project) = &node {
1488 let exprs = project.expr.clone();
1489 let exprs = self.modify_project_exprs(exprs, project.input.schema())?;
1490
1491 self.is_rewritten = true;
1492 let new_plan =
1493 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1494 Ok(Transformed::yes(new_plan))
1495 } else {
1496 let mut exprs = vec![];
1498 for field in node.schema().fields().iter() {
1499 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1500 field.name(),
1501 )));
1502 }
1503 let exprs = self.modify_project_exprs(exprs, node.schema())?;
1504 self.is_rewritten = true;
1505 let new_plan =
1506 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1507 Ok(Transformed::yes(new_plan))
1508 }
1509 }
1510
1511 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1513 node.recompute_schema().map(Transformed::yes)
1514 }
1515}
1516
1517#[derive(Debug)]
1519pub struct AddFilterRewriter {
1520 extra_filter: Expr,
1521 is_rewritten: bool,
1522}
1523
1524impl AddFilterRewriter {
1525 pub fn new(filter: Expr) -> Self {
1526 Self {
1527 extra_filter: filter,
1528 is_rewritten: false,
1529 }
1530 }
1531}
1532
1533impl TreeNodeRewriter for AddFilterRewriter {
1534 type Node = LogicalPlan;
1535 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1536 if self.is_rewritten {
1537 return Ok(Transformed::no(node));
1538 }
1539 match node {
1540 LogicalPlan::Filter(mut filter) => {
1541 filter.predicate = filter.predicate.and(self.extra_filter.clone());
1542 self.is_rewritten = true;
1543 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1544 }
1545 LogicalPlan::TableScan(_) => {
1546 let filter =
1548 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1549 self.is_rewritten = true;
1550 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1551 }
1552 _ => Ok(Transformed::no(node)),
1553 }
1554 }
1555}