fix: correct the schema used by TypeConversionRule (#2132)

* fix: correct the schema used by TypeConversionRule

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* specify time zone in UT

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2023-08-09 16:18:17 +08:00
committed by GitHub
parent 2ac51c6348
commit b156225b80
2 changed files with 69 additions and 38 deletions

View File

@@ -53,7 +53,7 @@ tokio.workspace = true
[dev-dependencies]
approx_eq = "0.1"
arrow.workspace = true
catalog = { workspace = true }
catalog = { workspace = true, features = ["testing"] }
common-function-macro = { workspace = true }
format_num = "0.1"
num = "0.4"

View File

@@ -34,14 +34,11 @@ use datatypes::arrow::datatypes::DataType;
pub struct TypeConversionRule;
impl AnalyzerRule for TypeConversionRule {
// TODO(ruihang): fix this warning
#[allow(deprecated)]
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
let schemas = plan.all_schemas().into_iter().cloned().collect::<Vec<_>>();
plan.transform(&|plan| match plan {
LogicalPlan::Filter(filter) => {
let mut converter = TypeConverter {
schemas: schemas.clone(),
schema: filter.input.schema().clone(),
};
let rewritten = filter.predicate.clone().rewrite(&mut converter)?;
Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new(
@@ -58,7 +55,7 @@ impl AnalyzerRule for TypeConversionRule {
fetch,
}) => {
let mut converter = TypeConverter {
schemas: schemas.clone(),
schema: projected_schema.clone(),
};
let rewrite_filters = filters
.into_iter()
@@ -88,7 +85,7 @@ impl AnalyzerRule for TypeConversionRule {
| LogicalPlan::Values { .. }
| LogicalPlan::Analyze { .. } => {
let mut converter = TypeConverter {
schemas: plan.all_schemas().into_iter().cloned().collect(),
schema: plan.schema().clone(),
};
let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
let expr = plan
@@ -118,16 +115,14 @@ impl AnalyzerRule for TypeConversionRule {
}
struct TypeConverter {
schemas: Vec<DFSchemaRef>,
schema: DFSchemaRef,
}
impl TypeConverter {
fn column_type(&self, expr: &Expr) -> Option<DataType> {
if let Expr::Column(_) = expr {
for schema in &self.schemas {
if let Ok(v) = expr.get_type(schema) {
return Some(v);
}
if let Ok(v) = expr.get_type(&self.schema) {
return Some(v);
}
}
None
@@ -296,65 +291,67 @@ mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::logical_expr::expr::AggregateFunction as AggrExpr;
use datafusion_common::{Column, DFField, DFSchema};
use datafusion_expr::{AggregateFunction, LogicalPlanBuilder};
use datafusion_sql::TableReference;
use super::*;
#[test]
fn test_string_to_timestamp_ms() {
assert!(matches!(
assert_eq!(
string_to_timestamp_ms("2022-02-02 19:00:00+08:00").unwrap(),
ScalarValue::TimestampMillisecond(Some(1643799600000), None)
));
assert!(matches!(
);
assert_eq!(
string_to_timestamp_ms("2009-02-13 23:31:30Z").unwrap(),
ScalarValue::TimestampMillisecond(Some(1234567890000), None)
));
);
}
#[test]
fn test_timestamp_to_timestamp_ms_expr() {
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(123, TimeUnit::Second),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(123000), None))
));
);
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
));
);
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
));
);
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None))
));
);
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
));
);
assert!(matches!(
assert_eq!(
timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
));
assert!(matches!(
);
assert_eq!(
timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
));
);
}
#[test]
fn test_convert_timestamp_str() {
use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;
let schema_ref = Arc::new(
let schema = Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(
None::<TableReference>,
@@ -366,9 +363,7 @@ mod tests {
)
.unwrap(),
);
let mut converter = TypeConverter {
schemas: vec![schema_ref],
};
let mut converter = TypeConverter { schema };
assert_eq!(
Expr::Column(Column::from_name("ts")).gt(Expr::Literal(
@@ -387,7 +382,7 @@ mod tests {
#[test]
fn test_convert_bool() {
let col_name = "is_valid";
let schema_ref = Arc::new(
let schema = Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(
None::<TableReference>,
@@ -399,9 +394,7 @@ mod tests {
)
.unwrap(),
);
let mut converter = TypeConverter {
schemas: vec![schema_ref],
};
let mut converter = TypeConverter { schema };
assert_eq!(
Expr::Column(Column::from_name(col_name))
@@ -414,4 +407,42 @@ mod tests {
.unwrap()
);
}
#[test]
fn test_retrieve_type_from_aggr_plan() {
let plan =
LogicalPlanBuilder::values(vec![vec![
Expr::Literal(ScalarValue::Int64(Some(1))),
Expr::Literal(ScalarValue::Float64(Some(1.0))),
Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None)),
]])
.unwrap()
.filter(Expr::Column(Column::from_name("column3")).gt(Expr::Literal(
ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())),
)))
.unwrap()
.aggregate(
Vec::<Expr>::new(),
vec![Expr::AggregateFunction(AggrExpr {
fun: AggregateFunction::Count,
args: vec![Expr::Column(Column::from_name("column1"))],
distinct: false,
filter: None,
order_by: None,
})],
)
.unwrap()
.build()
.unwrap();
let transformed_plan = TypeConversionRule
.analyze(plan, &ConfigOptions::default())
.unwrap();
let expected = String::from(
"Aggregate: groupBy=[[]], aggr=[[COUNT(column1)]]\
\n Filter: column3 > TimestampMillisecond(-28800000, None)\
\n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
);
assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
}
}