fix: count_wildcard_to_time_index_rule doesn't handle table reference properly (#3847)

* validate time index col

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* use TableReference instead

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* add more tests

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2024-04-30 23:59:56 +08:00
committed by GitHub
parent e84b1eefdf
commit e6eca8ca0c
3 changed files with 145 additions and 9 deletions

View File

@@ -16,12 +16,13 @@ use datafusion::datasource::DefaultTableSource;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::Result as DataFusionResult;
use datafusion_common::{Column, Result as DataFusionResult};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_optimizer::utils::NamePreserver;
use datafusion_optimizer::AnalyzerRule;
use datafusion_sql::TableReference;
use table::table::adapter::DfTableProviderAdapter;
/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
@@ -77,11 +78,27 @@ impl CountWildcardToTimeIndexRule {
})
}
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<String> {
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
let mut finder = TimeIndexFinder::default();
// Safety: `TimeIndexFinder` won't throw error.
plan.visit(&mut finder).unwrap();
finder.time_index
let col = finder.into_column();
// check if the time index is a valid column as for current plan
if let Some(col) = &col {
let mut is_valid = false;
for input in plan.inputs() {
if input.schema().has_column(col) {
is_valid = true;
break;
}
}
if !is_valid {
return None;
}
}
col
}
}
@@ -114,8 +131,8 @@ impl CountWildcardToTimeIndexRule {
#[derive(Default)]
struct TimeIndexFinder {
time_index: Option<String>,
table_alias: Option<String>,
time_index_col: Option<String>,
table_alias: Option<TableReference>,
}
impl TreeNodeVisitor for TimeIndexFinder {
@@ -123,7 +140,7 @@ impl TreeNodeVisitor for TimeIndexFinder {
fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
self.table_alias = Some(subquery_alias.alias.to_string());
self.table_alias = Some(subquery_alias.alias.clone());
}
if let LogicalPlan::TableScan(table_scan) = &node {
@@ -138,9 +155,13 @@ impl TreeNodeVisitor for TimeIndexFinder {
.downcast_ref::<DfTableProviderAdapter>()
{
let table_info = adapter.table().table_info();
let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name);
let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name);
self.time_index = col_name.map(|s| format!("{}.{}", table_name, s));
self.table_alias
.get_or_insert(TableReference::bare(table_info.name.clone()));
self.time_index_col = table_info
.meta
.schema
.timestamp_column()
.map(|c| c.name.clone());
return Ok(TreeNodeRecursion::Stop);
}
@@ -154,3 +175,43 @@ impl TreeNodeVisitor for TimeIndexFinder {
Ok(TreeNodeRecursion::Stop)
}
}
impl TimeIndexFinder {
fn into_column(self) -> Option<Column> {
self.time_index_col
.map(|c| Column::new(self.table_alias, c))
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use datafusion_expr::{count, wildcard, LogicalPlanBuilder};
use table::table::numbers::NumbersTable;
use super::*;
#[test]
fn uppercase_table_name() {
let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(numbers_table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])
.unwrap()
.alias(r#""FgHiJ""#)
.unwrap()
.build()
.unwrap();
let mut finder = TimeIndexFinder::default();
plan.visit(&mut finder).unwrap();
assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
assert!(finder.time_index_col.is_none());
}
}

View File

@@ -0,0 +1,56 @@
create table "HelloWorld" (a string, b timestamp time index);
Affected Rows: 0
insert into "HelloWorld" values ("a", 1) ,("b", 2);
Affected Rows: 2
select count(*) from "HelloWorld";
+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+
create table test (a string, "BbB" timestamp time index);
Affected Rows: 0
insert into test values ("c", 1) ;
Affected Rows: 1
select count(*) from test;
+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+
select count(*) from (select count(*) from test where a = 'a');
+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+
select count(*) from (select * from test cross join "HelloWorld");
+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+
drop table "HelloWorld";
Affected Rows: 0
drop table test;
Affected Rows: 0

View File

@@ -0,0 +1,19 @@
create table "HelloWorld" (a string, b timestamp time index);
insert into "HelloWorld" values ("a", 1) ,("b", 2);
select count(*) from "HelloWorld";
create table test (a string, "BbB" timestamp time index);
insert into test values ("c", 1) ;
select count(*) from test;
select count(*) from (select count(*) from test where a = 'a');
select count(*) from (select * from test cross join "HelloWorld");
drop table "HelloWorld";
drop table test;