Skip to main content

flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32    Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36    Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
37    bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::schema::{ColumnSchema, SchemaRef};
40use query::QueryEngineRef;
41use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
42use session::context::QueryContextRef;
43use snafu::{OptionExt, ResultExt, ensure};
44use sql::parser::{ParseOptions, ParserContext};
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47use table::TableRef;
48use table::table::adapter::DfTableProviderAdapter;
49
50use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
51use crate::df_optimizer::apply_df_optimizer;
52use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
53use crate::{Error, TableName};
54
55#[cfg(test)]
56mod test;
57
58/// Describes how one aggregate output field should be merged with the
59/// corresponding existing field in the sink table.
60///
61/// `output_field_name` is the final output/sink schema field name produced by
62/// the delta plan and read from the sink table. It is not a DataFusion `Column`
63/// reference. It may contain dots or other non-identifier characters when the
64/// query keeps DataFusion's raw aggregate output name, e.g.
65/// `max(numbers_with_ts.number)`.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct IncrementalAggregateMergeColumn {
68    /// Final output/sink field name for the aggregate result/state column.
69    ///
70    pub output_field_name: String,
71    pub merge_op: IncrementalAggregateMergeOp,
72}
73
74impl IncrementalAggregateMergeColumn {
75    /// Create a new merge column.
76    pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
77        Self {
78            output_field_name,
79            merge_op,
80        }
81    }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum IncrementalAggregateMergeOp {
86    Sum,
87    Min,
88    Max,
89    BoolAnd,
90    BoolOr,
91    BitAnd,
92    BitOr,
93    BitXor,
94}
95
96/// Analysis result for an incremental aggregate plan.
97///
98/// `group_key_names` and each merge column's `output_field_name` are final
99/// output/sink schema field names used to project both the delta plan and the
100/// sink table before the left-join merge. They are not DataFusion logical-plan
101/// `Column` references; callers must attach qualifiers structurally instead of
102/// formatting qualified names as strings.
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct IncrementalAggregateAnalysis {
105    /// Final output/sink field names for group keys used as merge join keys.
106    pub group_key_names: Vec<String>,
107    pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
108    /// Literal output fields that can be passed through from the delta plan.
109    pub literal_columns: Vec<String>,
110    /// Final output field order from the original aggregate plan.
111    pub output_field_names: Vec<String>,
112    pub unsupported_exprs: Vec<String>,
113}
114
115/// Recursively find all `Expr::Column` names inside an expression tree.
116/// Only recurses into wrappers that are merge-transparent.
117/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`, `Cast`) are
118/// intentionally not recursed into since their merge semantics would be
119/// incorrect.
120///
121/// `Cast`/`TryCast` are intentionally opaque: merging already-casted aggregate
122/// outputs is not generally equivalent to casting the final merged aggregate.
123fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
124    match expr {
125        Expr::Column(col) => {
126            names.push(col.name.clone());
127        }
128        Expr::Alias(alias) => find_column_names(&alias.expr, names),
129        _ => {}
130    }
131}
132
133fn unqualified_col(name: impl Into<String>) -> Expr {
134    Expr::Column(Column::from_name(name.into()))
135}
136
137fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
138    Expr::Column(Column::new(Some(qualifier), name.into()))
139}
140
141fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
142    Column::new(Some(qualifier), name.into())
143}
144
145fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
146    let mut group_finder = FindGroupByFinalName::default();
147    plan.visit(&mut group_finder)
148        .with_context(|_| DatafusionSnafu {
149            context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
150        })?;
151
152    let mut group_key_names = group_finder
153        .get_group_expr_names()
154        .unwrap_or_default()
155        .into_iter()
156        .collect::<Vec<_>>();
157    group_key_names.sort();
158    Ok(group_key_names)
159}
160
161fn has_grouping_set(plan: &LogicalPlan) -> bool {
162    match plan {
163        LogicalPlan::Aggregate(aggregate) => aggregate
164            .group_expr
165            .iter()
166            .any(|expr| matches!(expr, Expr::GroupingSet(_))),
167        _ => plan.inputs().into_iter().any(has_grouping_set),
168    }
169}
170
171fn has_aggregate(plan: &LogicalPlan) -> bool {
172    match plan {
173        LogicalPlan::Aggregate(_) => true,
174        _ => plan.inputs().into_iter().any(has_aggregate),
175    }
176}
177
178fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
179    while let LogicalPlan::SubqueryAlias(alias) = plan {
180        plan = alias.input.as_ref();
181    }
182    plan
183}
184
185fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
186    // Supported final shape: optional output Projection directly over one
187    // Aggregate. Post-aggregate filters (HAVING), ordering, limits,
188    // distinct/window/union/extension nodes are intentionally not accepted.
189    let plan = match plan {
190        LogicalPlan::Projection(projection) => projection.input.as_ref(),
191        _ => plan,
192    };
193
194    match plan {
195        LogicalPlan::Aggregate(aggregate) => {
196            check_input_plan_shape(aggregate.input.as_ref())?;
197            Ok(Some(aggregate))
198        }
199        LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
200            "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
201                .to_string(),
202        ),
203        _ if has_aggregate(plan) => Err(
204            "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
205        ),
206        _ => Ok(None),
207    }
208}
209
210fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
211    let plan = peel_subquery_aliases(plan);
212    match plan {
213        // Supported aggregate input shape: optional WHERE filter over a table scan.
214        // SubqueryAlias is a transparent naming wrapper for `FROM table AS alias`.
215        LogicalPlan::TableScan(_) => Ok(()),
216        LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
217            LogicalPlan::TableScan(_) => Ok(()),
218            _ => Err(
219                "unsupported aggregate input plan shape in incremental aggregate rewrite"
220                    .to_string(),
221            ),
222        },
223        _ => Err(
224            "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
225        ),
226    }
227}
228
229#[derive(Debug, Default)]
230struct OutputProjectionInfo {
231    has_top_level_projection: bool,
232    output_aliases: HashMap<String, String>,
233    duplicate_aggregate_aliases: BTreeSet<String>,
234    literal_columns: HashSet<String>,
235    output_field_names: Vec<String>,
236}
237
238impl OutputProjectionInfo {
239    fn output_field_name_set(&self) -> HashSet<String> {
240        self.output_field_names.iter().cloned().collect()
241    }
242
243    fn duplicate_output_names(&self) -> Vec<String> {
244        let mut seen = HashSet::new();
245        let mut duplicates = BTreeSet::new();
246        for name in &self.output_field_names {
247            if !seen.insert(name.clone()) {
248                duplicates.insert(name.clone());
249            }
250        }
251        duplicates.into_iter().collect()
252    }
253}
254
255fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
256    let mut projection_info = OutputProjectionInfo {
257        has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
258        output_field_names: plan
259            .schema()
260            .fields()
261            .iter()
262            .map(|field| field.name().clone())
263            .collect(),
264        ..Default::default()
265    };
266
267    let mut output_aliases = HashMap::new();
268    if let LogicalPlan::Projection(projection) = plan {
269        for expr in &projection.expr {
270            match expr {
271                Expr::Alias(alias) => {
272                    // Alias resolution has three cases:
273                    // - 0 Column refs (e.g., literal `42 AS lit`): record literal output
274                    // - 1 Column ref: record the mapping (e.g., `sum(x) AS total`)
275                    // - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`):
276                    //   skip — ambiguous merge semantics
277                    let alias_name = alias.name.clone();
278                    let mut col_names = Vec::new();
279                    find_column_names(&alias.expr, &mut col_names);
280                    match col_names.len() {
281                        0 if is_passthrough_output_column(&alias_name, alias.expr.as_ref()) => {
282                            projection_info.literal_columns.insert(alias_name);
283                        }
284                        1 => {
285                            if let Some(col_name) = col_names.into_iter().next() {
286                                if let Some(existing_alias) = output_aliases.get(&col_name) {
287                                    if existing_alias != &alias_name {
288                                        projection_info.duplicate_aggregate_aliases.insert(format!(
289                                            "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
290                                        ));
291                                    }
292                                } else {
293                                    output_aliases.insert(col_name, alias_name);
294                                }
295                            }
296                        }
297                        _ => {}
298                    }
299
300                    // If >1 column references detected (e.g., COALESCE(sum(x), sum(y))),
301                    // intentionally skip alias mapping — the merge semantics are ambiguous.
302                }
303                Expr::Column(col) => {
304                    output_aliases
305                        .entry(col.name.clone())
306                        .or_insert(col.name.clone());
307                }
308                Expr::Literal(_, _) => {
309                    projection_info
310                        .literal_columns
311                        .insert(expr.qualified_name().1);
312                }
313                _ => {}
314            }
315        }
316    }
317
318    if projection_info
319        .output_field_names
320        .iter()
321        .any(|name| name == AUTO_CREATED_PLACEHOLDER_TS_COL)
322    {
323        projection_info
324            .literal_columns
325            .insert(AUTO_CREATED_PLACEHOLDER_TS_COL.to_string());
326    }
327
328    projection_info.output_aliases = output_aliases;
329    projection_info
330}
331
332fn is_passthrough_output_column(alias_name: &str, expr: &Expr) -> bool {
333    matches!(expr, Expr::Literal(_, _))
334        || match alias_name {
335            AUTO_CREATED_UPDATE_AT_TS_COL => expr == &datafusion::prelude::now(),
336            AUTO_CREATED_PLACEHOLDER_TS_COL => is_literal_or_cast_literal(expr),
337            _ => false,
338        }
339}
340
341fn is_literal_or_cast_literal(expr: &Expr) -> bool {
342    match expr {
343        Expr::Literal(_, _) => true,
344        Expr::Cast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
345        Expr::TryCast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
346        _ => false,
347    }
348}
349
350fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
351    let Some(aggr_func) = get_aggr_func(aggr_expr) else {
352        return Err(aggr_expr.to_string());
353    };
354    if aggr_func.params.distinct {
355        return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
356    }
357    if !aggr_func.params.order_by.is_empty() {
358        return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
359    }
360    if aggr_func.params.null_treatment.is_some() {
361        return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
362    }
363
364    match aggr_func.func.name().to_ascii_lowercase().as_str() {
365        "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
366        "min" => Ok(IncrementalAggregateMergeOp::Min),
367        "max" => Ok(IncrementalAggregateMergeOp::Max),
368        "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
369        "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
370        "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
371        "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
372        "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
373        _ => Err(aggr_expr.to_string()),
374    }
375}
376
377fn resolve_aggregate_output_field_name(
378    aggr_expr: &Expr,
379    projection_info: &OutputProjectionInfo,
380    output_field_name_set: &HashSet<String>,
381) -> Option<String> {
382    // qualified_name() returns (Option<String>, String) where the second
383    // element is the unqualified column/alias name. This relies on
384    // DataFusion's internal naming convention: aggregate expressions
385    // emit a column named after the aggregate itself (e.g. "SUM(x)"),
386    // which matches what the projection aliases reference.
387    let raw_name = aggr_expr.qualified_name().1;
388    if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
389        Some(alias.clone())
390    } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
391    {
392        Some(raw_name)
393    } else {
394        None
395    }
396}
397
398fn find_uncovered_output_fields(
399    projection_info: &OutputProjectionInfo,
400    group_key_names: &[String],
401    merge_columns: &[IncrementalAggregateMergeColumn],
402) -> Vec<String> {
403    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
404    let merge_column_names = merge_columns
405        .iter()
406        .map(|c| c.output_field_name.clone())
407        .collect::<HashSet<_>>();
408
409    projection_info
410        .output_field_names
411        .iter()
412        .filter(|name| {
413            !group_key_names.contains(*name)
414                && !merge_column_names.contains(*name)
415                && !projection_info.literal_columns.contains(*name)
416                // Auto-created sink columns injected by ColumnMatcherRewriter
417                // are not part of the original aggregate semantics and must
418                // not prevent incremental aggregate rewrites.
419                && name.as_str() != AUTO_CREATED_UPDATE_AT_TS_COL
420                && name.as_str() != AUTO_CREATED_PLACEHOLDER_TS_COL
421        })
422        .cloned()
423        .collect()
424}
425
426fn find_unsupported_group_key_projection_outputs(
427    plan: &LogicalPlan,
428    aggregate: &Aggregate,
429    group_key_names: &[String],
430) -> Vec<String> {
431    let LogicalPlan::Projection(projection) = plan else {
432        return vec![];
433    };
434
435    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
436    let group_expr_names = aggregate
437        .group_expr
438        .iter()
439        .filter_map(|expr| expr.name_for_alias().ok())
440        .collect::<HashSet<_>>();
441    projection
442        .expr
443        .iter()
444        .filter_map(|expr| {
445            let output_name = expr.qualified_name().1;
446            if !group_key_names.contains(&output_name) {
447                return None;
448            }
449
450            let source_name = match expr {
451                Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
452                _ => expr.name_for_alias().ok(),
453            };
454            if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
455                None
456            } else {
457                Some(format!(
458                    "unsupported group key output field is not a transparent group expression: {output_name}"
459                ))
460            }
461        })
462        .collect()
463}
464
465pub fn analyze_incremental_aggregate_plan(
466    plan: &LogicalPlan,
467) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
468    let group_key_names = find_group_key_names(plan)?;
469    let aggregate = match extract_incremental_aggregate(plan) {
470        Ok(Some(aggregate)) => aggregate,
471        Ok(None) => return Ok(None),
472        Err(reason) => {
473            let projection_info = collect_output_projection_info(plan);
474            let mut unsupported_exprs = projection_info
475                .duplicate_output_names()
476                .into_iter()
477                .map(|name| format!("duplicate output field name: {name}"))
478                .collect::<Vec<_>>();
479            unsupported_exprs.push(reason);
480            unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
481            return Ok(Some(IncrementalAggregateAnalysis {
482                group_key_names,
483                merge_columns: vec![],
484                literal_columns: vec![],
485                output_field_names: projection_info.output_field_names,
486                unsupported_exprs,
487            }));
488        }
489    };
490    let aggr_exprs = aggregate.aggr_expr.clone();
491    let projection_info = collect_output_projection_info(plan);
492    let output_field_name_set = projection_info.output_field_name_set();
493
494    let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
495    let mut unsupported_exprs = projection_info
496        .duplicate_output_names()
497        .into_iter()
498        .map(|name| format!("duplicate output field name: {name}"))
499        .collect::<Vec<_>>();
500    if has_grouping_set(plan) {
501        unsupported_exprs.push(
502            "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
503        );
504    }
505    if group_key_names.is_empty() {
506        unsupported_exprs
507            .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
508    }
509    unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
510        plan,
511        aggregate,
512        &group_key_names,
513    ));
514    unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
515    for aggr_expr in aggr_exprs {
516        let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
517            Ok(merge_op) => merge_op,
518            Err(reason) => {
519                unsupported_exprs.push(reason);
520                continue;
521            }
522        };
523        let Some(output_field_name) = resolve_aggregate_output_field_name(
524            &aggr_expr,
525            &projection_info,
526            &output_field_name_set,
527        ) else {
528            unsupported_exprs.push(aggr_expr.to_string());
529            continue;
530        };
531        merge_columns.push(IncrementalAggregateMergeColumn::new(
532            output_field_name,
533            merge_op,
534        ));
535    }
536    unsupported_exprs.extend(
537        find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
538            .into_iter()
539            .map(|name| format!("unsupported output field: {name}")),
540    );
541    if !unsupported_exprs.is_empty() {
542        merge_columns.clear();
543    }
544    let mut literal_columns = projection_info
545        .literal_columns
546        .into_iter()
547        .collect::<Vec<_>>();
548    literal_columns.sort();
549
550    Ok(Some(IncrementalAggregateAnalysis {
551        group_key_names,
552        merge_columns,
553        literal_columns,
554        output_field_names: projection_info.output_field_names,
555        unsupported_exprs,
556    }))
557}
558
559/// Rewrites one incremental aggregate delta plan by left-joining it with the
560/// existing sink-table state and projecting merged aggregate outputs.
561///
562/// For a grouped aggregate such as:
563///
564/// ```text
565/// SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts
566/// ```
567///
568/// the rewrite is roughly:
569///
570/// ```text
571/// delta = SELECT ts, number FROM <delta_plan> AS __flow_delta
572/// sink_scan = SELECT * FROM <sink_table> [WHERE <sink_dirty_filter>]
573/// sink  = SELECT ts, number FROM sink_scan AS __flow_sink
574/// SELECT
575///   CASE
576///     WHEN __flow_sink.number IS NULL THEN __flow_delta.number
577///     WHEN __flow_delta.number >= __flow_sink.number THEN __flow_delta.number
578///     ELSE __flow_sink.number
579///   END AS number,
580///   __flow_delta.ts AS ts
581/// FROM delta
582/// LEFT JOIN sink
583///   ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts
584/// ```
585///
586/// If `sink_dirty_filter` is provided, it is applied to the sink table scan
587/// before projection, aliasing, and the left join. The predicate must reference
588/// raw sink table columns structurally (unqualified), before the `__flow_sink`
589/// alias exists.
590pub async fn rewrite_incremental_aggregate_with_sink_merge(
591    delta_plan: &LogicalPlan,
592    analysis: &IncrementalAggregateAnalysis,
593    sink_table: TableRef,
594    sink_table_name: &TableName,
595    sink_dirty_filter: Option<Expr>,
596) -> Result<LogicalPlan, Error> {
597    ensure!(
598        analysis.unsupported_exprs.is_empty(),
599        InvalidQuerySnafu {
600            reason: format!(
601                "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
602                analysis.unsupported_exprs
603            )
604        }
605    );
606
607    ensure!(
608        !analysis.merge_columns.is_empty(),
609        InvalidQuerySnafu {
610            reason:
611                "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
612                    .to_string()
613        }
614    );
615
616    ensure!(
617        !analysis.group_key_names.is_empty(),
618        InvalidQuerySnafu {
619            reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
620                .to_string()
621        }
622    );
623
624    let delta_alias = "__flow_delta";
625    let sink_alias = "__flow_sink";
626
627    let mut selected_columns = analysis.group_key_names.clone();
628    selected_columns.extend(
629        analysis
630            .merge_columns
631            .iter()
632            .map(|c| c.output_field_name.clone()),
633    );
634    let mut delta_selected_columns = selected_columns.clone();
635    delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
636
637    let delta_selected_exprs = delta_selected_columns
638        .iter()
639        .cloned()
640        .map(unqualified_col)
641        .collect::<Vec<_>>();
642    let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
643        .project(delta_selected_exprs)
644        .with_context(|_| DatafusionSnafu {
645            context: "Failed to project delta plan for incremental sink merge".to_string(),
646        })?
647        .alias(delta_alias)
648        .with_context(|_| DatafusionSnafu {
649            context: "Failed to alias delta plan for incremental sink merge".to_string(),
650        })?
651        .build()
652        .with_context(|_| DatafusionSnafu {
653            context: "Failed to build projected delta plan for incremental sink merge".to_string(),
654        })?;
655
656    let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
657    let table_source = Arc::new(DefaultTableSource::new(table_provider));
658    let sink_scan = LogicalPlan::TableScan(
659        TableScan::try_new(
660            TableReference::Full {
661                catalog: sink_table_name[0].clone().into(),
662                schema: sink_table_name[1].clone().into(),
663                table: sink_table_name[2].clone().into(),
664            },
665            table_source,
666            None,
667            vec![],
668            None,
669        )
670        .with_context(|_| DatafusionSnafu {
671            context: "Failed to build sink table scan for incremental sink merge".to_string(),
672        })?,
673    );
674
675    let sink_selected_exprs = selected_columns
676        .iter()
677        .cloned()
678        .map(unqualified_col)
679        .collect::<Vec<_>>();
680    let sink_input = if let Some(predicate) = sink_dirty_filter {
681        LogicalPlanBuilder::from(sink_scan)
682            .filter(predicate)
683            .with_context(|_| DatafusionSnafu {
684                context: "Failed to filter sink table scan for incremental sink merge".to_string(),
685            })?
686            .build()
687            .with_context(|_| DatafusionSnafu {
688                context: "Failed to build filtered sink plan for incremental sink merge"
689                    .to_string(),
690            })?
691    } else {
692        sink_scan
693    };
694
695    let sink_selected = LogicalPlanBuilder::from(sink_input)
696        .project(sink_selected_exprs)
697        .with_context(|_| DatafusionSnafu {
698            context: "Failed to project sink table scan for incremental sink merge".to_string(),
699        })?
700        .alias(sink_alias)
701        .with_context(|_| DatafusionSnafu {
702            context: "Failed to alias sink plan for incremental sink merge".to_string(),
703        })?
704        .build()
705        .with_context(|_| DatafusionSnafu {
706            context: "Failed to build projected sink plan for incremental sink merge".to_string(),
707        })?;
708
709    let join_keys = (
710        analysis
711            .group_key_names
712            .iter()
713            .cloned()
714            .map(|c| qualified_column(delta_alias, c))
715            .collect::<Vec<_>>(),
716        analysis
717            .group_key_names
718            .iter()
719            .cloned()
720            .map(|c| qualified_column(sink_alias, c))
721            .collect::<Vec<_>>(),
722    );
723
724    let joined = LogicalPlanBuilder::from(delta_selected)
725        .join_detailed(
726            sink_selected,
727            JoinType::Left,
728            join_keys,
729            None,
730            NullEquality::NullEqualsNull,
731        )
732        .with_context(|_| DatafusionSnafu {
733            context: "Failed to left join delta and sink plans for incremental sink merge"
734                .to_string(),
735        })?
736        .build()
737        .with_context(|_| DatafusionSnafu {
738            context: "Failed to build left join plan for incremental sink merge".to_string(),
739        })?;
740
741    let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
742    let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
743    let merge_columns = analysis
744        .merge_columns
745        .iter()
746        .map(|c| (&c.output_field_name, c))
747        .collect::<HashMap<_, _>>();
748
749    let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
750    for output_field_name in &analysis.output_field_names {
751        if group_key_names.contains(output_field_name)
752            || literal_columns.contains(output_field_name)
753        {
754            projection_exprs.push(
755                qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
756            );
757        } else if let Some(merge_col) = merge_columns.get(output_field_name) {
758            projection_exprs.push(build_left_join_merge_expr(
759                delta_alias,
760                sink_alias,
761                merge_col,
762            )?);
763        } else {
764            return InvalidQuerySnafu {
765                reason: format!(
766                    "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
767                ),
768            }
769            .fail();
770        }
771    }
772
773    LogicalPlanBuilder::from(joined)
774        .project(projection_exprs)
775        .with_context(|_| DatafusionSnafu {
776            context: "Failed to build projection merge plan for incremental sink merge".to_string(),
777        })?
778        .build()
779        .with_context(|_| DatafusionSnafu {
780            context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
781        })
782}
783
784fn build_left_join_merge_expr(
785    delta_alias: &str,
786    sink_alias: &str,
787    merge_col: &IncrementalAggregateMergeColumn,
788) -> Result<Expr, Error> {
789    let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
790    let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
791    let merged = match merge_col.merge_op {
792        IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
793            .when(is_null(right.clone()), left.clone())
794            .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
795            .with_context(|_| DatafusionSnafu {
796                context: "Failed to build SUM merge expression".to_string(),
797            })?,
798        IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
799            .when(left.clone().lt_eq(right.clone()), left.clone())
800            .otherwise(right.clone())
801            .with_context(|_| DatafusionSnafu {
802                context: "Failed to build MIN merge expression".to_string(),
803            })?,
804        IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
805            .when(left.clone().gt_eq(right.clone()), left.clone())
806            .otherwise(right.clone())
807            .with_context(|_| DatafusionSnafu {
808                context: "Failed to build MAX merge expression".to_string(),
809            })?,
810        IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
811            .when(is_null(right.clone()), left.clone())
812            .otherwise(and(left.clone(), right.clone()))
813            .with_context(|_| DatafusionSnafu {
814                context: "Failed to build BOOL_AND merge expression".to_string(),
815            })?,
816        IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
817            .when(is_null(right.clone()), left.clone())
818            .otherwise(or(left.clone(), right.clone()))
819            .with_context(|_| DatafusionSnafu {
820                context: "Failed to build BOOL_OR merge expression".to_string(),
821            })?,
822        IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
823            .when(is_null(right.clone()), left.clone())
824            .otherwise(bitwise_and(left.clone(), right.clone()))
825            .with_context(|_| DatafusionSnafu {
826                context: "Failed to build BIT_AND merge expression".to_string(),
827            })?,
828        IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
829            .when(is_null(right.clone()), left.clone())
830            .otherwise(bitwise_or(left.clone(), right.clone()))
831            .with_context(|_| DatafusionSnafu {
832                context: "Failed to build BIT_OR merge expression".to_string(),
833            })?,
834        IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
835            .when(is_null(right.clone()), left.clone())
836            .otherwise(bitwise_xor(left.clone(), right.clone()))
837            .with_context(|_| DatafusionSnafu {
838                context: "Failed to build BIT_XOR merge expression".to_string(),
839            })?,
840    };
841    Ok(merged.alias(merge_col.output_field_name.clone()))
842}
843
844pub async fn get_table_info_df_schema(
845    catalog_mr: CatalogManagerRef,
846    table_name: TableName,
847) -> Result<(TableRef, Arc<DFSchema>), Error> {
848    let full_table_name = table_name.clone().join(".");
849    let table = catalog_mr
850        .table(&table_name[0], &table_name[1], &table_name[2], None)
851        .await
852        .map_err(BoxedError::new)
853        .context(ExternalSnafu)?
854        .context(TableNotFoundSnafu {
855            name: &full_table_name,
856        })?;
857    let table_info = table.table_info();
858
859    let schema = table_info.meta.schema.clone();
860
861    let df_schema: Arc<DFSchema> = Arc::new(
862        schema
863            .arrow_schema()
864            .clone()
865            .try_into()
866            .with_context(|_| DatafusionSnafu {
867                context: format!(
868                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
869                    schema.arrow_schema()
870                ),
871            })?,
872    );
873    Ok((table, df_schema))
874}
875
876/// Convert sql to datafusion logical plan
877/// Also support TQL (but only Eval not Explain or Analyze)
878pub async fn sql_to_df_plan(
879    query_ctx: QueryContextRef,
880    engine: QueryEngineRef,
881    sql: &str,
882    optimize: bool,
883) -> Result<LogicalPlan, Error> {
884    let stmts =
885        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
886            .map_err(BoxedError::new)
887            .context(ExternalSnafu)?;
888
889    ensure!(
890        stmts.len() == 1,
891        InvalidQuerySnafu {
892            reason: format!("Expect only one statement, found {}", stmts.len())
893        }
894    );
895    let stmt = &stmts[0];
896    let query_stmt = match stmt {
897        Statement::Tql(tql) => match tql {
898            Tql::Eval(eval) => {
899                let eval = eval.clone();
900                let promql = PromQuery {
901                    start: eval.start,
902                    end: eval.end,
903                    step: eval.step,
904                    query: eval.query,
905                    lookback: eval
906                        .lookback
907                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
908                    alias: eval.alias.clone(),
909                };
910
911                QueryLanguageParser::parse_promql(&promql, &query_ctx)
912                    .map_err(BoxedError::new)
913                    .context(ExternalSnafu)?
914            }
915            _ => InvalidQuerySnafu {
916                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
917            }
918            .fail()?,
919        },
920        _ => QueryStatement::Sql(stmt.clone()),
921    };
922    let plan = engine
923        .planner()
924        .plan(&query_stmt, query_ctx.clone())
925        .await
926        .map_err(BoxedError::new)
927        .context(ExternalSnafu)?;
928
929    let plan = if optimize {
930        apply_df_optimizer(plan, &query_ctx).await?
931    } else {
932        plan
933    };
934    Ok(plan)
935}
936
937/// Generate a plan that matches the schema of the sink table
938/// from given sql by alias and adding auto columns
939pub(crate) async fn gen_plan_with_matching_schema(
940    sql: &str,
941    query_ctx: QueryContextRef,
942    engine: QueryEngineRef,
943    sink_table_schema: SchemaRef,
944    primary_key_indices: &[usize],
945    allow_partial: bool,
946) -> Result<LogicalPlan, Error> {
947    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
948
949    let mut add_auto_column = ColumnMatcherRewriter::new(
950        sink_table_schema,
951        primary_key_indices.to_vec(),
952        allow_partial,
953    );
954    let plan = plan
955        .clone()
956        .rewrite(&mut add_auto_column)
957        .with_context(|_| DatafusionSnafu {
958            context: format!("Failed to rewrite plan:\n {}\n", plan),
959        })?
960        .data;
961    Ok(plan)
962}
963
964pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
965    /// A dialect that forces identifiers to be quoted when have uppercase
966    struct ForceQuoteIdentifiers;
967    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
968        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
969            if identifier.to_lowercase() != identifier {
970                Some('`')
971            } else {
972                None
973            }
974        }
975    }
976    let unparser = Unparser::new(&ForceQuoteIdentifiers);
977    // first make all column qualified
978    let sql = unparser
979        .plan_to_sql(plan)
980        .with_context(|_e| DatafusionSnafu {
981            context: format!("Failed to unparse logical plan {plan:?}"),
982        })?;
983    Ok(sql.to_string())
984}
985
986/// Helper to find the innermost group by expr in schema, return None if no group by expr
987#[derive(Debug, Clone, Default)]
988pub struct FindGroupByFinalName {
989    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
990}
991
992impl FindGroupByFinalName {
993    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
994        self.group_exprs
995            .as_ref()
996            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
997    }
998}
999
1000impl TreeNodeVisitor<'_> for FindGroupByFinalName {
1001    type Node = LogicalPlan;
1002
1003    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1004        if let LogicalPlan::Aggregate(aggregate) = node {
1005            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
1006            debug!(
1007                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
1008                self.group_exprs
1009            );
1010        } else if let LogicalPlan::Distinct(distinct) = node {
1011            debug!("FindGroupByFinalName: Distinct: {}", node);
1012            match distinct {
1013                Distinct::All(input) => {
1014                    if let LogicalPlan::TableScan(table_scan) = &**input {
1015                        // get column from field_qualifier, projection and projected_schema:
1016                        let len = table_scan.projected_schema.fields().len();
1017                        let columns = (0..len)
1018                            .map(|f| {
1019                                let (qualifier, field) =
1020                                    table_scan.projected_schema.qualified_field(f);
1021                                datafusion_common::Column::new(qualifier.cloned(), field.name())
1022                            })
1023                            .map(datafusion_expr::Expr::Column);
1024                        self.group_exprs = Some(columns.collect());
1025                    } else {
1026                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
1027                    }
1028                }
1029                Distinct::On(distinct_on) => {
1030                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
1031                }
1032            }
1033            debug!(
1034                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
1035                self.group_exprs
1036            );
1037        }
1038
1039        Ok(TreeNodeRecursion::Continue)
1040    }
1041
1042    /// deal with projection when going up with group exprs
1043    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1044        if let LogicalPlan::Projection(projection) = node {
1045            for expr in &projection.expr {
1046                let Some(group_exprs) = &mut self.group_exprs else {
1047                    return Ok(TreeNodeRecursion::Continue);
1048                };
1049                if let datafusion_expr::Expr::Alias(alias) = expr {
1050                    // if a alias exist, replace with the new alias
1051                    let mut new_group_exprs = group_exprs.clone();
1052                    for group_expr in group_exprs.iter() {
1053                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
1054                            new_group_exprs.remove(group_expr);
1055                            new_group_exprs.insert(expr.clone());
1056                            break;
1057                        }
1058                    }
1059                    *group_exprs = new_group_exprs;
1060                }
1061            }
1062        }
1063        debug!("Aliased group by exprs: {:?}", self.group_exprs);
1064        Ok(TreeNodeRecursion::Continue)
1065    }
1066}
1067
1068/// Optionally add to the final select columns like `update_at` if the sink table has such column
1069/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
1070/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
1071/// with values like `now()` and `0`
1072///
1073/// it also give existing columns alias to column in sink table if needed
1074#[derive(Debug)]
1075pub struct ColumnMatcherRewriter {
1076    pub schema: SchemaRef,
1077    pub is_rewritten: bool,
1078    pub primary_key_indices: Vec<usize>,
1079    pub allow_partial: bool,
1080}
1081
1082impl ColumnMatcherRewriter {
1083    pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1084        Self {
1085            schema,
1086            is_rewritten: false,
1087            primary_key_indices,
1088            allow_partial,
1089        }
1090    }
1091
1092    /// modify the exprs in place so that it matches the schema and some auto columns are added
1093    fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1094        if self.allow_partial {
1095            return self.modify_project_exprs_with_partial(exprs);
1096        }
1097
1098        let all_names = self
1099            .schema
1100            .column_schemas()
1101            .iter()
1102            .map(|c| c.name.clone())
1103            .collect::<BTreeSet<_>>();
1104        // first match by position
1105        for (idx, expr) in exprs.iter_mut().enumerate() {
1106            if !all_names.contains(&expr.qualified_name().1)
1107                && let Some(col_name) = self
1108                    .schema
1109                    .column_schemas()
1110                    .get(idx)
1111                    .map(|c| c.name.clone())
1112            {
1113                // if the data type mismatched, later check_execute will error out
1114                // hence no need to check it here, beside, optimize pass might be able to cast it
1115                // so checking here is not necessary
1116                *expr = expr.clone().alias(col_name);
1117            }
1118        }
1119
1120        // add columns if have different column count
1121        let query_col_cnt = exprs.len();
1122        let table_col_cnt = self.schema.column_schemas().len();
1123        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1124
1125        let placeholder_ts_expr =
1126            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1127                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1128
1129        if query_col_cnt == table_col_cnt {
1130            // still need to add alias, see below
1131        } else if query_col_cnt + 1 == table_col_cnt {
1132            let last_col_schema = self.schema.column_schemas().last().unwrap();
1133
1134            // if time index column is auto created add it
1135            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1136                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1137            {
1138                exprs.push(placeholder_ts_expr);
1139            } else if last_col_schema.data_type.is_timestamp() {
1140                // is the update at column
1141                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1142            } else {
1143                // helpful error message
1144                return Err(DataFusionError::Plan(format!(
1145                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
1146                    last_col_schema.name, last_col_schema.data_type
1147                )));
1148            }
1149        } else if query_col_cnt + 2 == table_col_cnt {
1150            let mut col_iter = self.schema.column_schemas().iter().rev();
1151            let last_col_schema = col_iter.next().unwrap();
1152            let second_last_col_schema = col_iter.next().unwrap();
1153            if second_last_col_schema.data_type.is_timestamp() {
1154                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1155            } else {
1156                return Err(DataFusionError::Plan(format!(
1157                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1158                    second_last_col_schema.name, second_last_col_schema.data_type
1159                )));
1160            }
1161
1162            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1163                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1164            {
1165                exprs.push(placeholder_ts_expr);
1166            } else {
1167                return Err(DataFusionError::Plan(format!(
1168                    "Expect timestamp column {}, found {:?}",
1169                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1170                )));
1171            }
1172        } else {
1173            return Err(DataFusionError::Plan(format!(
1174                "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
1175                query_col_cnt,
1176                exprs,
1177                table_col_cnt,
1178                self.schema.column_schemas()
1179            )));
1180        }
1181        Ok(exprs)
1182    }
1183
1184    fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1185        let table_col_cnt = self.schema.column_schemas().len();
1186        let query_col_cnt = exprs.len();
1187
1188        if query_col_cnt > table_col_cnt {
1189            return Err(DataFusionError::Plan(format!(
1190                "Expect query column count <= table column count, found {} query columns {:?}, {} table columns {:?}",
1191                query_col_cnt,
1192                exprs,
1193                table_col_cnt,
1194                self.schema.column_schemas()
1195            )));
1196        }
1197
1198        let name_to_expr: HashMap<String, Expr> = exprs
1199            .clone()
1200            .into_iter()
1201            .map(|e| (e.qualified_name().1, e))
1202            .collect();
1203
1204        let required_columns = self.required_columns_for_partial();
1205        let missing: Vec<_> = required_columns
1206            .iter()
1207            .filter(|name| !name_to_expr.contains_key(*name))
1208            .cloned()
1209            .collect();
1210        if !missing.is_empty() {
1211            return Err(DataFusionError::Plan(format!(
1212                "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null",
1213                missing
1214            )));
1215        }
1216
1217        let placeholder_ts_expr =
1218            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1219                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1220
1221        let timestamp_index = self.schema.timestamp_index();
1222        let mut remap = name_to_expr;
1223        let mut new_exprs = Vec::with_capacity(table_col_cnt);
1224
1225        for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1226            let col_name = col_schema.name.clone();
1227            if let Some(expr) = remap.remove(&col_name) {
1228                let expr = if expr.qualified_name().1 == col_name {
1229                    expr
1230                } else {
1231                    expr.alias(col_name.clone())
1232                };
1233                new_exprs.push(expr);
1234                continue;
1235            }
1236
1237            if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1238                new_exprs.push(placeholder_ts_expr.clone());
1239                continue;
1240            }
1241
1242            if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1243                new_exprs.push(datafusion::prelude::now().alias(&col_name));
1244                continue;
1245            }
1246
1247            new_exprs.push(Self::null_expr(col_schema));
1248        }
1249
1250        if !remap.is_empty() {
1251            let extra: Vec<_> = remap.keys().cloned().collect();
1252            return Err(DataFusionError::Plan(format!(
1253                "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null",
1254                extra
1255            )));
1256        }
1257
1258        Ok(new_exprs)
1259    }
1260
1261    fn null_expr(col_schema: &ColumnSchema) -> Expr {
1262        Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1263    }
1264
1265    fn required_columns_for_partial(&self) -> HashSet<String> {
1266        let mut required = HashSet::new();
1267        for idx in &self.primary_key_indices {
1268            if let Some(col) = self.schema.column_schemas().get(*idx) {
1269                required.insert(col.name.clone());
1270            }
1271        }
1272
1273        if let Some(ts_idx) = self.schema.timestamp_index()
1274            && let Some(col) = self.schema.column_schemas().get(ts_idx)
1275            && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1276        {
1277            required.insert(col.name.clone());
1278        }
1279
1280        required
1281    }
1282}
1283
1284impl TreeNodeRewriter for ColumnMatcherRewriter {
1285    type Node = LogicalPlan;
1286    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1287        if self.is_rewritten {
1288            return Ok(Transformed::no(node));
1289        }
1290
1291        // if is distinct all, wrap it in a projection
1292        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1293            let mut exprs = vec![];
1294
1295            for field in node.schema().fields().iter() {
1296                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1297                    field.name(),
1298                )));
1299            }
1300
1301            let projection =
1302                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1303
1304            node = projection;
1305        }
1306        // handle table_scan by wrap it in a projection
1307        else if let LogicalPlan::TableScan(table_scan) = node {
1308            let mut exprs = vec![];
1309
1310            for field in table_scan.projected_schema.fields().iter() {
1311                exprs.push(Expr::Column(datafusion::common::Column::new(
1312                    Some(table_scan.table_name.clone()),
1313                    field.name(),
1314                )));
1315            }
1316
1317            let projection = LogicalPlan::Projection(Projection::try_new(
1318                exprs,
1319                Arc::new(LogicalPlan::TableScan(table_scan)),
1320            )?);
1321
1322            node = projection;
1323        }
1324
1325        // only do rewrite if found the outermost projection
1326        // if the outermost node is projection, can rewrite the exprs
1327        // if not, wrap it in a projection
1328        if let LogicalPlan::Projection(project) = &node {
1329            let exprs = project.expr.clone();
1330            let exprs = self.modify_project_exprs(exprs)?;
1331
1332            self.is_rewritten = true;
1333            let new_plan =
1334                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1335            Ok(Transformed::yes(new_plan))
1336        } else {
1337            // wrap the logical plan in a projection
1338            let mut exprs = vec![];
1339            for field in node.schema().fields().iter() {
1340                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1341                    field.name(),
1342                )));
1343            }
1344            let exprs = self.modify_project_exprs(exprs)?;
1345            self.is_rewritten = true;
1346            let new_plan =
1347                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1348            Ok(Transformed::yes(new_plan))
1349        }
1350    }
1351
1352    /// We might add new columns, so we need to recompute the schema
1353    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1354        node.recompute_schema().map(Transformed::yes)
1355    }
1356}
1357
1358/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
1359#[derive(Debug)]
1360pub struct AddFilterRewriter {
1361    extra_filter: Expr,
1362    is_rewritten: bool,
1363}
1364
1365impl AddFilterRewriter {
1366    pub fn new(filter: Expr) -> Self {
1367        Self {
1368            extra_filter: filter,
1369            is_rewritten: false,
1370        }
1371    }
1372}
1373
1374impl TreeNodeRewriter for AddFilterRewriter {
1375    type Node = LogicalPlan;
1376    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1377        if self.is_rewritten {
1378            return Ok(Transformed::no(node));
1379        }
1380        match node {
1381            LogicalPlan::Filter(mut filter) => {
1382                filter.predicate = filter.predicate.and(self.extra_filter.clone());
1383                self.is_rewritten = true;
1384                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1385            }
1386            LogicalPlan::TableScan(_) => {
1387                // add a new filter
1388                let filter =
1389                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1390                self.is_rewritten = true;
1391                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1392            }
1393            _ => Ok(Transformed::no(node)),
1394        }
1395    }
1396}