mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 13:52:59 +00:00
fix: count_wildcard_to_time_index_rule doesn't handle table reference properly (#3847)
* validate time index col Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * use TableReference instead Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * add more tests Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -16,12 +16,13 @@ use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion_common::tree_node::{
|
||||
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
|
||||
};
|
||||
use datafusion_common::Result as DataFusionResult;
|
||||
use datafusion_common::{Column, Result as DataFusionResult};
|
||||
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
|
||||
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
|
||||
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
|
||||
use datafusion_optimizer::utils::NamePreserver;
|
||||
use datafusion_optimizer::AnalyzerRule;
|
||||
use datafusion_sql::TableReference;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
|
||||
/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
|
||||
@@ -77,11 +78,27 @@ impl CountWildcardToTimeIndexRule {
|
||||
})
|
||||
}
|
||||
|
||||
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<String> {
|
||||
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
|
||||
let mut finder = TimeIndexFinder::default();
|
||||
// Safety: `TimeIndexFinder` won't throw error.
|
||||
plan.visit(&mut finder).unwrap();
|
||||
finder.time_index
|
||||
let col = finder.into_column();
|
||||
|
||||
// check if the time index is a valid column as for current plan
|
||||
if let Some(col) = &col {
|
||||
let mut is_valid = false;
|
||||
for input in plan.inputs() {
|
||||
if input.schema().has_column(col) {
|
||||
is_valid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !is_valid {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
col
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,8 +131,8 @@ impl CountWildcardToTimeIndexRule {
|
||||
|
||||
#[derive(Default)]
|
||||
struct TimeIndexFinder {
|
||||
time_index: Option<String>,
|
||||
table_alias: Option<String>,
|
||||
time_index_col: Option<String>,
|
||||
table_alias: Option<TableReference>,
|
||||
}
|
||||
|
||||
impl TreeNodeVisitor for TimeIndexFinder {
|
||||
@@ -123,7 +140,7 @@ impl TreeNodeVisitor for TimeIndexFinder {
|
||||
|
||||
fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
|
||||
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
|
||||
self.table_alias = Some(subquery_alias.alias.to_string());
|
||||
self.table_alias = Some(subquery_alias.alias.clone());
|
||||
}
|
||||
|
||||
if let LogicalPlan::TableScan(table_scan) = &node {
|
||||
@@ -138,9 +155,13 @@ impl TreeNodeVisitor for TimeIndexFinder {
|
||||
.downcast_ref::<DfTableProviderAdapter>()
|
||||
{
|
||||
let table_info = adapter.table().table_info();
|
||||
let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name);
|
||||
let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name);
|
||||
self.time_index = col_name.map(|s| format!("{}.{}", table_name, s));
|
||||
self.table_alias
|
||||
.get_or_insert(TableReference::bare(table_info.name.clone()));
|
||||
self.time_index_col = table_info
|
||||
.meta
|
||||
.schema
|
||||
.timestamp_column()
|
||||
.map(|c| c.name.clone());
|
||||
|
||||
return Ok(TreeNodeRecursion::Stop);
|
||||
}
|
||||
@@ -154,3 +175,43 @@ impl TreeNodeVisitor for TimeIndexFinder {
|
||||
Ok(TreeNodeRecursion::Stop)
|
||||
}
|
||||
}
|
||||
|
||||
impl TimeIndexFinder {
|
||||
fn into_column(self) -> Option<Column> {
|
||||
self.time_index_col
|
||||
.map(|c| Column::new(self.table_alias, c))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_expr::{count, wildcard, LogicalPlanBuilder};
|
||||
use table::table::numbers::NumbersTable;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn uppercase_table_name() {
|
||||
let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
|
||||
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
|
||||
DfTableProviderAdapter::new(numbers_table),
|
||||
)));
|
||||
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])
|
||||
.unwrap()
|
||||
.alias(r#""FgHiJ""#)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut finder = TimeIndexFinder::default();
|
||||
plan.visit(&mut finder).unwrap();
|
||||
|
||||
assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
|
||||
assert!(finder.time_index_col.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
56
tests/cases/standalone/common/aggregate/count.result
Normal file
56
tests/cases/standalone/common/aggregate/count.result
Normal file
@@ -0,0 +1,56 @@
|
||||
create table "HelloWorld" (a string, b timestamp time index);
|
||||
|
||||
Affected Rows: 0
|
||||
|
||||
insert into "HelloWorld" values ("a", 1) ,("b", 2);
|
||||
|
||||
Affected Rows: 2
|
||||
|
||||
select count(*) from "HelloWorld";
|
||||
|
||||
+----------+
|
||||
| COUNT(*) |
|
||||
+----------+
|
||||
| 2 |
|
||||
+----------+
|
||||
|
||||
create table test (a string, "BbB" timestamp time index);
|
||||
|
||||
Affected Rows: 0
|
||||
|
||||
insert into test values ("c", 1) ;
|
||||
|
||||
Affected Rows: 1
|
||||
|
||||
select count(*) from test;
|
||||
|
||||
+----------+
|
||||
| COUNT(*) |
|
||||
+----------+
|
||||
| 1 |
|
||||
+----------+
|
||||
|
||||
select count(*) from (select count(*) from test where a = 'a');
|
||||
|
||||
+----------+
|
||||
| COUNT(*) |
|
||||
+----------+
|
||||
| 1 |
|
||||
+----------+
|
||||
|
||||
select count(*) from (select * from test cross join "HelloWorld");
|
||||
|
||||
+----------+
|
||||
| COUNT(*) |
|
||||
+----------+
|
||||
| 2 |
|
||||
+----------+
|
||||
|
||||
drop table "HelloWorld";
|
||||
|
||||
Affected Rows: 0
|
||||
|
||||
drop table test;
|
||||
|
||||
Affected Rows: 0
|
||||
|
||||
19
tests/cases/standalone/common/aggregate/count.sql
Normal file
19
tests/cases/standalone/common/aggregate/count.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
create table "HelloWorld" (a string, b timestamp time index);
|
||||
|
||||
insert into "HelloWorld" values ("a", 1) ,("b", 2);
|
||||
|
||||
select count(*) from "HelloWorld";
|
||||
|
||||
create table test (a string, "BbB" timestamp time index);
|
||||
|
||||
insert into test values ("c", 1) ;
|
||||
|
||||
select count(*) from test;
|
||||
|
||||
select count(*) from (select count(*) from test where a = 'a');
|
||||
|
||||
select count(*) from (select * from test cross join "HelloWorld");
|
||||
|
||||
drop table "HelloWorld";
|
||||
|
||||
drop table test;
|
||||
Reference in New Issue
Block a user