diff --git a/src/query/src/optimizer/count_wildcard.rs b/src/query/src/optimizer/count_wildcard.rs index 359d333c25..b8a491003b 100644 --- a/src/query/src/optimizer/count_wildcard.rs +++ b/src/query/src/optimizer/count_wildcard.rs @@ -16,12 +16,13 @@ use datafusion::datasource::DefaultTableSource; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::Result as DataFusionResult; +use datafusion_common::{Column, Result as DataFusionResult}; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition}; use datafusion_optimizer::utils::NamePreserver; use datafusion_optimizer::AnalyzerRule; +use datafusion_sql::TableReference; use table::table::adapter::DfTableProviderAdapter; /// A replacement to DataFusion's [`CountWildcardRule`]. This rule @@ -77,11 +78,27 @@ impl CountWildcardToTimeIndexRule { }) } - fn try_find_time_index_col(plan: &LogicalPlan) -> Option { + fn try_find_time_index_col(plan: &LogicalPlan) -> Option { let mut finder = TimeIndexFinder::default(); // Safety: `TimeIndexFinder` won't throw error. plan.visit(&mut finder).unwrap(); - finder.time_index + let col = finder.into_column(); + + // check if the time index is a valid column as for current plan + if let Some(col) = &col { + let mut is_valid = false; + for input in plan.inputs() { + if input.schema().has_column(col) { + is_valid = true; + break; + } + } + if !is_valid { + return None; + } + } + + col } } @@ -114,8 +131,8 @@ impl CountWildcardToTimeIndexRule { #[derive(Default)] struct TimeIndexFinder { - time_index: Option, - table_alias: Option, + time_index_col: Option, + table_alias: Option, } impl TreeNodeVisitor for TimeIndexFinder { @@ -123,7 +140,7 @@ impl TreeNodeVisitor for TimeIndexFinder { fn f_down(&mut self, node: &Self::Node) -> DataFusionResult { if let LogicalPlan::SubqueryAlias(subquery_alias) = node { - self.table_alias = Some(subquery_alias.alias.to_string()); + self.table_alias = Some(subquery_alias.alias.clone()); } if let LogicalPlan::TableScan(table_scan) = &node { @@ -138,9 +155,13 @@ impl TreeNodeVisitor for TimeIndexFinder { .downcast_ref::() { let table_info = adapter.table().table_info(); - let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name); - let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name); - self.time_index = col_name.map(|s| format!("{}.{}", table_name, s)); + self.table_alias + .get_or_insert(TableReference::bare(table_info.name.clone())); + self.time_index_col = table_info + .meta + .schema + .timestamp_column() + .map(|c| c.name.clone()); return Ok(TreeNodeRecursion::Stop); } @@ -154,3 +175,43 @@ impl TreeNodeVisitor for TimeIndexFinder { Ok(TreeNodeRecursion::Stop) } } + +impl TimeIndexFinder { + fn into_column(self) -> Option { + self.time_index_col + .map(|c| Column::new(self.table_alias, c)) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datafusion_expr::{count, wildcard, LogicalPlanBuilder}; + use table::table::numbers::NumbersTable; + + use super::*; + + #[test] + fn uppercase_table_name() { + let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(numbers_table), + ))); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .aggregate(Vec::::new(), vec![count(wildcard())]) + .unwrap() + .alias(r#""FgHiJ""#) + .unwrap() + .build() + .unwrap(); + + let mut finder = TimeIndexFinder::default(); + plan.visit(&mut finder).unwrap(); + + assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ"))); + assert!(finder.time_index_col.is_none()); + } +} diff --git a/tests/cases/standalone/common/aggregate/count.result b/tests/cases/standalone/common/aggregate/count.result new file mode 100644 index 0000000000..4523118d18 --- /dev/null +++ b/tests/cases/standalone/common/aggregate/count.result @@ -0,0 +1,56 @@ +create table "HelloWorld" (a string, b timestamp time index); + +Affected Rows: 0 + +insert into "HelloWorld" values ("a", 1) ,("b", 2); + +Affected Rows: 2 + +select count(*) from "HelloWorld"; + ++----------+ +| COUNT(*) | ++----------+ +| 2 | ++----------+ + +create table test (a string, "BbB" timestamp time index); + +Affected Rows: 0 + +insert into test values ("c", 1) ; + +Affected Rows: 1 + +select count(*) from test; + ++----------+ +| COUNT(*) | ++----------+ +| 1 | ++----------+ + +select count(*) from (select count(*) from test where a = 'a'); + ++----------+ +| COUNT(*) | ++----------+ +| 1 | ++----------+ + +select count(*) from (select * from test cross join "HelloWorld"); + ++----------+ +| COUNT(*) | ++----------+ +| 2 | ++----------+ + +drop table "HelloWorld"; + +Affected Rows: 0 + +drop table test; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/aggregate/count.sql b/tests/cases/standalone/common/aggregate/count.sql new file mode 100644 index 0000000000..80100c96ae --- /dev/null +++ b/tests/cases/standalone/common/aggregate/count.sql @@ -0,0 +1,19 @@ +create table "HelloWorld" (a string, b timestamp time index); + +insert into "HelloWorld" values ("a", 1) ,("b", 2); + +select count(*) from "HelloWorld"; + +create table test (a string, "BbB" timestamp time index); + +insert into test values ("c", 1) ; + +select count(*) from test; + +select count(*) from (select count(*) from test where a = 'a'); + +select count(*) from (select * from test cross join "HelloWorld"); + +drop table "HelloWorld"; + +drop table test;