Skip to main content

flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32    Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36    Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
37    bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::schema::{ColumnSchema, SchemaRef};
40use query::QueryEngineRef;
41use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
42use session::context::QueryContextRef;
43use snafu::{OptionExt, ResultExt, ensure};
44use sql::parser::{ParseOptions, ParserContext};
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47use table::TableRef;
48use table::table::adapter::DfTableProviderAdapter;
49
50use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
51use crate::df_optimizer::apply_df_optimizer;
52use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
53use crate::{Error, TableName};
54
55#[cfg(test)]
56mod test;
57
58/// Describes how one aggregate output field should be merged with the
59/// corresponding existing field in the sink table.
60///
61/// `output_field_name` is the final output/sink schema field name produced by
62/// the delta plan and read from the sink table. It is not a DataFusion `Column`
63/// reference. It may contain dots or other non-identifier characters when the
64/// query keeps DataFusion's raw aggregate output name, e.g.
65/// `max(numbers_with_ts.number)`.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct IncrementalAggregateMergeColumn {
68    /// Final output/sink field name for the aggregate result/state column.
69    ///
70    pub output_field_name: String,
71    pub merge_op: IncrementalAggregateMergeOp,
72}
73
74impl IncrementalAggregateMergeColumn {
75    /// Create a new merge column.
76    pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
77        Self {
78            output_field_name,
79            merge_op,
80        }
81    }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum IncrementalAggregateMergeOp {
86    Sum,
87    Min,
88    Max,
89    BoolAnd,
90    BoolOr,
91    BitAnd,
92    BitOr,
93    BitXor,
94}
95
96/// Analysis result for an incremental aggregate plan.
97///
98/// `group_key_names` and each merge column's `output_field_name` are final
99/// output/sink schema field names used to project both the delta plan and the
100/// sink table before the left-join merge. They are not DataFusion logical-plan
101/// `Column` references; callers must attach qualifiers structurally instead of
102/// formatting qualified names as strings.
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct IncrementalAggregateAnalysis {
105    /// Final output/sink field names for group keys used as merge join keys.
106    pub group_key_names: Vec<String>,
107    pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
108    /// Literal output fields that can be passed through from the delta plan.
109    pub literal_columns: Vec<String>,
110    /// Final output field order from the original aggregate plan.
111    pub output_field_names: Vec<String>,
112    pub unsupported_exprs: Vec<String>,
113}
114
115/// Recursively find all `Expr::Column` names inside an expression tree.
116/// Only recurses into wrappers that are merge-transparent.
117/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`, `Cast`) are
118/// intentionally not recursed into since their merge semantics would be
119/// incorrect.
120///
121/// `Cast`/`TryCast` are intentionally opaque: merging already-casted aggregate
122/// outputs is not generally equivalent to casting the final merged aggregate.
123fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
124    match expr {
125        Expr::Column(col) => {
126            names.push(col.name.clone());
127        }
128        Expr::Alias(alias) => find_column_names(&alias.expr, names),
129        _ => {}
130    }
131}
132
133fn unqualified_col(name: impl Into<String>) -> Expr {
134    Expr::Column(Column::from_name(name.into()))
135}
136
137fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
138    Expr::Column(Column::new(Some(qualifier), name.into()))
139}
140
141fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
142    Column::new(Some(qualifier), name.into())
143}
144
145fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
146    let mut group_finder = FindGroupByFinalName::default();
147    plan.visit(&mut group_finder)
148        .with_context(|_| DatafusionSnafu {
149            context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
150        })?;
151
152    let mut group_key_names = group_finder
153        .get_group_expr_names()
154        .unwrap_or_default()
155        .into_iter()
156        .collect::<Vec<_>>();
157    group_key_names.sort();
158    Ok(group_key_names)
159}
160
161fn has_grouping_set(plan: &LogicalPlan) -> bool {
162    match plan {
163        LogicalPlan::Aggregate(aggregate) => aggregate
164            .group_expr
165            .iter()
166            .any(|expr| matches!(expr, Expr::GroupingSet(_))),
167        _ => plan.inputs().into_iter().any(has_grouping_set),
168    }
169}
170
171fn has_aggregate(plan: &LogicalPlan) -> bool {
172    match plan {
173        LogicalPlan::Aggregate(_) => true,
174        _ => plan.inputs().into_iter().any(has_aggregate),
175    }
176}
177
178fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
179    while let LogicalPlan::SubqueryAlias(alias) = plan {
180        plan = alias.input.as_ref();
181    }
182    plan
183}
184
185fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
186    // Supported final shape: optional output Projection directly over one
187    // Aggregate. Post-aggregate filters (HAVING), ordering, limits,
188    // distinct/window/union/extension nodes are intentionally not accepted.
189    let plan = match plan {
190        LogicalPlan::Projection(projection) => projection.input.as_ref(),
191        _ => plan,
192    };
193
194    match plan {
195        LogicalPlan::Aggregate(aggregate) => {
196            check_input_plan_shape(aggregate.input.as_ref())?;
197            Ok(Some(aggregate))
198        }
199        LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
200            "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
201                .to_string(),
202        ),
203        _ if has_aggregate(plan) => Err(
204            "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
205        ),
206        _ => Ok(None),
207    }
208}
209
210fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
211    let plan = peel_subquery_aliases(plan);
212    match plan {
213        // Supported aggregate input shape: optional WHERE filter over a table scan.
214        // SubqueryAlias is a transparent naming wrapper for `FROM table AS alias`.
215        LogicalPlan::TableScan(_) => Ok(()),
216        LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
217            LogicalPlan::TableScan(_) => Ok(()),
218            _ => Err(
219                "unsupported aggregate input plan shape in incremental aggregate rewrite"
220                    .to_string(),
221            ),
222        },
223        _ => Err(
224            "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
225        ),
226    }
227}
228
229#[derive(Debug, Default)]
230struct OutputProjectionInfo {
231    has_top_level_projection: bool,
232    output_aliases: HashMap<String, String>,
233    duplicate_aggregate_aliases: BTreeSet<String>,
234    literal_columns: HashSet<String>,
235    output_field_names: Vec<String>,
236}
237
238impl OutputProjectionInfo {
239    fn output_field_name_set(&self) -> HashSet<String> {
240        self.output_field_names.iter().cloned().collect()
241    }
242
243    fn duplicate_output_names(&self) -> Vec<String> {
244        let mut seen = HashSet::new();
245        let mut duplicates = BTreeSet::new();
246        for name in &self.output_field_names {
247            if !seen.insert(name.clone()) {
248                duplicates.insert(name.clone());
249            }
250        }
251        duplicates.into_iter().collect()
252    }
253}
254
255fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
256    let mut projection_info = OutputProjectionInfo {
257        has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
258        output_field_names: plan
259            .schema()
260            .fields()
261            .iter()
262            .map(|field| field.name().clone())
263            .collect(),
264        ..Default::default()
265    };
266
267    let mut output_aliases = HashMap::new();
268    if let LogicalPlan::Projection(projection) = plan {
269        for expr in &projection.expr {
270            match expr {
271                Expr::Alias(alias) => {
272                    // Alias resolution has three cases:
273                    // - 0 Column refs (e.g., literal `42 AS lit`): record literal output
274                    // - 1 Column ref: record the mapping (e.g., `sum(x) AS total`)
275                    // - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`):
276                    //   skip — ambiguous merge semantics
277                    let alias_name = alias.name.clone();
278                    let mut col_names = Vec::new();
279                    find_column_names(&alias.expr, &mut col_names);
280                    match col_names.len() {
281                        0 if matches!(alias.expr.as_ref(), Expr::Literal(_, _)) => {
282                            projection_info.literal_columns.insert(alias_name);
283                        }
284                        1 => {
285                            if let Some(col_name) = col_names.into_iter().next() {
286                                if let Some(existing_alias) = output_aliases.get(&col_name) {
287                                    if existing_alias != &alias_name {
288                                        projection_info.duplicate_aggregate_aliases.insert(format!(
289                                            "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
290                                        ));
291                                    }
292                                } else {
293                                    output_aliases.insert(col_name, alias_name);
294                                }
295                            }
296                        }
297                        _ => {}
298                    }
299
300                    // If >1 column references detected (e.g., COALESCE(sum(x), sum(y))),
301                    // intentionally skip alias mapping — the merge semantics are ambiguous.
302                }
303                Expr::Column(col) => {
304                    output_aliases
305                        .entry(col.name.clone())
306                        .or_insert(col.name.clone());
307                }
308                Expr::Literal(_, _) => {
309                    projection_info
310                        .literal_columns
311                        .insert(expr.qualified_name().1);
312                }
313                _ => {}
314            }
315        }
316    }
317
318    projection_info.output_aliases = output_aliases;
319    projection_info
320}
321
322fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
323    let Some(aggr_func) = get_aggr_func(aggr_expr) else {
324        return Err(aggr_expr.to_string());
325    };
326    if aggr_func.params.distinct {
327        return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
328    }
329    if !aggr_func.params.order_by.is_empty() {
330        return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
331    }
332    if aggr_func.params.null_treatment.is_some() {
333        return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
334    }
335
336    match aggr_func.func.name().to_ascii_lowercase().as_str() {
337        "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
338        "min" => Ok(IncrementalAggregateMergeOp::Min),
339        "max" => Ok(IncrementalAggregateMergeOp::Max),
340        "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
341        "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
342        "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
343        "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
344        "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
345        _ => Err(aggr_expr.to_string()),
346    }
347}
348
349fn resolve_aggregate_output_field_name(
350    aggr_expr: &Expr,
351    projection_info: &OutputProjectionInfo,
352    output_field_name_set: &HashSet<String>,
353) -> Option<String> {
354    // qualified_name() returns (Option<String>, String) where the second
355    // element is the unqualified column/alias name. This relies on
356    // DataFusion's internal naming convention: aggregate expressions
357    // emit a column named after the aggregate itself (e.g. "SUM(x)"),
358    // which matches what the projection aliases reference.
359    let raw_name = aggr_expr.qualified_name().1;
360    if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
361        Some(alias.clone())
362    } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
363    {
364        Some(raw_name)
365    } else {
366        None
367    }
368}
369
370fn find_uncovered_output_fields(
371    projection_info: &OutputProjectionInfo,
372    group_key_names: &[String],
373    merge_columns: &[IncrementalAggregateMergeColumn],
374) -> Vec<String> {
375    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
376    let merge_column_names = merge_columns
377        .iter()
378        .map(|c| c.output_field_name.clone())
379        .collect::<HashSet<_>>();
380
381    projection_info
382        .output_field_names
383        .iter()
384        .filter(|name| {
385            !group_key_names.contains(*name)
386                && !merge_column_names.contains(*name)
387                && !projection_info.literal_columns.contains(*name)
388        })
389        .cloned()
390        .collect()
391}
392
393fn find_unsupported_group_key_projection_outputs(
394    plan: &LogicalPlan,
395    aggregate: &Aggregate,
396    group_key_names: &[String],
397) -> Vec<String> {
398    let LogicalPlan::Projection(projection) = plan else {
399        return vec![];
400    };
401
402    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
403    let group_expr_names = aggregate
404        .group_expr
405        .iter()
406        .filter_map(|expr| expr.name_for_alias().ok())
407        .collect::<HashSet<_>>();
408    projection
409        .expr
410        .iter()
411        .filter_map(|expr| {
412            let output_name = expr.qualified_name().1;
413            if !group_key_names.contains(&output_name) {
414                return None;
415            }
416
417            let source_name = match expr {
418                Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
419                _ => expr.name_for_alias().ok(),
420            };
421            if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
422                None
423            } else {
424                Some(format!(
425                    "unsupported group key output field is not a transparent group expression: {output_name}"
426                ))
427            }
428        })
429        .collect()
430}
431
432pub fn analyze_incremental_aggregate_plan(
433    plan: &LogicalPlan,
434) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
435    let group_key_names = find_group_key_names(plan)?;
436    let aggregate = match extract_incremental_aggregate(plan) {
437        Ok(Some(aggregate)) => aggregate,
438        Ok(None) => return Ok(None),
439        Err(reason) => {
440            let projection_info = collect_output_projection_info(plan);
441            let mut unsupported_exprs = projection_info
442                .duplicate_output_names()
443                .into_iter()
444                .map(|name| format!("duplicate output field name: {name}"))
445                .collect::<Vec<_>>();
446            unsupported_exprs.push(reason);
447            unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
448            return Ok(Some(IncrementalAggregateAnalysis {
449                group_key_names,
450                merge_columns: vec![],
451                literal_columns: vec![],
452                output_field_names: projection_info.output_field_names,
453                unsupported_exprs,
454            }));
455        }
456    };
457    let aggr_exprs = aggregate.aggr_expr.clone();
458    let projection_info = collect_output_projection_info(plan);
459    let output_field_name_set = projection_info.output_field_name_set();
460
461    let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
462    let mut unsupported_exprs = projection_info
463        .duplicate_output_names()
464        .into_iter()
465        .map(|name| format!("duplicate output field name: {name}"))
466        .collect::<Vec<_>>();
467    if has_grouping_set(plan) {
468        unsupported_exprs.push(
469            "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
470        );
471    }
472    if group_key_names.is_empty() {
473        unsupported_exprs
474            .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
475    }
476    unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
477        plan,
478        aggregate,
479        &group_key_names,
480    ));
481    unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
482    for aggr_expr in aggr_exprs {
483        let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
484            Ok(merge_op) => merge_op,
485            Err(reason) => {
486                unsupported_exprs.push(reason);
487                continue;
488            }
489        };
490        let Some(output_field_name) = resolve_aggregate_output_field_name(
491            &aggr_expr,
492            &projection_info,
493            &output_field_name_set,
494        ) else {
495            unsupported_exprs.push(aggr_expr.to_string());
496            continue;
497        };
498        merge_columns.push(IncrementalAggregateMergeColumn::new(
499            output_field_name,
500            merge_op,
501        ));
502    }
503    unsupported_exprs.extend(
504        find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
505            .into_iter()
506            .map(|name| format!("unsupported output field: {name}")),
507    );
508    if !unsupported_exprs.is_empty() {
509        merge_columns.clear();
510    }
511    let mut literal_columns = projection_info
512        .literal_columns
513        .into_iter()
514        .collect::<Vec<_>>();
515    literal_columns.sort();
516
517    Ok(Some(IncrementalAggregateAnalysis {
518        group_key_names,
519        merge_columns,
520        literal_columns,
521        output_field_names: projection_info.output_field_names,
522        unsupported_exprs,
523    }))
524}
525
526/// Rewrites one incremental aggregate delta plan by left-joining it with the
527/// existing sink-table state and projecting merged aggregate outputs.
528///
529/// For a grouped aggregate such as:
530///
531/// ```text
532/// SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts
533/// ```
534///
535/// the rewrite is roughly:
536///
537/// ```text
538/// delta = SELECT ts, number FROM <delta_plan> AS __flow_delta
539/// sink  = SELECT ts, number FROM <sink_table> AS __flow_sink
540/// SELECT
541///   CASE
542///     WHEN __flow_sink.number IS NULL THEN __flow_delta.number
543///     WHEN __flow_delta.number >= __flow_sink.number THEN __flow_delta.number
544///     ELSE __flow_sink.number
545///   END AS number,
546///   __flow_delta.ts AS ts
547/// FROM delta
548/// LEFT JOIN sink
549///   ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts
550/// ```
551pub async fn rewrite_incremental_aggregate_with_sink_merge(
552    delta_plan: &LogicalPlan,
553    analysis: &IncrementalAggregateAnalysis,
554    sink_table: TableRef,
555    sink_table_name: &TableName,
556) -> Result<LogicalPlan, Error> {
557    ensure!(
558        analysis.unsupported_exprs.is_empty(),
559        InvalidQuerySnafu {
560            reason: format!(
561                "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
562                analysis.unsupported_exprs
563            )
564        }
565    );
566
567    ensure!(
568        !analysis.merge_columns.is_empty(),
569        InvalidQuerySnafu {
570            reason:
571                "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
572                    .to_string()
573        }
574    );
575
576    ensure!(
577        !analysis.group_key_names.is_empty(),
578        InvalidQuerySnafu {
579            reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
580                .to_string()
581        }
582    );
583
584    let delta_alias = "__flow_delta";
585    let sink_alias = "__flow_sink";
586
587    let mut selected_columns = analysis.group_key_names.clone();
588    selected_columns.extend(
589        analysis
590            .merge_columns
591            .iter()
592            .map(|c| c.output_field_name.clone()),
593    );
594    let mut delta_selected_columns = selected_columns.clone();
595    delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
596
597    let delta_selected_exprs = delta_selected_columns
598        .iter()
599        .cloned()
600        .map(unqualified_col)
601        .collect::<Vec<_>>();
602    let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
603        .project(delta_selected_exprs)
604        .with_context(|_| DatafusionSnafu {
605            context: "Failed to project delta plan for incremental sink merge".to_string(),
606        })?
607        .alias(delta_alias)
608        .with_context(|_| DatafusionSnafu {
609            context: "Failed to alias delta plan for incremental sink merge".to_string(),
610        })?
611        .build()
612        .with_context(|_| DatafusionSnafu {
613            context: "Failed to build projected delta plan for incremental sink merge".to_string(),
614        })?;
615
616    let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
617    let table_source = Arc::new(DefaultTableSource::new(table_provider));
618    let sink_scan = LogicalPlan::TableScan(
619        TableScan::try_new(
620            TableReference::Full {
621                catalog: sink_table_name[0].clone().into(),
622                schema: sink_table_name[1].clone().into(),
623                table: sink_table_name[2].clone().into(),
624            },
625            table_source,
626            None,
627            vec![],
628            None,
629        )
630        .with_context(|_| DatafusionSnafu {
631            context: "Failed to build sink table scan for incremental sink merge".to_string(),
632        })?,
633    );
634
635    let sink_selected_exprs = selected_columns
636        .iter()
637        .cloned()
638        .map(unqualified_col)
639        .collect::<Vec<_>>();
640    let sink_selected = LogicalPlanBuilder::from(sink_scan)
641        .project(sink_selected_exprs)
642        .with_context(|_| DatafusionSnafu {
643            context: "Failed to project sink table scan for incremental sink merge".to_string(),
644        })?
645        .alias(sink_alias)
646        .with_context(|_| DatafusionSnafu {
647            context: "Failed to alias sink plan for incremental sink merge".to_string(),
648        })?
649        .build()
650        .with_context(|_| DatafusionSnafu {
651            context: "Failed to build projected sink plan for incremental sink merge".to_string(),
652        })?;
653
654    let join_keys = (
655        analysis
656            .group_key_names
657            .iter()
658            .cloned()
659            .map(|c| qualified_column(delta_alias, c))
660            .collect::<Vec<_>>(),
661        analysis
662            .group_key_names
663            .iter()
664            .cloned()
665            .map(|c| qualified_column(sink_alias, c))
666            .collect::<Vec<_>>(),
667    );
668
669    let joined = LogicalPlanBuilder::from(delta_selected)
670        .join_detailed(
671            sink_selected,
672            JoinType::Left,
673            join_keys,
674            None,
675            NullEquality::NullEqualsNull,
676        )
677        .with_context(|_| DatafusionSnafu {
678            context: "Failed to left join delta and sink plans for incremental sink merge"
679                .to_string(),
680        })?
681        .build()
682        .with_context(|_| DatafusionSnafu {
683            context: "Failed to build left join plan for incremental sink merge".to_string(),
684        })?;
685
686    let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
687    let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
688    let merge_columns = analysis
689        .merge_columns
690        .iter()
691        .map(|c| (&c.output_field_name, c))
692        .collect::<HashMap<_, _>>();
693
694    let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
695    for output_field_name in &analysis.output_field_names {
696        if group_key_names.contains(output_field_name)
697            || literal_columns.contains(output_field_name)
698        {
699            projection_exprs.push(
700                qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
701            );
702        } else if let Some(merge_col) = merge_columns.get(output_field_name) {
703            projection_exprs.push(build_left_join_merge_expr(
704                delta_alias,
705                sink_alias,
706                merge_col,
707            )?);
708        } else {
709            return InvalidQuerySnafu {
710                reason: format!(
711                    "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
712                ),
713            }
714            .fail();
715        }
716    }
717
718    LogicalPlanBuilder::from(joined)
719        .project(projection_exprs)
720        .with_context(|_| DatafusionSnafu {
721            context: "Failed to build projection merge plan for incremental sink merge".to_string(),
722        })?
723        .build()
724        .with_context(|_| DatafusionSnafu {
725            context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
726        })
727}
728
729fn build_left_join_merge_expr(
730    delta_alias: &str,
731    sink_alias: &str,
732    merge_col: &IncrementalAggregateMergeColumn,
733) -> Result<Expr, Error> {
734    let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
735    let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
736    let merged = match merge_col.merge_op {
737        IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
738            .when(is_null(right.clone()), left.clone())
739            .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
740            .with_context(|_| DatafusionSnafu {
741                context: "Failed to build SUM merge expression".to_string(),
742            })?,
743        IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
744            .when(left.clone().lt_eq(right.clone()), left.clone())
745            .otherwise(right.clone())
746            .with_context(|_| DatafusionSnafu {
747                context: "Failed to build MIN merge expression".to_string(),
748            })?,
749        IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
750            .when(left.clone().gt_eq(right.clone()), left.clone())
751            .otherwise(right.clone())
752            .with_context(|_| DatafusionSnafu {
753                context: "Failed to build MAX merge expression".to_string(),
754            })?,
755        IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
756            .when(is_null(right.clone()), left.clone())
757            .otherwise(and(left.clone(), right.clone()))
758            .with_context(|_| DatafusionSnafu {
759                context: "Failed to build BOOL_AND merge expression".to_string(),
760            })?,
761        IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
762            .when(is_null(right.clone()), left.clone())
763            .otherwise(or(left.clone(), right.clone()))
764            .with_context(|_| DatafusionSnafu {
765                context: "Failed to build BOOL_OR merge expression".to_string(),
766            })?,
767        IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
768            .when(is_null(right.clone()), left.clone())
769            .otherwise(bitwise_and(left.clone(), right.clone()))
770            .with_context(|_| DatafusionSnafu {
771                context: "Failed to build BIT_AND merge expression".to_string(),
772            })?,
773        IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
774            .when(is_null(right.clone()), left.clone())
775            .otherwise(bitwise_or(left.clone(), right.clone()))
776            .with_context(|_| DatafusionSnafu {
777                context: "Failed to build BIT_OR merge expression".to_string(),
778            })?,
779        IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
780            .when(is_null(right.clone()), left.clone())
781            .otherwise(bitwise_xor(left.clone(), right.clone()))
782            .with_context(|_| DatafusionSnafu {
783                context: "Failed to build BIT_XOR merge expression".to_string(),
784            })?,
785    };
786    Ok(merged.alias(merge_col.output_field_name.clone()))
787}
788
789pub async fn get_table_info_df_schema(
790    catalog_mr: CatalogManagerRef,
791    table_name: TableName,
792) -> Result<(TableRef, Arc<DFSchema>), Error> {
793    let full_table_name = table_name.clone().join(".");
794    let table = catalog_mr
795        .table(&table_name[0], &table_name[1], &table_name[2], None)
796        .await
797        .map_err(BoxedError::new)
798        .context(ExternalSnafu)?
799        .context(TableNotFoundSnafu {
800            name: &full_table_name,
801        })?;
802    let table_info = table.table_info();
803
804    let schema = table_info.meta.schema.clone();
805
806    let df_schema: Arc<DFSchema> = Arc::new(
807        schema
808            .arrow_schema()
809            .clone()
810            .try_into()
811            .with_context(|_| DatafusionSnafu {
812                context: format!(
813                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
814                    schema.arrow_schema()
815                ),
816            })?,
817    );
818    Ok((table, df_schema))
819}
820
821/// Convert sql to datafusion logical plan
822/// Also support TQL (but only Eval not Explain or Analyze)
823pub async fn sql_to_df_plan(
824    query_ctx: QueryContextRef,
825    engine: QueryEngineRef,
826    sql: &str,
827    optimize: bool,
828) -> Result<LogicalPlan, Error> {
829    let stmts =
830        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
831            .map_err(BoxedError::new)
832            .context(ExternalSnafu)?;
833
834    ensure!(
835        stmts.len() == 1,
836        InvalidQuerySnafu {
837            reason: format!("Expect only one statement, found {}", stmts.len())
838        }
839    );
840    let stmt = &stmts[0];
841    let query_stmt = match stmt {
842        Statement::Tql(tql) => match tql {
843            Tql::Eval(eval) => {
844                let eval = eval.clone();
845                let promql = PromQuery {
846                    start: eval.start,
847                    end: eval.end,
848                    step: eval.step,
849                    query: eval.query,
850                    lookback: eval
851                        .lookback
852                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
853                    alias: eval.alias.clone(),
854                };
855
856                QueryLanguageParser::parse_promql(&promql, &query_ctx)
857                    .map_err(BoxedError::new)
858                    .context(ExternalSnafu)?
859            }
860            _ => InvalidQuerySnafu {
861                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
862            }
863            .fail()?,
864        },
865        _ => QueryStatement::Sql(stmt.clone()),
866    };
867    let plan = engine
868        .planner()
869        .plan(&query_stmt, query_ctx.clone())
870        .await
871        .map_err(BoxedError::new)
872        .context(ExternalSnafu)?;
873
874    let plan = if optimize {
875        apply_df_optimizer(plan, &query_ctx).await?
876    } else {
877        plan
878    };
879    Ok(plan)
880}
881
882/// Generate a plan that matches the schema of the sink table
883/// from given sql by alias and adding auto columns
884pub(crate) async fn gen_plan_with_matching_schema(
885    sql: &str,
886    query_ctx: QueryContextRef,
887    engine: QueryEngineRef,
888    sink_table_schema: SchemaRef,
889    primary_key_indices: &[usize],
890    allow_partial: bool,
891) -> Result<LogicalPlan, Error> {
892    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
893
894    let mut add_auto_column = ColumnMatcherRewriter::new(
895        sink_table_schema,
896        primary_key_indices.to_vec(),
897        allow_partial,
898    );
899    let plan = plan
900        .clone()
901        .rewrite(&mut add_auto_column)
902        .with_context(|_| DatafusionSnafu {
903            context: format!("Failed to rewrite plan:\n {}\n", plan),
904        })?
905        .data;
906    Ok(plan)
907}
908
909pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
910    /// A dialect that forces identifiers to be quoted when have uppercase
911    struct ForceQuoteIdentifiers;
912    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
913        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
914            if identifier.to_lowercase() != identifier {
915                Some('`')
916            } else {
917                None
918            }
919        }
920    }
921    let unparser = Unparser::new(&ForceQuoteIdentifiers);
922    // first make all column qualified
923    let sql = unparser
924        .plan_to_sql(plan)
925        .with_context(|_e| DatafusionSnafu {
926            context: format!("Failed to unparse logical plan {plan:?}"),
927        })?;
928    Ok(sql.to_string())
929}
930
931/// Helper to find the innermost group by expr in schema, return None if no group by expr
932#[derive(Debug, Clone, Default)]
933pub struct FindGroupByFinalName {
934    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
935}
936
937impl FindGroupByFinalName {
938    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
939        self.group_exprs
940            .as_ref()
941            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
942    }
943}
944
945impl TreeNodeVisitor<'_> for FindGroupByFinalName {
946    type Node = LogicalPlan;
947
948    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
949        if let LogicalPlan::Aggregate(aggregate) = node {
950            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
951            debug!(
952                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
953                self.group_exprs
954            );
955        } else if let LogicalPlan::Distinct(distinct) = node {
956            debug!("FindGroupByFinalName: Distinct: {}", node);
957            match distinct {
958                Distinct::All(input) => {
959                    if let LogicalPlan::TableScan(table_scan) = &**input {
960                        // get column from field_qualifier, projection and projected_schema:
961                        let len = table_scan.projected_schema.fields().len();
962                        let columns = (0..len)
963                            .map(|f| {
964                                let (qualifier, field) =
965                                    table_scan.projected_schema.qualified_field(f);
966                                datafusion_common::Column::new(qualifier.cloned(), field.name())
967                            })
968                            .map(datafusion_expr::Expr::Column);
969                        self.group_exprs = Some(columns.collect());
970                    } else {
971                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
972                    }
973                }
974                Distinct::On(distinct_on) => {
975                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
976                }
977            }
978            debug!(
979                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
980                self.group_exprs
981            );
982        }
983
984        Ok(TreeNodeRecursion::Continue)
985    }
986
987    /// deal with projection when going up with group exprs
988    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
989        if let LogicalPlan::Projection(projection) = node {
990            for expr in &projection.expr {
991                let Some(group_exprs) = &mut self.group_exprs else {
992                    return Ok(TreeNodeRecursion::Continue);
993                };
994                if let datafusion_expr::Expr::Alias(alias) = expr {
995                    // if a alias exist, replace with the new alias
996                    let mut new_group_exprs = group_exprs.clone();
997                    for group_expr in group_exprs.iter() {
998                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
999                            new_group_exprs.remove(group_expr);
1000                            new_group_exprs.insert(expr.clone());
1001                            break;
1002                        }
1003                    }
1004                    *group_exprs = new_group_exprs;
1005                }
1006            }
1007        }
1008        debug!("Aliased group by exprs: {:?}", self.group_exprs);
1009        Ok(TreeNodeRecursion::Continue)
1010    }
1011}
1012
1013/// Optionally add to the final select columns like `update_at` if the sink table has such column
1014/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
1015/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
1016/// with values like `now()` and `0`
1017///
1018/// it also give existing columns alias to column in sink table if needed
1019#[derive(Debug)]
1020pub struct ColumnMatcherRewriter {
1021    pub schema: SchemaRef,
1022    pub is_rewritten: bool,
1023    pub primary_key_indices: Vec<usize>,
1024    pub allow_partial: bool,
1025}
1026
1027impl ColumnMatcherRewriter {
1028    pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1029        Self {
1030            schema,
1031            is_rewritten: false,
1032            primary_key_indices,
1033            allow_partial,
1034        }
1035    }
1036
1037    /// modify the exprs in place so that it matches the schema and some auto columns are added
1038    fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1039        if self.allow_partial {
1040            return self.modify_project_exprs_with_partial(exprs);
1041        }
1042
1043        let all_names = self
1044            .schema
1045            .column_schemas()
1046            .iter()
1047            .map(|c| c.name.clone())
1048            .collect::<BTreeSet<_>>();
1049        // first match by position
1050        for (idx, expr) in exprs.iter_mut().enumerate() {
1051            if !all_names.contains(&expr.qualified_name().1)
1052                && let Some(col_name) = self
1053                    .schema
1054                    .column_schemas()
1055                    .get(idx)
1056                    .map(|c| c.name.clone())
1057            {
1058                // if the data type mismatched, later check_execute will error out
1059                // hence no need to check it here, beside, optimize pass might be able to cast it
1060                // so checking here is not necessary
1061                *expr = expr.clone().alias(col_name);
1062            }
1063        }
1064
1065        // add columns if have different column count
1066        let query_col_cnt = exprs.len();
1067        let table_col_cnt = self.schema.column_schemas().len();
1068        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1069
1070        let placeholder_ts_expr =
1071            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1072                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1073
1074        if query_col_cnt == table_col_cnt {
1075            // still need to add alias, see below
1076        } else if query_col_cnt + 1 == table_col_cnt {
1077            let last_col_schema = self.schema.column_schemas().last().unwrap();
1078
1079            // if time index column is auto created add it
1080            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1081                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1082            {
1083                exprs.push(placeholder_ts_expr);
1084            } else if last_col_schema.data_type.is_timestamp() {
1085                // is the update at column
1086                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1087            } else {
1088                // helpful error message
1089                return Err(DataFusionError::Plan(format!(
1090                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
1091                    last_col_schema.name, last_col_schema.data_type
1092                )));
1093            }
1094        } else if query_col_cnt + 2 == table_col_cnt {
1095            let mut col_iter = self.schema.column_schemas().iter().rev();
1096            let last_col_schema = col_iter.next().unwrap();
1097            let second_last_col_schema = col_iter.next().unwrap();
1098            if second_last_col_schema.data_type.is_timestamp() {
1099                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1100            } else {
1101                return Err(DataFusionError::Plan(format!(
1102                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1103                    second_last_col_schema.name, second_last_col_schema.data_type
1104                )));
1105            }
1106
1107            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1108                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1109            {
1110                exprs.push(placeholder_ts_expr);
1111            } else {
1112                return Err(DataFusionError::Plan(format!(
1113                    "Expect timestamp column {}, found {:?}",
1114                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1115                )));
1116            }
1117        } else {
1118            return Err(DataFusionError::Plan(format!(
1119                "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
1120                query_col_cnt,
1121                exprs,
1122                table_col_cnt,
1123                self.schema.column_schemas()
1124            )));
1125        }
1126        Ok(exprs)
1127    }
1128
1129    fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1130        let table_col_cnt = self.schema.column_schemas().len();
1131        let query_col_cnt = exprs.len();
1132
1133        if query_col_cnt > table_col_cnt {
1134            return Err(DataFusionError::Plan(format!(
1135                "Expect query column count <= table column count, found {} query columns {:?}, {} table columns {:?}",
1136                query_col_cnt,
1137                exprs,
1138                table_col_cnt,
1139                self.schema.column_schemas()
1140            )));
1141        }
1142
1143        let name_to_expr: HashMap<String, Expr> = exprs
1144            .clone()
1145            .into_iter()
1146            .map(|e| (e.qualified_name().1, e))
1147            .collect();
1148
1149        let required_columns = self.required_columns_for_partial();
1150        let missing: Vec<_> = required_columns
1151            .iter()
1152            .filter(|name| !name_to_expr.contains_key(*name))
1153            .cloned()
1154            .collect();
1155        if !missing.is_empty() {
1156            return Err(DataFusionError::Plan(format!(
1157                "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null",
1158                missing
1159            )));
1160        }
1161
1162        let placeholder_ts_expr =
1163            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1164                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1165
1166        let timestamp_index = self.schema.timestamp_index();
1167        let mut remap = name_to_expr;
1168        let mut new_exprs = Vec::with_capacity(table_col_cnt);
1169
1170        for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1171            let col_name = col_schema.name.clone();
1172            if let Some(expr) = remap.remove(&col_name) {
1173                let expr = if expr.qualified_name().1 == col_name {
1174                    expr
1175                } else {
1176                    expr.alias(col_name.clone())
1177                };
1178                new_exprs.push(expr);
1179                continue;
1180            }
1181
1182            if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1183                new_exprs.push(placeholder_ts_expr.clone());
1184                continue;
1185            }
1186
1187            if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1188                new_exprs.push(datafusion::prelude::now().alias(&col_name));
1189                continue;
1190            }
1191
1192            new_exprs.push(Self::null_expr(col_schema));
1193        }
1194
1195        if !remap.is_empty() {
1196            let extra: Vec<_> = remap.keys().cloned().collect();
1197            return Err(DataFusionError::Plan(format!(
1198                "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null",
1199                extra
1200            )));
1201        }
1202
1203        Ok(new_exprs)
1204    }
1205
1206    fn null_expr(col_schema: &ColumnSchema) -> Expr {
1207        Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1208    }
1209
1210    fn required_columns_for_partial(&self) -> HashSet<String> {
1211        let mut required = HashSet::new();
1212        for idx in &self.primary_key_indices {
1213            if let Some(col) = self.schema.column_schemas().get(*idx) {
1214                required.insert(col.name.clone());
1215            }
1216        }
1217
1218        if let Some(ts_idx) = self.schema.timestamp_index()
1219            && let Some(col) = self.schema.column_schemas().get(ts_idx)
1220            && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1221        {
1222            required.insert(col.name.clone());
1223        }
1224
1225        required
1226    }
1227}
1228
1229impl TreeNodeRewriter for ColumnMatcherRewriter {
1230    type Node = LogicalPlan;
1231    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1232        if self.is_rewritten {
1233            return Ok(Transformed::no(node));
1234        }
1235
1236        // if is distinct all, wrap it in a projection
1237        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1238            let mut exprs = vec![];
1239
1240            for field in node.schema().fields().iter() {
1241                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1242                    field.name(),
1243                )));
1244            }
1245
1246            let projection =
1247                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1248
1249            node = projection;
1250        }
1251        // handle table_scan by wrap it in a projection
1252        else if let LogicalPlan::TableScan(table_scan) = node {
1253            let mut exprs = vec![];
1254
1255            for field in table_scan.projected_schema.fields().iter() {
1256                exprs.push(Expr::Column(datafusion::common::Column::new(
1257                    Some(table_scan.table_name.clone()),
1258                    field.name(),
1259                )));
1260            }
1261
1262            let projection = LogicalPlan::Projection(Projection::try_new(
1263                exprs,
1264                Arc::new(LogicalPlan::TableScan(table_scan)),
1265            )?);
1266
1267            node = projection;
1268        }
1269
1270        // only do rewrite if found the outermost projection
1271        // if the outermost node is projection, can rewrite the exprs
1272        // if not, wrap it in a projection
1273        if let LogicalPlan::Projection(project) = &node {
1274            let exprs = project.expr.clone();
1275            let exprs = self.modify_project_exprs(exprs)?;
1276
1277            self.is_rewritten = true;
1278            let new_plan =
1279                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1280            Ok(Transformed::yes(new_plan))
1281        } else {
1282            // wrap the logical plan in a projection
1283            let mut exprs = vec![];
1284            for field in node.schema().fields().iter() {
1285                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1286                    field.name(),
1287                )));
1288            }
1289            let exprs = self.modify_project_exprs(exprs)?;
1290            self.is_rewritten = true;
1291            let new_plan =
1292                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1293            Ok(Transformed::yes(new_plan))
1294        }
1295    }
1296
1297    /// We might add new columns, so we need to recompute the schema
1298    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1299        node.recompute_schema().map(Transformed::yes)
1300    }
1301}
1302
1303/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
1304#[derive(Debug)]
1305pub struct AddFilterRewriter {
1306    extra_filter: Expr,
1307    is_rewritten: bool,
1308}
1309
1310impl AddFilterRewriter {
1311    pub fn new(filter: Expr) -> Self {
1312        Self {
1313            extra_filter: filter,
1314            is_rewritten: false,
1315        }
1316    }
1317}
1318
1319impl TreeNodeRewriter for AddFilterRewriter {
1320    type Node = LogicalPlan;
1321    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1322        if self.is_rewritten {
1323            return Ok(Transformed::no(node));
1324        }
1325        match node {
1326            LogicalPlan::Filter(mut filter) => {
1327                filter.predicate = filter.predicate.and(self.extra_filter.clone());
1328                self.is_rewritten = true;
1329                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1330            }
1331            LogicalPlan::TableScan(_) => {
1332                // add a new filter
1333                let filter =
1334                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1335                self.is_rewritten = true;
1336                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1337            }
1338            _ => Ok(Transformed::no(node)),
1339        }
1340    }
1341}