Skip to main content

flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32    Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36    Distinct, ExprSchemable, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and,
37    binary_expr, bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::prelude::ConcreteDataType;
40use datatypes::schema::{ColumnSchema, SchemaRef};
41use query::QueryEngineRef;
42use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
43use session::context::QueryContextRef;
44use snafu::{OptionExt, ResultExt, ensure};
45use sql::parser::{ParseOptions, ParserContext};
46use sql::statements::statement::Statement;
47use sql::statements::tql::Tql;
48use table::TableRef;
49use table::table::adapter::DfTableProviderAdapter;
50
51use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
52use crate::df_optimizer::apply_df_optimizer;
53use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
54use crate::{Error, TableName};
55
56#[cfg(test)]
57mod test;
58
59/// Describes how one aggregate output field should be merged with the
60/// corresponding existing field in the sink table.
61///
62/// `output_field_name` is the final output/sink schema field name produced by
63/// the delta plan and read from the sink table. It is not a DataFusion `Column`
64/// reference. It may contain dots or other non-identifier characters when the
65/// query keeps DataFusion's raw aggregate output name, e.g.
66/// `max(numbers_with_ts.number)`.
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct IncrementalAggregateMergeColumn {
69    /// Final output/sink field name for the aggregate result/state column.
70    ///
71    pub output_field_name: String,
72    pub merge_op: IncrementalAggregateMergeOp,
73}
74
75impl IncrementalAggregateMergeColumn {
76    /// Create a new merge column.
77    pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
78        Self {
79            output_field_name,
80            merge_op,
81        }
82    }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum IncrementalAggregateMergeOp {
87    Sum,
88    Min,
89    Max,
90    BoolAnd,
91    BoolOr,
92    BitAnd,
93    BitOr,
94    BitXor,
95}
96
97/// Analysis result for an incremental aggregate plan.
98///
99/// `group_key_names` and each merge column's `output_field_name` are final
100/// output/sink schema field names used to project both the delta plan and the
101/// sink table before the left-join merge. They are not DataFusion logical-plan
102/// `Column` references; callers must attach qualifiers structurally instead of
103/// formatting qualified names as strings.
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct IncrementalAggregateAnalysis {
106    /// Final output/sink field names for group keys used as merge join keys.
107    pub group_key_names: Vec<String>,
108    pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
109    /// Literal output fields that can be passed through from the delta plan.
110    pub literal_columns: Vec<String>,
111    /// Final output field order from the original aggregate plan.
112    pub output_field_names: Vec<String>,
113    pub unsupported_exprs: Vec<String>,
114}
115
116/// Recursively find all `Expr::Column` names inside an expression tree.
117/// Only recurses into wrappers that are merge-transparent.
118/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`, `Cast`) are
119/// intentionally not recursed into since their merge semantics would be
120/// incorrect.
121///
122/// `Cast`/`TryCast` are intentionally opaque: merging already-casted aggregate
123/// outputs is not generally equivalent to casting the final merged aggregate.
124fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
125    match expr {
126        Expr::Column(col) => {
127            names.push(col.name.clone());
128        }
129        Expr::Alias(alias) => find_column_names(&alias.expr, names),
130        _ => {}
131    }
132}
133
134fn unqualified_col(name: impl Into<String>) -> Expr {
135    Expr::Column(Column::from_name(name.into()))
136}
137
138fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
139    Expr::Column(Column::new(Some(qualifier), name.into()))
140}
141
142fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
143    Column::new(Some(qualifier), name.into())
144}
145
146fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
147    let mut group_finder = FindGroupByFinalName::default();
148    plan.visit(&mut group_finder)
149        .with_context(|_| DatafusionSnafu {
150            context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
151        })?;
152
153    let mut group_key_names = group_finder
154        .get_group_expr_names()
155        .unwrap_or_default()
156        .into_iter()
157        .collect::<Vec<_>>();
158    group_key_names.sort();
159    Ok(group_key_names)
160}
161
162fn has_grouping_set(plan: &LogicalPlan) -> bool {
163    match plan {
164        LogicalPlan::Aggregate(aggregate) => aggregate
165            .group_expr
166            .iter()
167            .any(|expr| matches!(expr, Expr::GroupingSet(_))),
168        _ => plan.inputs().into_iter().any(has_grouping_set),
169    }
170}
171
172fn has_aggregate(plan: &LogicalPlan) -> bool {
173    match plan {
174        LogicalPlan::Aggregate(_) => true,
175        _ => plan.inputs().into_iter().any(has_aggregate),
176    }
177}
178
179fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
180    while let LogicalPlan::SubqueryAlias(alias) = plan {
181        plan = alias.input.as_ref();
182    }
183    plan
184}
185
186fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
187    // Supported final shape: optional output Projection directly over one
188    // Aggregate. Post-aggregate filters (HAVING), ordering, limits,
189    // distinct/window/union/extension nodes are intentionally not accepted.
190    let plan = match plan {
191        LogicalPlan::Projection(projection) => projection.input.as_ref(),
192        _ => plan,
193    };
194
195    match plan {
196        LogicalPlan::Aggregate(aggregate) => {
197            check_input_plan_shape(aggregate.input.as_ref())?;
198            Ok(Some(aggregate))
199        }
200        LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
201            "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
202                .to_string(),
203        ),
204        _ if has_aggregate(plan) => Err(
205            "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
206        ),
207        _ => Ok(None),
208    }
209}
210
211fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
212    let plan = peel_subquery_aliases(plan);
213    match plan {
214        // Supported aggregate input shape: optional WHERE filter over a table scan.
215        // SubqueryAlias is a transparent naming wrapper for `FROM table AS alias`.
216        LogicalPlan::TableScan(_) => Ok(()),
217        LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
218            LogicalPlan::TableScan(_) => Ok(()),
219            _ => Err(
220                "unsupported aggregate input plan shape in incremental aggregate rewrite"
221                    .to_string(),
222            ),
223        },
224        _ => Err(
225            "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
226        ),
227    }
228}
229
230#[derive(Debug, Default)]
231struct OutputProjectionInfo {
232    has_top_level_projection: bool,
233    output_aliases: HashMap<String, String>,
234    duplicate_aggregate_aliases: BTreeSet<String>,
235    literal_columns: HashSet<String>,
236    output_field_names: Vec<String>,
237}
238
239impl OutputProjectionInfo {
240    fn output_field_name_set(&self) -> HashSet<String> {
241        self.output_field_names.iter().cloned().collect()
242    }
243
244    fn duplicate_output_names(&self) -> Vec<String> {
245        let mut seen = HashSet::new();
246        let mut duplicates = BTreeSet::new();
247        for name in &self.output_field_names {
248            if !seen.insert(name.clone()) {
249                duplicates.insert(name.clone());
250            }
251        }
252        duplicates.into_iter().collect()
253    }
254}
255
256fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
257    let mut projection_info = OutputProjectionInfo {
258        has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
259        output_field_names: plan
260            .schema()
261            .fields()
262            .iter()
263            .map(|field| field.name().clone())
264            .collect(),
265        ..Default::default()
266    };
267
268    let mut output_aliases = HashMap::new();
269    if let LogicalPlan::Projection(projection) = plan {
270        for expr in &projection.expr {
271            match expr {
272                Expr::Alias(alias) => {
273                    // Alias resolution has three cases:
274                    // - 0 Column refs (e.g., literal `42 AS lit`): record literal output
275                    // - 1 Column ref: record the mapping (e.g., `sum(x) AS total`)
276                    // - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`):
277                    //   skip — ambiguous merge semantics
278                    let alias_name = alias.name.clone();
279                    let mut col_names = Vec::new();
280                    find_column_names(&alias.expr, &mut col_names);
281                    match col_names.len() {
282                        0 if is_passthrough_output_column(&alias_name, alias.expr.as_ref()) => {
283                            projection_info.literal_columns.insert(alias_name);
284                        }
285                        1 => {
286                            if let Some(col_name) = col_names.into_iter().next() {
287                                if let Some(existing_alias) = output_aliases.get(&col_name) {
288                                    if existing_alias != &alias_name {
289                                        projection_info.duplicate_aggregate_aliases.insert(format!(
290                                            "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
291                                        ));
292                                    }
293                                } else {
294                                    output_aliases.insert(col_name, alias_name);
295                                }
296                            }
297                        }
298                        _ => {}
299                    }
300
301                    // If >1 column references detected (e.g., COALESCE(sum(x), sum(y))),
302                    // intentionally skip alias mapping — the merge semantics are ambiguous.
303                }
304                Expr::Column(col) => {
305                    output_aliases
306                        .entry(col.name.clone())
307                        .or_insert(col.name.clone());
308                }
309                Expr::Literal(_, _) => {
310                    projection_info
311                        .literal_columns
312                        .insert(expr.qualified_name().1);
313                }
314                _ => {}
315            }
316        }
317    }
318
319    if projection_info
320        .output_field_names
321        .iter()
322        .any(|name| name == AUTO_CREATED_PLACEHOLDER_TS_COL)
323    {
324        projection_info
325            .literal_columns
326            .insert(AUTO_CREATED_PLACEHOLDER_TS_COL.to_string());
327    }
328
329    projection_info.output_aliases = output_aliases;
330    projection_info
331}
332
333fn is_passthrough_output_column(alias_name: &str, expr: &Expr) -> bool {
334    matches!(expr, Expr::Literal(_, _))
335        || match alias_name {
336            AUTO_CREATED_UPDATE_AT_TS_COL => expr == &datafusion::prelude::now(),
337            AUTO_CREATED_PLACEHOLDER_TS_COL => is_literal_or_cast_literal(expr),
338            _ => false,
339        }
340}
341
342fn is_literal_or_cast_literal(expr: &Expr) -> bool {
343    match expr {
344        Expr::Literal(_, _) => true,
345        Expr::Cast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
346        Expr::TryCast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
347        _ => false,
348    }
349}
350
351fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
352    let Some(aggr_func) = get_aggr_func(aggr_expr) else {
353        return Err(aggr_expr.to_string());
354    };
355    if aggr_func.params.distinct {
356        return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
357    }
358    if !aggr_func.params.order_by.is_empty() {
359        return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
360    }
361    if aggr_func.params.null_treatment.is_some() {
362        return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
363    }
364
365    match aggr_func.func.name().to_ascii_lowercase().as_str() {
366        "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
367        "min" => Ok(IncrementalAggregateMergeOp::Min),
368        "max" => Ok(IncrementalAggregateMergeOp::Max),
369        "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
370        "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
371        "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
372        "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
373        "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
374        _ => Err(aggr_expr.to_string()),
375    }
376}
377
378fn resolve_aggregate_output_field_name(
379    aggr_expr: &Expr,
380    projection_info: &OutputProjectionInfo,
381    output_field_name_set: &HashSet<String>,
382) -> Option<String> {
383    // qualified_name() returns (Option<String>, String) where the second
384    // element is the unqualified column/alias name. This relies on
385    // DataFusion's internal naming convention: aggregate expressions
386    // emit a column named after the aggregate itself (e.g. "SUM(x)"),
387    // which matches what the projection aliases reference.
388    let raw_name = aggr_expr.qualified_name().1;
389    if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
390        Some(alias.clone())
391    } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
392    {
393        Some(raw_name)
394    } else {
395        None
396    }
397}
398
399fn find_uncovered_output_fields(
400    projection_info: &OutputProjectionInfo,
401    group_key_names: &[String],
402    merge_columns: &[IncrementalAggregateMergeColumn],
403) -> Vec<String> {
404    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
405    let merge_column_names = merge_columns
406        .iter()
407        .map(|c| c.output_field_name.clone())
408        .collect::<HashSet<_>>();
409
410    projection_info
411        .output_field_names
412        .iter()
413        .filter(|name| {
414            !group_key_names.contains(*name)
415                && !merge_column_names.contains(*name)
416                && !projection_info.literal_columns.contains(*name)
417                // Auto-created sink columns injected by ColumnMatcherRewriter
418                // are not part of the original aggregate semantics and must
419                // not prevent incremental aggregate rewrites.
420                && name.as_str() != AUTO_CREATED_UPDATE_AT_TS_COL
421                && name.as_str() != AUTO_CREATED_PLACEHOLDER_TS_COL
422        })
423        .cloned()
424        .collect()
425}
426
427fn find_unsupported_group_key_projection_outputs(
428    plan: &LogicalPlan,
429    aggregate: &Aggregate,
430    group_key_names: &[String],
431) -> Vec<String> {
432    let LogicalPlan::Projection(projection) = plan else {
433        return vec![];
434    };
435
436    let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
437    let group_expr_names = aggregate
438        .group_expr
439        .iter()
440        .filter_map(|expr| expr.name_for_alias().ok())
441        .collect::<HashSet<_>>();
442    projection
443        .expr
444        .iter()
445        .filter_map(|expr| {
446            let output_name = expr.qualified_name().1;
447            if !group_key_names.contains(&output_name) {
448                return None;
449            }
450
451            let source_name = match expr {
452                Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
453                _ => expr.name_for_alias().ok(),
454            };
455            if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
456                None
457            } else {
458                Some(format!(
459                    "unsupported group key output field is not a transparent group expression: {output_name}"
460                ))
461            }
462        })
463        .collect()
464}
465
466pub fn analyze_incremental_aggregate_plan(
467    plan: &LogicalPlan,
468) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
469    let group_key_names = find_group_key_names(plan)?;
470    let aggregate = match extract_incremental_aggregate(plan) {
471        Ok(Some(aggregate)) => aggregate,
472        Ok(None) => return Ok(None),
473        Err(reason) => {
474            let projection_info = collect_output_projection_info(plan);
475            let mut unsupported_exprs = projection_info
476                .duplicate_output_names()
477                .into_iter()
478                .map(|name| format!("duplicate output field name: {name}"))
479                .collect::<Vec<_>>();
480            unsupported_exprs.push(reason);
481            unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
482            return Ok(Some(IncrementalAggregateAnalysis {
483                group_key_names,
484                merge_columns: vec![],
485                literal_columns: vec![],
486                output_field_names: projection_info.output_field_names,
487                unsupported_exprs,
488            }));
489        }
490    };
491    let aggr_exprs = aggregate.aggr_expr.clone();
492    let projection_info = collect_output_projection_info(plan);
493    let output_field_name_set = projection_info.output_field_name_set();
494
495    let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
496    let mut unsupported_exprs = projection_info
497        .duplicate_output_names()
498        .into_iter()
499        .map(|name| format!("duplicate output field name: {name}"))
500        .collect::<Vec<_>>();
501    if has_grouping_set(plan) {
502        unsupported_exprs.push(
503            "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
504        );
505    }
506    if group_key_names.is_empty() {
507        unsupported_exprs
508            .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
509    }
510    unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
511        plan,
512        aggregate,
513        &group_key_names,
514    ));
515    unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
516    for aggr_expr in aggr_exprs {
517        let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
518            Ok(merge_op) => merge_op,
519            Err(reason) => {
520                unsupported_exprs.push(reason);
521                continue;
522            }
523        };
524        let Some(output_field_name) = resolve_aggregate_output_field_name(
525            &aggr_expr,
526            &projection_info,
527            &output_field_name_set,
528        ) else {
529            unsupported_exprs.push(aggr_expr.to_string());
530            continue;
531        };
532        merge_columns.push(IncrementalAggregateMergeColumn::new(
533            output_field_name,
534            merge_op,
535        ));
536    }
537    unsupported_exprs.extend(
538        find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
539            .into_iter()
540            .map(|name| format!("unsupported output field: {name}")),
541    );
542    if !unsupported_exprs.is_empty() {
543        merge_columns.clear();
544    }
545    let mut literal_columns = projection_info
546        .literal_columns
547        .into_iter()
548        .collect::<Vec<_>>();
549    literal_columns.sort();
550
551    Ok(Some(IncrementalAggregateAnalysis {
552        group_key_names,
553        merge_columns,
554        literal_columns,
555        output_field_names: projection_info.output_field_names,
556        unsupported_exprs,
557    }))
558}
559
560/// Rewrites one incremental aggregate delta plan by left-joining it with the
561/// existing sink-table state and projecting merged aggregate outputs.
562///
563/// For a grouped aggregate such as:
564///
565/// ```text
566/// SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts
567/// ```
568///
569/// the rewrite is roughly:
570///
571/// ```text
572/// delta = SELECT ts, number FROM <delta_plan> AS __flow_delta
573/// sink_scan = SELECT * FROM <sink_table> [WHERE <sink_dirty_filter>]
574/// sink  = SELECT ts, number FROM sink_scan AS __flow_sink
575/// SELECT
576///   CASE
577///     WHEN __flow_sink.number IS NULL THEN __flow_delta.number
578///     WHEN __flow_delta.number >= __flow_sink.number THEN __flow_delta.number
579///     ELSE __flow_sink.number
580///   END AS number,
581///   __flow_delta.ts AS ts
582/// FROM delta
583/// LEFT JOIN sink
584///   ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts
585/// ```
586///
587/// If `sink_dirty_filter` is provided, it is applied to the sink table scan
588/// before projection, aliasing, and the left join. The predicate must reference
589/// raw sink table columns structurally (unqualified), before the `__flow_sink`
590/// alias exists.
591pub async fn rewrite_incremental_aggregate_with_sink_merge(
592    delta_plan: &LogicalPlan,
593    analysis: &IncrementalAggregateAnalysis,
594    sink_table: TableRef,
595    sink_table_name: &TableName,
596    sink_dirty_filter: Option<Expr>,
597) -> Result<LogicalPlan, Error> {
598    ensure!(
599        analysis.unsupported_exprs.is_empty(),
600        InvalidQuerySnafu {
601            reason: format!(
602                "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
603                analysis.unsupported_exprs
604            )
605        }
606    );
607
608    ensure!(
609        !analysis.merge_columns.is_empty(),
610        InvalidQuerySnafu {
611            reason:
612                "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
613                    .to_string()
614        }
615    );
616
617    ensure!(
618        !analysis.group_key_names.is_empty(),
619        InvalidQuerySnafu {
620            reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
621                .to_string()
622        }
623    );
624
625    let delta_alias = "__flow_delta";
626    let sink_alias = "__flow_sink";
627
628    let mut selected_columns = analysis.group_key_names.clone();
629    selected_columns.extend(
630        analysis
631            .merge_columns
632            .iter()
633            .map(|c| c.output_field_name.clone()),
634    );
635    let mut delta_selected_columns = selected_columns.clone();
636    delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
637
638    let delta_selected_exprs = delta_selected_columns
639        .iter()
640        .cloned()
641        .map(unqualified_col)
642        .collect::<Vec<_>>();
643    let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
644        .project(delta_selected_exprs)
645        .with_context(|_| DatafusionSnafu {
646            context: "Failed to project delta plan for incremental sink merge".to_string(),
647        })?
648        .alias(delta_alias)
649        .with_context(|_| DatafusionSnafu {
650            context: "Failed to alias delta plan for incremental sink merge".to_string(),
651        })?
652        .build()
653        .with_context(|_| DatafusionSnafu {
654            context: "Failed to build projected delta plan for incremental sink merge".to_string(),
655        })?;
656
657    let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
658    let table_source = Arc::new(DefaultTableSource::new(table_provider));
659    let sink_scan = LogicalPlan::TableScan(
660        TableScan::try_new(
661            TableReference::Full {
662                catalog: sink_table_name[0].clone().into(),
663                schema: sink_table_name[1].clone().into(),
664                table: sink_table_name[2].clone().into(),
665            },
666            table_source,
667            None,
668            vec![],
669            None,
670        )
671        .with_context(|_| DatafusionSnafu {
672            context: "Failed to build sink table scan for incremental sink merge".to_string(),
673        })?,
674    );
675
676    let sink_selected_exprs = selected_columns
677        .iter()
678        .cloned()
679        .map(unqualified_col)
680        .collect::<Vec<_>>();
681    let sink_input = if let Some(predicate) = sink_dirty_filter {
682        LogicalPlanBuilder::from(sink_scan)
683            .filter(predicate)
684            .with_context(|_| DatafusionSnafu {
685                context: "Failed to filter sink table scan for incremental sink merge".to_string(),
686            })?
687            .build()
688            .with_context(|_| DatafusionSnafu {
689                context: "Failed to build filtered sink plan for incremental sink merge"
690                    .to_string(),
691            })?
692    } else {
693        sink_scan
694    };
695
696    let sink_selected = LogicalPlanBuilder::from(sink_input)
697        .project(sink_selected_exprs)
698        .with_context(|_| DatafusionSnafu {
699            context: "Failed to project sink table scan for incremental sink merge".to_string(),
700        })?
701        .alias(sink_alias)
702        .with_context(|_| DatafusionSnafu {
703            context: "Failed to alias sink plan for incremental sink merge".to_string(),
704        })?
705        .build()
706        .with_context(|_| DatafusionSnafu {
707            context: "Failed to build projected sink plan for incremental sink merge".to_string(),
708        })?;
709
710    let join_keys = (
711        analysis
712            .group_key_names
713            .iter()
714            .cloned()
715            .map(|c| qualified_column(delta_alias, c))
716            .collect::<Vec<_>>(),
717        analysis
718            .group_key_names
719            .iter()
720            .cloned()
721            .map(|c| qualified_column(sink_alias, c))
722            .collect::<Vec<_>>(),
723    );
724
725    let joined = LogicalPlanBuilder::from(delta_selected)
726        .join_detailed(
727            sink_selected,
728            JoinType::Left,
729            join_keys,
730            None,
731            NullEquality::NullEqualsNull,
732        )
733        .with_context(|_| DatafusionSnafu {
734            context: "Failed to left join delta and sink plans for incremental sink merge"
735                .to_string(),
736        })?
737        .build()
738        .with_context(|_| DatafusionSnafu {
739            context: "Failed to build left join plan for incremental sink merge".to_string(),
740        })?;
741
742    let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
743    let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
744    let merge_columns = analysis
745        .merge_columns
746        .iter()
747        .map(|c| (&c.output_field_name, c))
748        .collect::<HashMap<_, _>>();
749
750    let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
751    for output_field_name in &analysis.output_field_names {
752        if group_key_names.contains(output_field_name)
753            || literal_columns.contains(output_field_name)
754        {
755            projection_exprs.push(
756                qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
757            );
758        } else if let Some(merge_col) = merge_columns.get(output_field_name) {
759            projection_exprs.push(build_left_join_merge_expr(
760                delta_alias,
761                sink_alias,
762                merge_col,
763            )?);
764        } else {
765            return InvalidQuerySnafu {
766                reason: format!(
767                    "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
768                ),
769            }
770            .fail();
771        }
772    }
773
774    LogicalPlanBuilder::from(joined)
775        .project(projection_exprs)
776        .with_context(|_| DatafusionSnafu {
777            context: "Failed to build projection merge plan for incremental sink merge".to_string(),
778        })?
779        .build()
780        .with_context(|_| DatafusionSnafu {
781            context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
782        })
783}
784
785fn build_left_join_merge_expr(
786    delta_alias: &str,
787    sink_alias: &str,
788    merge_col: &IncrementalAggregateMergeColumn,
789) -> Result<Expr, Error> {
790    let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
791    let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
792    let merged = match merge_col.merge_op {
793        IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
794            .when(is_null(right.clone()), left.clone())
795            .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
796            .with_context(|_| DatafusionSnafu {
797                context: "Failed to build SUM merge expression".to_string(),
798            })?,
799        IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
800            .when(left.clone().lt_eq(right.clone()), left.clone())
801            .otherwise(right.clone())
802            .with_context(|_| DatafusionSnafu {
803                context: "Failed to build MIN merge expression".to_string(),
804            })?,
805        IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
806            .when(left.clone().gt_eq(right.clone()), left.clone())
807            .otherwise(right.clone())
808            .with_context(|_| DatafusionSnafu {
809                context: "Failed to build MAX merge expression".to_string(),
810            })?,
811        IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
812            .when(is_null(right.clone()), left.clone())
813            .otherwise(and(left.clone(), right.clone()))
814            .with_context(|_| DatafusionSnafu {
815                context: "Failed to build BOOL_AND merge expression".to_string(),
816            })?,
817        IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
818            .when(is_null(right.clone()), left.clone())
819            .otherwise(or(left.clone(), right.clone()))
820            .with_context(|_| DatafusionSnafu {
821                context: "Failed to build BOOL_OR merge expression".to_string(),
822            })?,
823        IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
824            .when(is_null(right.clone()), left.clone())
825            .otherwise(bitwise_and(left.clone(), right.clone()))
826            .with_context(|_| DatafusionSnafu {
827                context: "Failed to build BIT_AND merge expression".to_string(),
828            })?,
829        IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
830            .when(is_null(right.clone()), left.clone())
831            .otherwise(bitwise_or(left.clone(), right.clone()))
832            .with_context(|_| DatafusionSnafu {
833                context: "Failed to build BIT_OR merge expression".to_string(),
834            })?,
835        IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
836            .when(is_null(right.clone()), left.clone())
837            .otherwise(bitwise_xor(left.clone(), right.clone()))
838            .with_context(|_| DatafusionSnafu {
839                context: "Failed to build BIT_XOR merge expression".to_string(),
840            })?,
841    };
842    Ok(merged.alias(merge_col.output_field_name.clone()))
843}
844
845pub async fn get_table_info_df_schema(
846    catalog_mr: CatalogManagerRef,
847    table_name: TableName,
848) -> Result<(TableRef, Arc<DFSchema>), Error> {
849    let full_table_name = table_name.clone().join(".");
850    let table = catalog_mr
851        .table(&table_name[0], &table_name[1], &table_name[2], None)
852        .await
853        .map_err(BoxedError::new)
854        .context(ExternalSnafu)?
855        .context(TableNotFoundSnafu {
856            name: &full_table_name,
857        })?;
858    let table_info = table.table_info();
859
860    let schema = table_info.meta.schema.clone();
861
862    let df_schema: Arc<DFSchema> = Arc::new(
863        schema
864            .arrow_schema()
865            .clone()
866            .try_into()
867            .with_context(|_| DatafusionSnafu {
868                context: format!(
869                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
870                    schema.arrow_schema()
871                ),
872            })?,
873    );
874    Ok((table, df_schema))
875}
876
877/// Convert sql to datafusion logical plan
878/// Also support TQL (but only Eval not Explain or Analyze)
879pub async fn sql_to_df_plan(
880    query_ctx: QueryContextRef,
881    engine: QueryEngineRef,
882    sql: &str,
883    optimize: bool,
884) -> Result<LogicalPlan, Error> {
885    let stmts =
886        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
887            .map_err(BoxedError::new)
888            .context(ExternalSnafu)?;
889
890    ensure!(
891        stmts.len() == 1,
892        InvalidQuerySnafu {
893            reason: format!("Expect only one statement, found {}", stmts.len())
894        }
895    );
896    let stmt = &stmts[0];
897    let query_stmt = match stmt {
898        Statement::Tql(tql) => match tql {
899            Tql::Eval(eval) => {
900                let eval = eval.clone();
901                let promql = PromQuery {
902                    start: eval.start,
903                    end: eval.end,
904                    step: eval.step,
905                    query: eval.query,
906                    lookback: eval
907                        .lookback
908                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
909                    alias: eval.alias.clone(),
910                };
911
912                QueryLanguageParser::parse_promql(&promql, &query_ctx)
913                    .map_err(BoxedError::new)
914                    .context(ExternalSnafu)?
915            }
916            _ => InvalidQuerySnafu {
917                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
918            }
919            .fail()?,
920        },
921        _ => QueryStatement::Sql(stmt.clone()),
922    };
923    let plan = engine
924        .planner()
925        .plan(&query_stmt, query_ctx.clone())
926        .await
927        .map_err(BoxedError::new)
928        .context(ExternalSnafu)?;
929
930    let plan = if optimize {
931        apply_df_optimizer(plan, &query_ctx).await?
932    } else {
933        plan
934    };
935    Ok(plan)
936}
937
938/// Generate a plan that matches the schema of the sink table
939/// from given sql by alias and adding auto columns
940pub(crate) async fn gen_plan_with_matching_schema(
941    sql: &str,
942    query_ctx: QueryContextRef,
943    engine: QueryEngineRef,
944    sink_table_schema: SchemaRef,
945    primary_key_indices: &[usize],
946    allow_partial: bool,
947) -> Result<LogicalPlan, Error> {
948    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
949
950    let mut add_auto_column = ColumnMatcherRewriter::new(
951        sink_table_schema,
952        primary_key_indices.to_vec(),
953        allow_partial,
954    );
955    let plan = plan
956        .clone()
957        .rewrite(&mut add_auto_column)
958        .with_context(|_| DatafusionSnafu {
959            context: "Failed to rewrite plan".to_string(),
960        })?
961        .data;
962    Ok(plan)
963}
964
965pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
966    /// A dialect that forces identifiers to be quoted when have uppercase
967    struct ForceQuoteIdentifiers;
968    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
969        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
970            if identifier.to_lowercase() != identifier {
971                Some('`')
972            } else {
973                None
974            }
975        }
976    }
977    let unparser = Unparser::new(&ForceQuoteIdentifiers);
978    // first make all column qualified
979    let sql = unparser
980        .plan_to_sql(plan)
981        .with_context(|_e| DatafusionSnafu {
982            context: format!("Failed to unparse logical plan {plan:?}"),
983        })?;
984    Ok(sql.to_string())
985}
986
987/// Helper to find the innermost group by expr in schema, return None if no group by expr
988#[derive(Debug, Clone, Default)]
989pub struct FindGroupByFinalName {
990    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
991}
992
993impl FindGroupByFinalName {
994    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
995        self.group_exprs
996            .as_ref()
997            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
998    }
999}
1000
1001impl TreeNodeVisitor<'_> for FindGroupByFinalName {
1002    type Node = LogicalPlan;
1003
1004    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1005        if let LogicalPlan::Aggregate(aggregate) = node {
1006            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
1007            debug!(
1008                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
1009                self.group_exprs
1010            );
1011        } else if let LogicalPlan::Distinct(distinct) = node {
1012            debug!("FindGroupByFinalName: Distinct: {}", node);
1013            match distinct {
1014                Distinct::All(input) => {
1015                    if let LogicalPlan::TableScan(table_scan) = &**input {
1016                        // get column from field_qualifier, projection and projected_schema:
1017                        let len = table_scan.projected_schema.fields().len();
1018                        let columns = (0..len)
1019                            .map(|f| {
1020                                let (qualifier, field) =
1021                                    table_scan.projected_schema.qualified_field(f);
1022                                datafusion_common::Column::new(qualifier.cloned(), field.name())
1023                            })
1024                            .map(datafusion_expr::Expr::Column);
1025                        self.group_exprs = Some(columns.collect());
1026                    } else {
1027                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
1028                    }
1029                }
1030                Distinct::On(distinct_on) => {
1031                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
1032                }
1033            }
1034            debug!(
1035                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
1036                self.group_exprs
1037            );
1038        }
1039
1040        Ok(TreeNodeRecursion::Continue)
1041    }
1042
1043    /// deal with projection when going up with group exprs
1044    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1045        if let LogicalPlan::Projection(projection) = node {
1046            for expr in &projection.expr {
1047                let Some(group_exprs) = &mut self.group_exprs else {
1048                    return Ok(TreeNodeRecursion::Continue);
1049                };
1050                if let datafusion_expr::Expr::Alias(alias) = expr {
1051                    // if a alias exist, replace with the new alias
1052                    let mut new_group_exprs = group_exprs.clone();
1053                    for group_expr in group_exprs.iter() {
1054                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
1055                            new_group_exprs.remove(group_expr);
1056                            new_group_exprs.insert(expr.clone());
1057                            break;
1058                        }
1059                    }
1060                    *group_exprs = new_group_exprs;
1061                }
1062            }
1063        }
1064        debug!("Aliased group by exprs: {:?}", self.group_exprs);
1065        Ok(TreeNodeRecursion::Continue)
1066    }
1067}
1068
1069/// Optionally add to the final select columns like `update_at` if the sink table has such column
1070/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
1071/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
1072/// with values like `now()` and `0`
1073///
1074/// it also give existing columns alias to column in sink table if needed
1075#[derive(Debug)]
1076pub struct ColumnMatcherRewriter {
1077    pub schema: SchemaRef,
1078    pub is_rewritten: bool,
1079    pub primary_key_indices: Vec<usize>,
1080    pub allow_partial: bool,
1081}
1082
1083impl ColumnMatcherRewriter {
1084    pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1085        Self {
1086            schema,
1087            is_rewritten: false,
1088            primary_key_indices,
1089            allow_partial,
1090        }
1091    }
1092
1093    /// modify the exprs in place so that it matches the schema and some auto columns are added
1094    fn modify_project_exprs(
1095        &mut self,
1096        mut exprs: Vec<Expr>,
1097        input_schema: &DFSchema,
1098    ) -> DfResult<Vec<Expr>> {
1099        if self.allow_partial {
1100            return self.modify_project_exprs_with_partial(exprs);
1101        }
1102
1103        let original_exprs = exprs.clone();
1104
1105        let all_names = self
1106            .schema
1107            .column_schemas()
1108            .iter()
1109            .map(|c| c.name.clone())
1110            .collect::<BTreeSet<_>>();
1111        // add columns if have different column count
1112        let query_col_cnt = exprs.len();
1113        let table_col_cnt = self.schema.column_schemas().len();
1114        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1115
1116        let placeholder_ts_expr =
1117            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1118                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1119
1120        if query_col_cnt == table_col_cnt {
1121            // still need to add alias, see below
1122        } else if query_col_cnt + 1 == table_col_cnt {
1123            let last_col_schema = self.schema.column_schemas().last().unwrap();
1124
1125            // if time index column is auto created add it
1126            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1127                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1128            {
1129                exprs.push(placeholder_ts_expr);
1130            } else if last_col_schema.data_type.is_timestamp() {
1131                // is the update at column
1132                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1133            } else {
1134                return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1135                    &original_exprs,
1136                    self.schema.as_ref(),
1137                )));
1138            }
1139        } else if query_col_cnt + 2 == table_col_cnt {
1140            let mut col_iter = self.schema.column_schemas().iter().rev();
1141            let last_col_schema = col_iter.next().unwrap();
1142            let second_last_col_schema = col_iter.next().unwrap();
1143            if second_last_col_schema.data_type.is_timestamp() {
1144                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1145            } else {
1146                return Err(DataFusionError::Plan(format!(
1147                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1148                    second_last_col_schema.name, second_last_col_schema.data_type
1149                )));
1150            }
1151
1152            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1153                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1154            {
1155                exprs.push(placeholder_ts_expr);
1156            } else {
1157                return Err(DataFusionError::Plan(format!(
1158                    "Expect timestamp column {}, found {:?}",
1159                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1160                )));
1161            }
1162        } else {
1163            return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1164                &original_exprs,
1165                self.schema.as_ref(),
1166            )));
1167        }
1168
1169        self.match_extra_output_columns(exprs, input_schema, &original_exprs, &all_names)
1170    }
1171
1172    /// Match flow output columns whose names are not in the sink schema by the same position only.
1173    ///
1174    /// This keeps the legacy "omit output aliases and map by position" behavior, but only when the
1175    /// sink column at the same index is actually missing from the flow output. If the extra output
1176    /// would be aliased to a sink column that already exists elsewhere, report a schema mismatch
1177    /// instead of guessing another sink column by type.
1178    ///
1179    /// In particular, this intentionally rejects cross-position remaps like
1180    /// `record_time_window2 -> record_time_window`: they are easy to confuse with real schema
1181    /// mismatches and should be fixed by giving the flow output the sink column name explicitly.
1182    fn match_extra_output_columns(
1183        &self,
1184        mut exprs: Vec<Expr>,
1185        input_schema: &DFSchema,
1186        original_exprs: &[Expr],
1187        all_names: &BTreeSet<String>,
1188    ) -> DfResult<Vec<Expr>> {
1189        let mut output_names = exprs
1190            .iter()
1191            .map(|expr| expr.qualified_name().1)
1192            .collect::<Vec<_>>();
1193        let output_name_set = output_names.iter().cloned().collect::<BTreeSet<_>>();
1194        let extra_expr_indices = output_names
1195            .iter()
1196            .enumerate()
1197            .filter_map(|(idx, name)| (!all_names.contains(name)).then_some(idx))
1198            .collect::<Vec<_>>();
1199        let missing_sink_indices = self
1200            .schema
1201            .column_schemas()
1202            .iter()
1203            .enumerate()
1204            .filter_map(|(idx, column)| (!output_name_set.contains(&column.name)).then_some(idx))
1205            .collect::<Vec<_>>();
1206
1207        if extra_expr_indices.is_empty() && missing_sink_indices.is_empty() {
1208            return Ok(exprs);
1209        }
1210
1211        if extra_expr_indices.len() != missing_sink_indices.len() {
1212            return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1213                original_exprs,
1214                self.schema.as_ref(),
1215            )));
1216        }
1217
1218        let mut positional_matches = Vec::new();
1219        for expr_idx in extra_expr_indices {
1220            if !missing_sink_indices.contains(&expr_idx) {
1221                return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1222                    original_exprs,
1223                    self.schema.as_ref(),
1224                )));
1225            }
1226
1227            let target_col_schema = &self.schema.column_schemas()[expr_idx];
1228            let expr_type =
1229                ConcreteDataType::from_arrow_type(&exprs[expr_idx].get_type(input_schema)?);
1230            if is_obviously_incompatible_positional_match(&expr_type, &target_col_schema.data_type)
1231            {
1232                return Err(DataFusionError::Plan(format!(
1233                    "Cannot match flow output column '{}' to sink column '{}' by position: incompatible data types, flow output type is {:?}, sink column type is {:?}. {}",
1234                    output_names[expr_idx],
1235                    target_col_schema.name,
1236                    expr_type,
1237                    target_col_schema.data_type,
1238                    format_flow_sink_schema_mismatch(original_exprs, self.schema.as_ref())
1239                )));
1240            }
1241
1242            let target_name = target_col_schema.name.clone();
1243            positional_matches.push(format!(
1244                "{} -> {} (flow output type: {:?}, sink column type: {:?})",
1245                output_names[expr_idx], target_name, expr_type, target_col_schema.data_type
1246            ));
1247            exprs[expr_idx] = exprs[expr_idx].clone().alias(target_name.clone());
1248            output_names[expr_idx] = target_name;
1249        }
1250
1251        if !positional_matches.is_empty() {
1252            debug!(
1253                "Matched flow output columns to sink columns by position: {:?}",
1254                positional_matches
1255            );
1256        }
1257
1258        let duplicated_output_names = duplicate_names(&output_names);
1259        if !duplicated_output_names.is_empty() {
1260            return Err(DataFusionError::Plan(format!(
1261                "Flow output schema contains duplicate column(s) after schema matching {:?}. {}",
1262                duplicated_output_names,
1263                format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1264            )));
1265        }
1266
1267        Ok(exprs)
1268    }
1269
1270    fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1271        let table_col_cnt = self.schema.column_schemas().len();
1272        let query_col_cnt = exprs.len();
1273
1274        if query_col_cnt > table_col_cnt {
1275            return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1276                &exprs,
1277                self.schema.as_ref(),
1278            )));
1279        }
1280
1281        let name_to_expr: HashMap<String, Expr> = exprs
1282            .clone()
1283            .into_iter()
1284            .map(|e| (e.qualified_name().1, e))
1285            .collect();
1286
1287        let required_columns = self.required_columns_for_partial();
1288        let missing: Vec<_> = required_columns
1289            .iter()
1290            .filter(|name| !name_to_expr.contains_key(*name))
1291            .cloned()
1292            .collect();
1293        if !missing.is_empty() {
1294            return Err(DataFusionError::Plan(format!(
1295                "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null. {}",
1296                missing,
1297                format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1298            )));
1299        }
1300
1301        let placeholder_ts_expr =
1302            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1303                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1304
1305        let timestamp_index = self.schema.timestamp_index();
1306        let mut remap = name_to_expr;
1307        let mut new_exprs = Vec::with_capacity(table_col_cnt);
1308
1309        for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1310            let col_name = col_schema.name.clone();
1311            if let Some(expr) = remap.remove(&col_name) {
1312                let expr = if expr.qualified_name().1 == col_name {
1313                    expr
1314                } else {
1315                    expr.alias(col_name.clone())
1316                };
1317                new_exprs.push(expr);
1318                continue;
1319            }
1320
1321            if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1322                new_exprs.push(placeholder_ts_expr.clone());
1323                continue;
1324            }
1325
1326            if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1327                new_exprs.push(datafusion::prelude::now().alias(&col_name));
1328                continue;
1329            }
1330
1331            new_exprs.push(Self::null_expr(col_schema));
1332        }
1333
1334        if !remap.is_empty() {
1335            let extra: Vec<_> = remap.keys().cloned().collect();
1336            return Err(DataFusionError::Plan(format!(
1337                "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null. {}",
1338                extra,
1339                format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1340            )));
1341        }
1342
1343        Ok(new_exprs)
1344    }
1345
1346    fn null_expr(col_schema: &ColumnSchema) -> Expr {
1347        Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1348    }
1349
1350    fn required_columns_for_partial(&self) -> HashSet<String> {
1351        let mut required = HashSet::new();
1352        for idx in &self.primary_key_indices {
1353            if let Some(col) = self.schema.column_schemas().get(*idx) {
1354                required.insert(col.name.clone());
1355            }
1356        }
1357
1358        if let Some(ts_idx) = self.schema.timestamp_index()
1359            && let Some(col) = self.schema.column_schemas().get(ts_idx)
1360            && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1361        {
1362            required.insert(col.name.clone());
1363        }
1364
1365        required
1366    }
1367}
1368
1369fn is_obviously_incompatible_positional_match(
1370    expr_type: &ConcreteDataType,
1371    sink_type: &ConcreteDataType,
1372) -> bool {
1373    // This is a coarse type-family guard for legacy positional aliasing, not a strict type equality
1374    // check. For example, numeric width/sign differences are allowed here and left to downstream
1375    // coercion, and untyped NULL can be coerced to any target type. Clearly different families such
1376    // as timestamp vs string are rejected early.
1377    if expr_type.is_null() || expr_type == sink_type {
1378        return false;
1379    }
1380
1381    expr_type.is_timestamp() != sink_type.is_timestamp()
1382        || expr_type.is_string() != sink_type.is_string()
1383        || expr_type.is_boolean() != sink_type.is_boolean()
1384        || expr_type.is_json() != sink_type.is_json()
1385        || expr_type.is_vector() != sink_type.is_vector()
1386}
1387
1388fn duplicate_names(names: &[String]) -> Vec<String> {
1389    let mut seen = HashSet::new();
1390    let mut duplicated = BTreeSet::new();
1391    for name in names {
1392        if !seen.insert(name.as_str()) {
1393            duplicated.insert(name.as_str());
1394        }
1395    }
1396    duplicated.into_iter().map(str::to_string).collect()
1397}
1398
1399fn format_flow_sink_schema_mismatch(
1400    query_exprs: &[Expr],
1401    sink_schema: &datatypes::schema::Schema,
1402) -> String {
1403    let flow_output_columns = query_exprs
1404        .iter()
1405        .map(|expr| expr.qualified_name().1)
1406        .collect::<Vec<_>>();
1407    let sink_table_columns = sink_schema
1408        .column_schemas()
1409        .iter()
1410        .map(|col| col.name.clone())
1411        .collect::<Vec<_>>();
1412
1413    let flow_output_set = flow_output_columns.iter().cloned().collect::<HashSet<_>>();
1414    let sink_table_set = sink_table_columns.iter().cloned().collect::<HashSet<_>>();
1415
1416    let mut extra_flow_columns = flow_output_columns
1417        .iter()
1418        .filter(|name| !sink_table_set.contains(*name))
1419        .cloned()
1420        .collect::<Vec<_>>();
1421    extra_flow_columns.sort();
1422    extra_flow_columns.dedup();
1423
1424    let mut missing_sink_columns = sink_table_columns
1425        .iter()
1426        .filter(|name| !flow_output_set.contains(*name))
1427        .cloned()
1428        .collect::<Vec<_>>();
1429    missing_sink_columns.sort();
1430    missing_sink_columns.dedup();
1431
1432    format!(
1433        "Flow output schema does not match sink table schema: found {} flow output columns and {} sink table columns. flow output columns: {:?}, sink table columns: {:?}, extra flow columns not in sink: {:?}, missing sink columns from flow output: {:?}",
1434        flow_output_columns.len(),
1435        sink_table_columns.len(),
1436        flow_output_columns,
1437        sink_table_columns,
1438        extra_flow_columns,
1439        missing_sink_columns
1440    )
1441}
1442
1443impl TreeNodeRewriter for ColumnMatcherRewriter {
1444    type Node = LogicalPlan;
1445    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1446        if self.is_rewritten {
1447            return Ok(Transformed::no(node));
1448        }
1449
1450        // if is distinct all, wrap it in a projection
1451        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1452            let mut exprs = vec![];
1453
1454            for field in node.schema().fields().iter() {
1455                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1456                    field.name(),
1457                )));
1458            }
1459
1460            let projection =
1461                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1462
1463            node = projection;
1464        }
1465        // handle table_scan by wrap it in a projection
1466        else if let LogicalPlan::TableScan(table_scan) = node {
1467            let mut exprs = vec![];
1468
1469            for field in table_scan.projected_schema.fields().iter() {
1470                exprs.push(Expr::Column(datafusion::common::Column::new(
1471                    Some(table_scan.table_name.clone()),
1472                    field.name(),
1473                )));
1474            }
1475
1476            let projection = LogicalPlan::Projection(Projection::try_new(
1477                exprs,
1478                Arc::new(LogicalPlan::TableScan(table_scan)),
1479            )?);
1480
1481            node = projection;
1482        }
1483
1484        // only do rewrite if found the outermost projection
1485        // if the outermost node is projection, can rewrite the exprs
1486        // if not, wrap it in a projection
1487        if let LogicalPlan::Projection(project) = &node {
1488            let exprs = project.expr.clone();
1489            let exprs = self.modify_project_exprs(exprs, project.input.schema())?;
1490
1491            self.is_rewritten = true;
1492            let new_plan =
1493                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1494            Ok(Transformed::yes(new_plan))
1495        } else {
1496            // wrap the logical plan in a projection
1497            let mut exprs = vec![];
1498            for field in node.schema().fields().iter() {
1499                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1500                    field.name(),
1501                )));
1502            }
1503            let exprs = self.modify_project_exprs(exprs, node.schema())?;
1504            self.is_rewritten = true;
1505            let new_plan =
1506                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1507            Ok(Transformed::yes(new_plan))
1508        }
1509    }
1510
1511    /// We might add new columns, so we need to recompute the schema
1512    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1513        node.recompute_schema().map(Transformed::yes)
1514    }
1515}
1516
1517/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
1518#[derive(Debug)]
1519pub struct AddFilterRewriter {
1520    extra_filter: Expr,
1521    is_rewritten: bool,
1522}
1523
1524impl AddFilterRewriter {
1525    pub fn new(filter: Expr) -> Self {
1526        Self {
1527            extra_filter: filter,
1528            is_rewritten: false,
1529        }
1530    }
1531}
1532
1533impl TreeNodeRewriter for AddFilterRewriter {
1534    type Node = LogicalPlan;
1535    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1536        if self.is_rewritten {
1537            return Ok(Transformed::no(node));
1538        }
1539        match node {
1540            LogicalPlan::Filter(mut filter) => {
1541                filter.predicate = filter.predicate.and(self.extra_filter.clone());
1542                self.is_rewritten = true;
1543                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1544            }
1545            LogicalPlan::TableScan(_) => {
1546                // add a new filter
1547                let filter =
1548                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1549                self.is_rewritten = true;
1550                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1551            }
1552            _ => Ok(Transformed::no(node)),
1553        }
1554    }
1555}