diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs index 45fef00a1b..dd929f1427 100644 --- a/src/query/src/dist_plan/analyzer.rs +++ b/src/query/src/dist_plan/analyzer.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; use common_telemetry::debug; @@ -32,6 +32,7 @@ use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::metadata::TableType; use table::table::adapter::DfTableProviderAdapter; +use crate::dist_plan::analyzer::utils::aliased_columns_for; use crate::dist_plan::commutativity::{ Categorizer, Commutativity, partial_commutative_transformer, }; @@ -46,7 +47,7 @@ mod test; mod fallback; mod utils; -pub(crate) use utils::{AliasMapping, AliasTracker}; +pub(crate) use utils::AliasMapping; #[derive(Debug, Clone)] pub struct DistPlannerOptions { @@ -229,8 +230,7 @@ struct PlanRewriter { stage: Vec, status: RewriterStatus, /// Partition columns of the table in current pass - partition_cols: Option>, - alias_tracker: Option, + partition_cols: Option, /// use stack count as scope to determine column requirements is needed or not /// i.e for a logical plan like: /// ```ignore @@ -311,7 +311,7 @@ impl PlanRewriter { } if self.expand_on_next_part_cond_trans_commutative { - let comm = Categorizer::check_plan(plan, self.get_aliased_partition_columns()); + let comm = Categorizer::check_plan(plan, self.partition_cols.clone()); match comm { Commutativity::PartialCommutative => { // a small difference is that for partial commutative, we still need to @@ -333,7 +333,7 @@ impl PlanRewriter { } } - match Categorizer::check_plan(plan, self.get_aliased_partition_columns()) { + match Categorizer::check_plan(plan, self.partition_cols.clone()) { Commutativity::Commutative => {} Commutativity::PartialCommutative => { if let Some(plan) = partial_commutative_transformer(plan) { @@ -427,49 +427,31 @@ impl PlanRewriter { self.status = RewriterStatus::Unexpanded; } - /// Maybe update alias for original table columns in the plan - fn maybe_update_alias(&mut self, node: &LogicalPlan) { - if let Some(alias_tracker) = &mut self.alias_tracker { - alias_tracker.update_alias(node); - debug!( - "Current partition columns are: {:?}", - self.get_aliased_partition_columns() - ); - } else if let LogicalPlan::TableScan(table_scan) = node { - self.alias_tracker = AliasTracker::new(table_scan); - debug!( - "Initialize partition columns: {:?} with table={}", - self.get_aliased_partition_columns(), - table_scan.table_name - ); - } - } + fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> { + if let Some(part_cols) = &mut self.partition_cols { + // update partition alias + let child = plan.inputs().first().cloned().ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "PlanRewriter: maybe_set_partitions: plan has no child: {plan}" + )) + })?; - fn get_aliased_partition_columns(&self) -> Option { - if let Some(part_cols) = self.partition_cols.as_ref() { - let Some(alias_tracker) = &self.alias_tracker else { - // no alias tracker meaning no table scan encountered - return None; - }; - let mut aliased = HashMap::new(); - for part_col in part_cols { - let all_alias = alias_tracker - .get_all_alias_for_col(part_col) - .cloned() - .unwrap_or_default(); - - aliased.insert(part_col.clone(), all_alias); + for (_col_name, alias_set) in part_cols.iter_mut() { + let aliased_cols = aliased_columns_for( + &alias_set.clone().into_iter().collect(), + plan, + Some(child), + )?; + *alias_set = aliased_cols.into_values().flatten().collect(); } - Some(aliased) - } else { - None - } - } - fn maybe_set_partitions(&mut self, plan: &LogicalPlan) { - if self.partition_cols.is_some() { - // only need to set once - return; + debug!( + "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}", + part_cols, + plan.display() + ); + + return Ok(()); } if let LogicalPlan::TableScan(table_scan) = plan @@ -507,9 +489,31 @@ impl PlanRewriter { // This helps with distinguishing between non-partitioned table and partitioned table with all phy part cols not in logical table partition_cols.push("__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__".to_string()); } - self.partition_cols = Some(partition_cols); + self.partition_cols = Some( + partition_cols + .into_iter() + .map(|c| { + let index = + plan.schema().index_of_column_by_name(None, &c).ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + format!( + "PlanRewriter: maybe_set_partitions: column {c} not found in schema of plan: {plan}" + ), + ) + })?; + let column = plan.schema().columns().get(index).cloned().ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}" + )) + })?; + Ok((c.clone(), BTreeSet::from([column]))) + }) + .collect::>()?, + ); } } + + Ok(()) } /// pop one stack item and reduce the level by 1 @@ -537,6 +541,11 @@ impl PlanRewriter { "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}" ); + debug!( + "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}", + self.partition_cols + ); + // add merge scan as the new root let mut node = MergeScanLogicalPlan::new( on_node, @@ -677,7 +686,6 @@ impl TreeNodeRewriter for PlanRewriter { self.stage.clear(); self.set_unexpanded(); self.partition_cols = None; - self.alias_tracker = None; Ok(Transformed::no(node)) } @@ -698,9 +706,7 @@ impl TreeNodeRewriter for PlanRewriter { return Ok(Transformed::no(node)); } - self.maybe_set_partitions(&node); - - self.maybe_update_alias(&node); + self.maybe_set_partitions(&node)?; let Some(parent) = self.get_parent() else { debug!("Plan Rewriter: expand now for no parent found for node: {node}"); diff --git a/src/query/src/dist_plan/analyzer/fallback.rs b/src/query/src/dist_plan/analyzer/fallback.rs index 2a6e098caa..327296f67c 100644 --- a/src/query/src/dist_plan/analyzer/fallback.rs +++ b/src/query/src/dist_plan/analyzer/fallback.rs @@ -17,14 +17,18 @@ //! This is a temporary solution, and will be removed once we have a more robust plan rewriter //! +use std::collections::BTreeSet; + use common_telemetry::debug; use datafusion::datasource::DefaultTableSource; +use datafusion_common::Result as DfResult; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_expr::LogicalPlan; use table::metadata::TableType; use table::table::adapter::DfTableProviderAdapter; use crate::dist_plan::MergeScanLogicalPlan; +use crate::dist_plan::analyzer::AliasMapping; /// FallbackPlanRewriter is a plan rewriter that will only push down table scan node /// This is used when `PlanRewriter` produce errors when trying to rewrite the plan @@ -38,9 +42,9 @@ impl TreeNodeRewriter for FallbackPlanRewriter { fn f_down( &mut self, - node: Self::Node, - ) -> datafusion_common::Result> { - if let LogicalPlan::TableScan(table_scan) = &node { + plan: Self::Node, + ) -> DfResult> { + if let LogicalPlan::TableScan(table_scan) = &plan { let partition_cols = if let Some(source) = table_scan .source .as_any() @@ -63,7 +67,25 @@ impl TreeNodeRewriter for FallbackPlanRewriter { "FallbackPlanRewriter: table {} has partition columns: {:?}", info.name, partition_cols ); - Some(partition_cols) + Some(partition_cols + .into_iter() + .map(|c| { + let index = + plan.schema().index_of_column_by_name(None, &c).ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + format!( + "PlanRewriter: maybe_set_partitions: column {c} not found in schema of plan: {plan}" + ), + ) + })?; + let column = plan.schema().columns().get(index).cloned().ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}" + )) + })?; + Ok((c.clone(), BTreeSet::from([column]))) + }) + .collect::>()?) } else { None } @@ -74,7 +96,7 @@ impl TreeNodeRewriter for FallbackPlanRewriter { None }; let node = MergeScanLogicalPlan::new( - node, + plan, false, // at this stage, the partition cols should be set // treat it as non-partitioned if None @@ -83,7 +105,7 @@ impl TreeNodeRewriter for FallbackPlanRewriter { .into_logical_plan(); Ok(Transformed::yes(node)) } else { - Ok(Transformed::no(node)) + Ok(Transformed::no(plan)) } } } diff --git a/src/query/src/dist_plan/analyzer/test.rs b/src/query/src/dist_plan/analyzer/test.rs index 2dafc91395..888cd48432 100644 --- a/src/query/src/dist_plan/analyzer/test.rs +++ b/src/query/src/dist_plan/analyzer/test.rs @@ -156,7 +156,7 @@ impl Stream for EmptyStream { fn expand_proj_sort_proj() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -203,7 +203,7 @@ fn expand_proj_sort_proj() { fn expand_sort_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -237,7 +237,7 @@ fn expand_sort_limit() { fn expand_sort_alias_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -276,7 +276,7 @@ fn expand_sort_alias_limit() { fn expand_sort_alias_conflict_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -318,7 +318,7 @@ fn expand_sort_alias_conflict_limit() { fn expand_sort_alias_conflict_but_not_really_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -358,7 +358,7 @@ fn expand_sort_alias_conflict_but_not_really_limit() { fn expand_limit_sort() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -391,7 +391,7 @@ fn expand_limit_sort() { fn expand_sort_limit_sort() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -438,7 +438,7 @@ fn expand_sort_limit_sort() { fn expand_proj_step_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -473,7 +473,7 @@ fn expand_proj_step_aggr() { fn expand_proj_alias_fake_part_col_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -517,7 +517,7 @@ fn expand_proj_alias_fake_part_col_aggr() { fn expand_proj_alias_aliased_part_col_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -563,7 +563,7 @@ fn expand_proj_alias_aliased_part_col_aggr() { fn expand_part_col_aggr_step_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -596,7 +596,7 @@ fn expand_part_col_aggr_step_aggr() { fn expand_step_aggr_step_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -629,7 +629,7 @@ fn expand_step_aggr_step_aggr() { fn expand_part_col_aggr_part_col_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -673,7 +673,7 @@ fn expand_part_col_aggr_part_col_aggr() { fn expand_step_aggr_proj() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -709,7 +709,7 @@ fn expand_step_aggr_proj() { fn expand_proj_sort_step_aggr_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -750,7 +750,7 @@ fn expand_proj_sort_step_aggr_limit() { fn expand_proj_sort_limit_step_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -792,7 +792,7 @@ fn expand_proj_sort_limit_step_aggr() { fn expand_proj_limit_step_aggr_sort() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -833,7 +833,7 @@ fn expand_proj_limit_step_aggr_sort() { fn expand_proj_sort_part_col_aggr_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -875,7 +875,7 @@ fn expand_proj_sort_part_col_aggr_limit() { fn expand_proj_sort_limit_part_col_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -917,7 +917,7 @@ fn expand_proj_sort_limit_part_col_aggr() { fn expand_proj_part_col_aggr_limit_sort() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -959,7 +959,7 @@ fn expand_proj_part_col_aggr_limit_sort() { fn expand_proj_part_col_aggr_sort_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -1002,7 +1002,7 @@ fn expand_proj_part_col_aggr_sort_limit() { fn expand_proj_limit_part_col_aggr_sort() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -1044,7 +1044,7 @@ fn expand_proj_limit_part_col_aggr_sort() { fn expand_proj_limit_sort_part_col_aggr() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -1087,7 +1087,7 @@ fn expand_proj_limit_sort_part_col_aggr() { fn expand_step_aggr_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -1120,7 +1120,7 @@ fn expand_step_aggr_limit() { fn expand_step_aggr_avg_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -1153,7 +1153,7 @@ fn expand_step_aggr_avg_limit() { fn expand_part_col_aggr_limit() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); diff --git a/src/query/src/dist_plan/analyzer/utils.rs b/src/query/src/dist_plan/analyzer/utils.rs index e064e8fcea..652e7a34e1 100644 --- a/src/query/src/dist_plan/analyzer/utils.rs +++ b/src/query/src/dist_plan/analyzer/utils.rs @@ -12,247 +12,374 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::sync::Arc; -use datafusion::datasource::DefaultTableSource; +use datafusion::error::Result as DfResult; use datafusion_common::Column; -use datafusion_expr::{Expr, LogicalPlan, TableScan}; -use table::metadata::TableType; -use table::table::adapter::DfTableProviderAdapter; +use datafusion_expr::expr::Alias; +use datafusion_expr::{Expr, LogicalPlan}; -/// Mapping of original column in table to all the alias at current node -pub type AliasMapping = HashMap>; +/// Return all the original columns(at original node) for the given aliased columns at the aliased node +/// +/// if `original_node` is None, it means original columns are from leaf node +#[allow(unused)] +pub fn original_column_for( + aliased_columns: &HashSet, + aliased_node: LogicalPlan, + original_node: Option>, +) -> DfResult> { + let schema_cols: HashSet = aliased_node.schema().columns().iter().cloned().collect(); + let cur_aliases: HashMap = aliased_columns + .iter() + .filter(|c| schema_cols.contains(c)) + .map(|c| (c.clone(), c.clone())) + .collect(); -/// tracking aliases for the source table columns in the plan -#[derive(Debug, Clone)] -pub struct AliasTracker { - /// mapping from the original table name to the alias used in the plan - /// notice how one column might have multiple aliases in the plan - /// - pub mapping: AliasMapping, -} - -impl AliasTracker { - pub fn new(table_scan: &TableScan) -> Option { - if let Some(source) = table_scan - .source - .as_any() - .downcast_ref::() - && let Some(provider) = source - .table_provider - .as_any() - .downcast_ref::() - && provider.table().table_type() == TableType::Base - { - let info = provider.table().table_info(); - let schema = info.meta.schema.clone(); - let col_schema = schema.column_schemas(); - let mapping = col_schema - .iter() - .map(|col| { - ( - col.name.clone(), - HashSet::from_iter(std::iter::once(Column::new_unqualified( - col.name.clone(), - ))), - ) - }) - .collect(); - return Some(Self { mapping }); - } - - None + if cur_aliases.is_empty() { + return Ok(HashMap::new()); } - /// update alias for original columns - /// - /// only handle `Alias` with column in `Projection` node - pub fn update_alias(&mut self, node: &LogicalPlan) { - if let LogicalPlan::Projection(projection) = node { - // first collect all the alias mapping, i.e. the col_a AS b AS c AS d become `a->d` - // notice one column might have multiple aliases - let mut alias_mapping: AliasMapping = HashMap::new(); - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - let outer_alias = alias.clone(); - let mut cur_alias = alias.clone(); - while let Expr::Alias(alias) = *cur_alias.expr { - cur_alias = alias; - } - if let Expr::Column(column) = *cur_alias.expr { - alias_mapping - .entry(column.name.clone()) - .or_default() - .insert(Column::new(outer_alias.relation, outer_alias.name)); - } - } else if let Expr::Column(column) = expr { - // identity mapping - alias_mapping - .entry(column.name.clone()) - .or_default() - .insert(column.clone()); - } + original_column_for_inner(cur_aliases, &aliased_node, &original_node) +} + +fn original_column_for_inner( + mut cur_aliases: HashMap, + node: &LogicalPlan, + original_node: &Option>, +) -> DfResult> { + let mut current_node = node; + + loop { + // Base case: check if we've reached the target node + if let Some(original_node) = original_node { + if *current_node == **original_node { + return Ok(cur_aliases); } + } else if current_node.inputs().is_empty() { + // leaf node reached + return Ok(cur_aliases); + } - // update mapping using `alias_mapping` - let mut new_mapping = HashMap::new(); - for (table_col_name, cur_columns) in std::mem::take(&mut self.mapping) { - let new_aliases = { - let mut new_aliases = HashSet::new(); - for cur_column in &cur_columns { - let new_alias_for_cur_column = alias_mapping - .get(cur_column.name()) + // Validate node has exactly one child + if current_node.inputs().len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "only accept plan with at most one child".to_string(), + )); + } + + // Get alias layer and update aliases + let layer = get_alias_layer_from_node(current_node)?; + let mut new_aliases = HashMap::new(); + for (start_alias, cur_alias) in cur_aliases { + if let Some(old_column) = layer.get_old_from_new(cur_alias.clone()) { + new_aliases.insert(start_alias, old_column); + } + } + + // Move to child node and continue iteration + cur_aliases = new_aliases; + current_node = current_node.inputs()[0]; + } +} + +/// Return all the aliased columns(at aliased node) for the given original columns(at original node) +/// +/// if `original_node` is None, it means original columns are from leaf node +pub fn aliased_columns_for( + original_columns: &HashSet, + aliased_node: &LogicalPlan, + original_node: Option<&LogicalPlan>, +) -> DfResult>> { + let initial_aliases: HashMap> = { + if let Some(original) = &original_node { + let schema_cols: HashSet = original.schema().columns().into_iter().collect(); + original_columns + .iter() + .filter(|c| schema_cols.contains(c)) + .map(|c| (c.clone(), HashSet::from([c.clone()]))) + .collect() + } else { + original_columns + .iter() + .map(|c| (c.clone(), HashSet::from([c.clone()]))) + .collect() + } + }; + + if initial_aliases.is_empty() { + return Ok(HashMap::new()); + } + + aliased_columns_for_inner(initial_aliases, aliased_node, original_node) +} + +fn aliased_columns_for_inner( + cur_aliases: HashMap>, + node: &LogicalPlan, + original_node: Option<&LogicalPlan>, +) -> DfResult>> { + // First, collect the path from current node to the target node + let mut path = Vec::new(); + let mut current_node = node; + + // Descend to the target node, collecting nodes along the way + loop { + // Base case: check if we've reached the target node + if let Some(original_node) = original_node { + if *current_node == *original_node { + break; + } + } else if current_node.inputs().is_empty() { + // leaf node reached + break; + } + + // Validate node has exactly one child + if current_node.inputs().len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "only accept plan with at most one child".to_string(), + )); + } + + // Add current node to path and move to child + path.push(current_node); + current_node = current_node.inputs()[0]; + } + + // Now apply alias layers in reverse order (from original to aliased) + let mut result = cur_aliases; + for &node_in_path in path.iter().rev() { + let layer = get_alias_layer_from_node(node_in_path)?; + let mut new_aliases = HashMap::new(); + for (original_column, cur_alias_set) in result { + let mut new_alias_set = HashSet::new(); + for cur_alias in cur_alias_set { + new_alias_set.extend(layer.get_new_from_old(cur_alias.clone())); + } + if !new_alias_set.is_empty() { + new_aliases.insert(original_column, new_alias_set); + } + } + result = new_aliases; + } + + Ok(result) +} + +/// Return a mapping of original column to all the aliased columns in current node of the plan +fn get_alias_layer_from_node(node: &LogicalPlan) -> DfResult { + match node { + LogicalPlan::Projection(proj) => Ok(get_alias_layer_from_exprs(&proj.expr)), + LogicalPlan::Aggregate(aggr) => Ok(get_alias_layer_from_exprs(&aggr.group_expr)), + LogicalPlan::SubqueryAlias(subquery_alias) => { + let mut layer = AliasLayer::default(); + let old_columns = subquery_alias.input.schema().columns(); + for old_column in old_columns { + let new_column = Column::new( + Some(subquery_alias.alias.clone()), + old_column.name().to_string(), + ); + // mapping from old_column to new_column + layer.insert_alias(old_column, HashSet::from([new_column])); + } + Ok(layer) + } + LogicalPlan::TableScan(scan) => { + let columns = scan.projected_schema.columns(); + let mut layer = AliasLayer::default(); + for col in columns { + layer.insert_alias(col.clone(), HashSet::from([col.clone()])); + } + Ok(layer) + } + _ => { + let input_schema = node + .inputs() + .first() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "only accept plan with at most one child".to_string(), + ) + })? + .schema(); + let output_schema = node.schema(); + // only accept at most one child plan, and if not one of the above nodes, + // also shouldn't modify the schema or else alias scope tracker can't support them + if node.inputs().len() > 1 { + Err(datafusion::error::DataFusionError::Internal(format!( + "only accept plan with at most one child, found: {}", + node + ))) + } else if node.inputs().len() == 1 { + if input_schema != output_schema { + let input_columns = input_schema.columns(); + let all_input_is_in_output = input_columns + .iter() + .all(|c| output_schema.is_column_from_schema(c)); + if all_input_is_in_output { + // all input is in output, so it's just adding some columns, we can do identity mapping for input columns + let mut layer = AliasLayer::default(); + for col in input_columns { + layer.insert_alias(col.clone(), HashSet::from([col.clone()])); + } + Ok(layer) + } else { + // otherwise use the intersection of input and output + // TODO(discord9): maybe just make this case unsupported for now? + common_telemetry::warn!( + "Might be unsupported plan for alias tracking, track alias anyway: {}", + node + ); + let input_columns = input_schema.columns(); + let output_columns = + output_schema.columns().into_iter().collect::>(); + let common_columns: HashSet = input_columns + .iter() + .filter(|c| output_columns.contains(c)) .cloned() - .unwrap_or_default(); + .collect(); - for new_alias in new_alias_for_cur_column { - let is_table_ref_eq = match (&new_alias.relation, &cur_column.relation) - { - (Some(o), Some(c)) => o.resolved_eq(c), - _ => true, - }; - // is the same column if both name and table ref is eq - if is_table_ref_eq { - new_aliases.insert(new_alias.clone()); - } + let mut layer = AliasLayer::default(); + for col in &common_columns { + layer.insert_alias(col.clone(), HashSet::from([col.clone()])); + } + Ok(layer) + } + } else { + // identity mapping + let mut layer = AliasLayer::default(); + for col in output_schema.columns() { + layer.insert_alias(col.clone(), HashSet::from([col.clone()])); + } + Ok(layer) + } + } else { + // unknown plan with no input, error msg + Err(datafusion::error::DataFusionError::Internal(format!( + "Unsupported plan with no input: {}", + node + ))) + } + } + } +} + +fn get_alias_layer_from_exprs(exprs: &[Expr]) -> AliasLayer { + let mut alias_mapping: HashMap> = HashMap::new(); + for expr in exprs { + if let Expr::Alias(alias) = expr { + if let Some(column) = get_alias_original_column(alias) { + alias_mapping + .entry(column.clone()) + .or_default() + .insert(Column::new(alias.relation.clone(), alias.name.clone())); + } + } else if let Expr::Column(column) = expr { + // identity mapping + alias_mapping + .entry(column.clone()) + .or_default() + .insert(column.clone()); + } + } + let mut layer = AliasLayer::default(); + for (old_column, new_columns) in alias_mapping { + layer.insert_alias(old_column, new_columns); + } + layer +} + +#[derive(Default, Debug, Clone)] +struct AliasLayer { + /// for convenient of querying, key is field's name + old_to_new: BTreeMap>, +} + +impl AliasLayer { + pub fn insert_alias(&mut self, old_column: Column, new_columns: HashSet) { + self.old_to_new + .entry(old_column) + .or_default() + .extend(new_columns); + } + + pub fn get_new_from_old(&self, old_column: Column) -> HashSet { + let mut res_cols = HashSet::new(); + for (old, new_cols) in self.old_to_new.iter() { + if old.name() == old_column.name() { + match (&old.relation, &old_column.relation) { + (Some(o), Some(c)) => { + if o.resolved_eq(c) { + res_cols.extend(new_cols.clone()); } } - new_aliases - }; - - new_mapping.insert(table_col_name, new_aliases); + _ => { + // if any of the two relation is None, meaning not fully qualified, just match name + res_cols.extend(new_cols.clone()); + } + } } - - self.mapping = new_mapping; - common_telemetry::debug!( - "Updating alias tracker to {:?} using node: \n{node}", - self.mapping - ); } + res_cols } - pub fn get_all_alias_for_col(&self, col_name: &str) -> Option<&HashSet> { - self.mapping.get(col_name) - } - - #[allow(unused)] - pub fn is_alias_for(&self, original_col: &str, cur_col: &Column) -> bool { - self.mapping - .get(original_col) - .map(|cols| cols.contains(cur_col)) - .unwrap_or(false) + pub fn get_old_from_new(&self, new_column: Column) -> Option { + for (old, new_set) in &self.old_to_new { + if new_set.iter().any(|n| { + if n.name() != new_column.name() { + return false; + } + match (&n.relation, &new_column.relation) { + (Some(r1), Some(r2)) => r1.resolved_eq(r2), + _ => true, + } + }) { + return Some(old.clone()); + } + } + None } } +fn get_alias_original_column(alias: &Alias) -> Option { + let mut cur_alias = alias; + while let Expr::Alias(inner_alias) = cur_alias.expr.as_ref() { + cur_alias = inner_alias; + } + if let Expr::Column(column) = cur_alias.expr.as_ref() { + return Some(column.clone()); + } + + None +} + +/// Mapping of original column in table to all the alias at current node +pub type AliasMapping = BTreeMap>; + #[cfg(test)] mod tests { use std::sync::Arc; use common_telemetry::init_default_ut_logging; - use datafusion::error::Result as DfResult; - use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; + use datafusion::datasource::DefaultTableSource; + use datafusion::functions_aggregate::min_max::{max, min}; use datafusion_expr::{LogicalPlanBuilder, col}; + use pretty_assertions::assert_eq; + use table::table::adapter::DfTableProviderAdapter; use super::*; use crate::dist_plan::analyzer::test::TestTable; - #[derive(Debug)] - struct TrackerTester { - alias_tracker: Option, - mapping_at_each_level: Vec, - } - - impl TreeNodeVisitor<'_> for TrackerTester { - type Node = LogicalPlan; - - fn f_up(&mut self, node: &LogicalPlan) -> DfResult { - if let Some(alias_tracker) = &mut self.alias_tracker { - alias_tracker.update_alias(node); - self.mapping_at_each_level.push( - self.alias_tracker - .as_ref() - .map(|a| a.mapping.clone()) - .unwrap_or_default() - .clone(), - ); - } else if let LogicalPlan::TableScan(table_scan) = node { - self.alias_tracker = AliasTracker::new(table_scan); - self.mapping_at_each_level.push( - self.alias_tracker - .as_ref() - .map(|a| a.mapping.clone()) - .unwrap_or_default() - .clone(), - ); - } - Ok(TreeNodeRecursion::Continue) - } + fn qcol(name: &str) -> Column { + Column::from_qualified_name(name) } #[test] - fn proj_alias_tracker() { + fn proj_multi_layered_alias_tracker() { // use logging for better debugging init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); - let table_source = Arc::new(DefaultTableSource::new(Arc::new( - DfTableProviderAdapter::new(test_table), - ))); - let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) - .unwrap() - .project(vec![ - col("number"), - col("pk3").alias("pk1"), - col("pk2").alias("pk3"), - ]) - .unwrap() - .project(vec![ - col("number"), - col("pk1").alias("pk2"), - col("pk3").alias("pk1"), - ]) - .unwrap() - .build() - .unwrap(); - - let mut tracker_tester = TrackerTester { - alias_tracker: None, - mapping_at_each_level: Vec::new(), - }; - plan.visit(&mut tracker_tester).unwrap(); - - assert_eq!( - tracker_tester.mapping_at_each_level, - vec![ - HashMap::from([ - ("number".to_string(), HashSet::from(["number".into()])), - ("pk1".to_string(), HashSet::from(["pk1".into()])), - ("pk2".to_string(), HashSet::from(["pk2".into()])), - ("pk3".to_string(), HashSet::from(["pk3".into()])), - ("ts".to_string(), HashSet::from(["ts".into()])) - ]), - HashMap::from([ - ("number".to_string(), HashSet::from(["t.number".into()])), - ("pk1".to_string(), HashSet::from([])), - ("pk2".to_string(), HashSet::from(["pk3".into()])), - ("pk3".to_string(), HashSet::from(["pk1".into()])), - ("ts".to_string(), HashSet::from([])) - ]), - HashMap::from([ - ("number".to_string(), HashSet::from(["t.number".into()])), - ("pk1".to_string(), HashSet::from([])), - ("pk2".to_string(), HashSet::from(["pk1".into()])), - ("pk3".to_string(), HashSet::from(["pk2".into()])), - ("ts".to_string(), HashSet::from([])) - ]) - ] - ); - } - - #[test] - fn proj_multi_alias_tracker() { - // use logging for better debugging - init_default_ut_logging(); - let test_table = TestTable::table_with_name(0, "numbers".to_string()); + let test_table = TestTable::table_with_name(0, "t".to_string()); let table_source = Arc::new(DefaultTableSource::new(Arc::new( DfTableProviderAdapter::new(test_table), ))); @@ -273,43 +400,417 @@ mod tests { .build() .unwrap(); - let mut tracker_tester = TrackerTester { - alias_tracker: None, - mapping_at_each_level: Vec::new(), - }; - plan.visit(&mut tracker_tester).unwrap(); + let child = plan.inputs()[0].clone(); assert_eq!( - tracker_tester.mapping_at_each_level, - vec![ - HashMap::from([ - ("number".to_string(), HashSet::from(["number".into()])), - ("pk1".to_string(), HashSet::from(["pk1".into()])), - ("pk2".to_string(), HashSet::from(["pk2".into()])), - ("pk3".to_string(), HashSet::from(["pk3".into()])), - ("ts".to_string(), HashSet::from(["ts".into()])) - ]), - HashMap::from([ - ("number".to_string(), HashSet::from(["t.number".into()])), - ("pk1".to_string(), HashSet::from([])), - ("pk2".to_string(), HashSet::from([])), - ( - "pk3".to_string(), - HashSet::from(["pk1".into(), "pk2".into()]) - ), - ("ts".to_string(), HashSet::from([])) - ]), - HashMap::from([ - ("number".to_string(), HashSet::from(["t.number".into()])), - ("pk1".to_string(), HashSet::from([])), - ("pk2".to_string(), HashSet::from([])), - ( - "pk3".to_string(), - HashSet::from(["pk4".into(), "pk5".into()]) - ), - ("ts".to_string(), HashSet::from([])) - ]) - ] + aliased_columns_for( + &HashSet::from([qcol("pk1"), qcol("pk2")]), + &plan, + Some(&child) + ) + .unwrap(), + HashMap::from([ + (qcol("pk1"), HashSet::from([qcol("pk5")])), + (qcol("pk2"), HashSet::from([qcol("pk4")])) + ]) + ); + + // columns not in the plan should return empty mapping + assert_eq!( + aliased_columns_for( + &HashSet::from([qcol("pk1"), qcol("pk2")]), + &plan, + Some(&plan) + ) + .unwrap(), + HashMap::from([]) + ); + + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk3")]), &plan, Some(&child)).unwrap(), + HashMap::from([]) + ); + + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk5"), qcol("pk4")]), + plan.clone(), + None + ) + .unwrap(), + HashMap::from([(qcol("pk5"), qcol("t.pk3")), (qcol("pk4"), qcol("t.pk3"))]) + ); + + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("pk3")]), &plan, None).unwrap(), + HashMap::from([(qcol("pk3"), HashSet::from([qcol("pk5"), qcol("pk4")]))]) + ); + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk1"), qcol("pk2")]), + child.clone(), + None + ) + .unwrap(), + HashMap::from([(qcol("pk1"), qcol("t.pk3")), (qcol("pk2"), qcol("t.pk3"))]) + ); + + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("pk3")]), &child, None).unwrap(), + HashMap::from([(qcol("pk3"), HashSet::from([qcol("pk1"), qcol("pk2")]))]) + ); + + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk4"), qcol("pk5")]), + plan.clone(), + Some(Arc::new(child.clone())) + ) + .unwrap(), + HashMap::from([(qcol("pk4"), qcol("pk2")), (qcol("pk5"), qcol("pk1"))]) + ); + } + + #[test] + fn sort_subquery_alias_layered_tracker() { + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![col("t.number").sort(true, false)]) + .unwrap() + .alias("a") + .unwrap() + .build() + .unwrap(); + + let sort_plan = plan.inputs()[0].clone(); + let scan_plan = sort_plan.inputs()[0].clone(); + + // Test aliased_columns_for from scan to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.number")]), &plan, Some(&scan_plan)) + .unwrap(), + HashMap::from([(qcol("t.number"), HashSet::from([qcol("a.number")]))]) + ); + + // Test aliased_columns_for from sort to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.number")]), &plan, Some(&sort_plan)) + .unwrap(), + HashMap::from([(qcol("t.number"), HashSet::from([qcol("a.number")]))]) + ); + + // Test aliased_columns_for from leaf to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.number")]), &plan, None).unwrap(), + HashMap::from([(qcol("t.number"), HashSet::from([qcol("a.number")]))]) + ); + + // Test original_column_for from final plan to scan + assert_eq!( + original_column_for( + &HashSet::from([qcol("a.number")]), + plan.clone(), + Some(Arc::new(scan_plan.clone())) + ) + .unwrap(), + HashMap::from([(qcol("a.number"), qcol("t.number"))]) + ); + + // Test original_column_for from final plan to sort + assert_eq!( + original_column_for( + &HashSet::from([qcol("a.number")]), + plan.clone(), + Some(Arc::new(sort_plan.clone())) + ) + .unwrap(), + HashMap::from([(qcol("a.number"), qcol("t.number"))]) + ); + } + + #[test] + fn proj_alias_layered_tracker() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .project(vec![ + col("number"), + col("pk3").alias("pk1"), + col("pk2").alias("pk3"), + ]) + .unwrap() + .project(vec![ + col("number"), + col("pk1").alias("pk2"), + col("pk3").alias("pk1"), + ]) + .unwrap() + .build() + .unwrap(); + + let first_proj = plan.inputs()[0].clone(); + let scan_plan = first_proj.inputs()[0].clone(); + + // Test original_column_for from final plan to scan + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk1")]), + plan.clone(), + Some(Arc::new(scan_plan.clone())) + ) + .unwrap(), + HashMap::from([(qcol("pk1"), qcol("t.pk2"))]) + ); + + // Test original_column_for from final plan to first projection + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk1")]), + plan.clone(), + Some(Arc::new(first_proj.clone())) + ) + .unwrap(), + HashMap::from([(qcol("pk1"), qcol("pk3"))]) + ); + + // Test original_column_for from final plan to leaf + assert_eq!( + original_column_for( + &HashSet::from([qcol("pk1")]), + plan.clone(), + Some(Arc::new(plan.clone())) + ) + .unwrap(), + HashMap::from([(qcol("pk1"), qcol("pk1"))]) + ); + + // Test aliased_columns_for from scan to first projection + assert_eq!( + aliased_columns_for( + &HashSet::from([qcol("t.pk2")]), + &first_proj, + Some(&scan_plan) + ) + .unwrap(), + HashMap::from([(qcol("t.pk2"), HashSet::from([qcol("pk3")]))]) + ); + + // Test aliased_columns_for from first projection to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("pk3")]), &plan, Some(&first_proj)).unwrap(), + HashMap::from([(qcol("pk3"), HashSet::from([qcol("pk1")]))]) + ); + + // Test aliased_columns_for from scan to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk2")]), &plan, Some(&scan_plan)).unwrap(), + HashMap::from([(qcol("t.pk2"), HashSet::from([qcol("pk1")]))]) + ); + + // Test aliased_columns_for from leaf to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("pk2")]), &plan, None).unwrap(), + HashMap::from([(qcol("pk2"), HashSet::from([qcol("pk1")]))]) + ); + } + + #[test] + fn proj_alias_relation_layered_tracker() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .project(vec![ + col("number"), + col("pk3").alias_qualified(Some("b"), "pk1"), + col("pk2").alias_qualified(Some("a"), "pk1"), + ]) + .unwrap() + .build() + .unwrap(); + + let scan_plan = plan.inputs()[0].clone(); + + // Test aliased_columns_for from scan to projection + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk2")]), &plan, Some(&scan_plan)).unwrap(), + HashMap::from([(qcol("t.pk2"), HashSet::from([qcol("a.pk1")]))]) + ); + } + + #[test] + fn proj_alias_aliased_aggr() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .project(vec![ + col("number"), + col("pk1").alias("pk3"), + col("pk2").alias("pk4"), + ]) + .unwrap() + .project(vec![ + col("number"), + col("pk3").alias("pk42"), + col("pk4").alias("pk43"), + ]) + .unwrap() + .aggregate(vec![col("pk42"), col("pk43")], vec![min(col("number"))]) + .unwrap() + .build() + .unwrap(); + + let aggr_plan = plan.clone(); + let second_proj = aggr_plan.inputs()[0].clone(); + let first_proj = second_proj.inputs()[0].clone(); + let scan_plan = first_proj.inputs()[0].clone(); + + // Test aliased_columns_for from scan to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk1")]), &plan, Some(&scan_plan)).unwrap(), + HashMap::from([(qcol("t.pk1"), HashSet::from([qcol("pk42")]))]) + ); + + // Test aliased_columns_for from scan to first projection + assert_eq!( + aliased_columns_for( + &HashSet::from([Column::from_name("pk1")]), + &first_proj, + None + ) + .unwrap(), + HashMap::from([(Column::from_name("pk1"), HashSet::from([qcol("pk3")]))]) + ); + } + + #[test] + fn aggr_aggr_alias() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))]) + .unwrap() + .aggregate( + vec![col("pk1"), col("pk2")], + vec![min(col("max(t.number)"))], + ) + .unwrap() + .build() + .unwrap(); + + let second_aggr = plan.clone(); + let first_aggr = second_aggr.inputs()[0].clone(); + let scan_plan = first_aggr.inputs()[0].clone(); + + // Test aliased_columns_for from scan to final plan (identity mapping for aggregates) + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk1")]), &plan, Some(&scan_plan)).unwrap(), + HashMap::from([(qcol("t.pk1"), HashSet::from([qcol("t.pk1")]))]) + ); + + // Test aliased_columns_for from scan to first aggregate + assert_eq!( + aliased_columns_for( + &HashSet::from([qcol("t.pk1")]), + &first_aggr, + Some(&scan_plan) + ) + .unwrap(), + HashMap::from([(qcol("t.pk1"), HashSet::from([qcol("t.pk1")]))]) + ); + + // Test aliased_columns_for from first aggregate to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([qcol("t.pk1")]), &plan, Some(&first_aggr)).unwrap(), + HashMap::from([(qcol("t.pk1"), HashSet::from([qcol("t.pk1")]))]) + ); + + // Test aliased_columns_for from leaf to final plan + assert_eq!( + aliased_columns_for(&HashSet::from([Column::from_name("pk1")]), &plan, None).unwrap(), + HashMap::from([(Column::from_name("pk1"), HashSet::from([qcol("t.pk1")]))]) + ); + } + + #[test] + fn aggr_aggr_alias_projection() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))]) + .unwrap() + .aggregate( + vec![col("pk1"), col("pk2")], + vec![min(col("max(t.number)"))], + ) + .unwrap() + .project(vec![ + col("pk1").alias("pk11"), + col("pk2").alias("pk22"), + col("min(max(t.number))").alias("min_max_number"), + ]) + .unwrap() + .build() + .unwrap(); + + let proj_plan = plan.clone(); + let second_aggr = proj_plan.inputs()[0].clone(); + + // Test original_column_for from projection to second aggregate for aggr gen column + assert_eq!( + original_column_for( + &HashSet::from([Column::from_name("min_max_number")]), + plan.clone(), + Some(Arc::new(second_aggr.clone())) + ) + .unwrap(), + HashMap::from([( + Column::from_name("min_max_number"), + Column::from_name("min(max(t.number))") + )]) + ); + + // Test aliased_columns_for from second aggregate to projection + assert_eq!( + aliased_columns_for( + &HashSet::from([Column::from_name("min(max(t.number))")]), + &plan, + Some(&second_aggr) + ) + .unwrap(), + HashMap::from([( + Column::from_name("min(max(t.number))"), + HashSet::from([Column::from_name("min_max_number")]) + )]) ); } } diff --git a/src/query/src/dist_plan/commutativity.rs b/src/query/src/dist_plan/commutativity.rs index 84102a1b66..b44d12d101 100644 --- a/src/query/src/dist_plan/commutativity.rs +++ b/src/query/src/dist_plan/commutativity.rs @@ -302,6 +302,10 @@ impl Categorizer { /// Return true if the given expr and partition cols satisfied the rule. /// In this case the plan can be treated as fully commutative. + /// + /// So only if all partition columns show up in `exprs`, return true. + /// Otherwise return false. + /// fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool { let mut ref_cols = HashSet::new(); for expr in exprs { diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs index e6d9ecfdc8..aebf9a457d 100644 --- a/src/query/src/dist_plan/merge_scan.rs +++ b/src/query/src/dist_plan/merge_scan.rs @@ -52,6 +52,7 @@ use store_api::storage::RegionId; use table::table_name::TableName; use tokio::time::Instant; +use crate::dist_plan::analyzer::AliasMapping; use crate::error::ConvertSchemaSnafu; use crate::metrics::{MERGE_SCAN_ERRORS_TOTAL, MERGE_SCAN_POLL_ELAPSED, MERGE_SCAN_REGIONS}; use crate::region_query::RegionQueryHandlerRef; @@ -62,7 +63,7 @@ pub struct MergeScanLogicalPlan { input: LogicalPlan, /// If this plan is a placeholder is_placeholder: bool, - partition_cols: Vec, + partition_cols: AliasMapping, } impl UserDefinedLogicalNodeCore for MergeScanLogicalPlan { @@ -103,7 +104,7 @@ impl UserDefinedLogicalNodeCore for MergeScanLogicalPlan { } impl MergeScanLogicalPlan { - pub fn new(input: LogicalPlan, is_placeholder: bool, partition_cols: Vec) -> Self { + pub fn new(input: LogicalPlan, is_placeholder: bool, partition_cols: AliasMapping) -> Self { Self { input, is_placeholder, @@ -130,7 +131,7 @@ impl MergeScanLogicalPlan { &self.input } - pub fn partition_cols(&self) -> &[String] { + pub fn partition_cols(&self) -> &AliasMapping { &self.partition_cols } } @@ -150,7 +151,7 @@ pub struct MergeScanExec { partition_metrics: Arc>>, query_ctx: QueryContextRef, target_partition: usize, - partition_cols: Vec, + partition_cols: AliasMapping, } impl std::fmt::Debug for MergeScanExec { @@ -175,7 +176,7 @@ impl MergeScanExec { region_query_handler: RegionQueryHandlerRef, query_ctx: QueryContextRef, target_partition: usize, - partition_cols: Vec, + partition_cols: AliasMapping, ) -> Result { // TODO(CookiePieWw): Initially we removed the metadata from the schema in #2000, but we have to // keep it for #4619 to identify json type in src/datatypes/src/schema/column_schema.rs. @@ -215,12 +216,18 @@ impl MergeScanExec { let partition_exprs = partition_cols .iter() .filter_map(|col| { - session_state - .create_physical_expr( - Expr::Column(ColumnExpr::new_unqualified(col)), - plan.schema(), - ) - .ok() + if let Some(first_alias) = col.1.first() { + session_state + .create_physical_expr( + Expr::Column(ColumnExpr::new_unqualified( + first_alias.name().to_string(), + )), + plan.schema(), + ) + .ok() + } else { + None + } }) .collect(); let partitioning = Partitioning::Hash(partition_exprs, target_partition); @@ -424,20 +431,20 @@ impl MergeScanExec { return None; } - let partition_cols = self + let all_partition_col_aliases: HashSet<_> = self .partition_cols - .iter() - .map(|x| x.as_str()) - .collect::>(); + .values() + .flat_map(|aliases| aliases.iter().map(|c| c.name())) + .collect(); let mut overlaps = vec![]; for expr in &hash_exprs { - // TODO(ruihang): tracking aliases if let Some(col_expr) = expr.as_any().downcast_ref::() - && partition_cols.contains(col_expr.name()) + && all_partition_col_aliases.contains(col_expr.name()) { overlaps.push(expr.clone()); } } + if overlaps.is_empty() { return None; } diff --git a/src/query/src/dist_plan/planner.rs b/src/query/src/dist_plan/planner.rs index 74567b727d..433d905ce7 100644 --- a/src/query/src/dist_plan/planner.rs +++ b/src/query/src/dist_plan/planner.rs @@ -177,7 +177,7 @@ impl ExtensionPlanner for DistExtensionPlanner { self.region_query_handler.clone(), query_ctx, session_state.config().target_partitions(), - merge_scan.partition_cols().to_vec(), + merge_scan.partition_cols().clone(), )?; Ok(Some(Arc::new(merge_scan_plan) as _)) }