From b156225b808f4e82be5debf4d2463cf6e801110a Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Wed, 9 Aug 2023 16:18:17 +0800 Subject: [PATCH] fix: correct the schema used by TypeConversionRule (#2132) * fix: correct the schema used by TypeConversionRule Signed-off-by: Ruihang Xia * specify time zone in UT Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/query/Cargo.toml | 2 +- src/query/src/optimizer/type_conversion.rs | 105 +++++++++++++-------- 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 0f0fa3a11f..1667622b5d 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -53,7 +53,7 @@ tokio.workspace = true [dev-dependencies] approx_eq = "0.1" arrow.workspace = true -catalog = { workspace = true } +catalog = { workspace = true, features = ["testing"] } common-function-macro = { workspace = true } format_num = "0.1" num = "0.4" diff --git a/src/query/src/optimizer/type_conversion.rs b/src/query/src/optimizer/type_conversion.rs index 3dda517451..64f999b50f 100644 --- a/src/query/src/optimizer/type_conversion.rs +++ b/src/query/src/optimizer/type_conversion.rs @@ -34,14 +34,11 @@ use datatypes::arrow::datatypes::DataType; pub struct TypeConversionRule; impl AnalyzerRule for TypeConversionRule { - // TODO(ruihang): fix this warning - #[allow(deprecated)] fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - let schemas = plan.all_schemas().into_iter().cloned().collect::>(); plan.transform(&|plan| match plan { LogicalPlan::Filter(filter) => { let mut converter = TypeConverter { - schemas: schemas.clone(), + schema: filter.input.schema().clone(), }; let rewritten = filter.predicate.clone().rewrite(&mut converter)?; Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( @@ -58,7 +55,7 @@ impl AnalyzerRule for TypeConversionRule { fetch, }) => { let mut converter = TypeConverter { - schemas: schemas.clone(), + schema: projected_schema.clone(), }; let rewrite_filters = filters .into_iter() @@ -88,7 +85,7 @@ impl AnalyzerRule for TypeConversionRule { | LogicalPlan::Values { .. } | LogicalPlan::Analyze { .. } => { let mut converter = TypeConverter { - schemas: plan.all_schemas().into_iter().cloned().collect(), + schema: plan.schema().clone(), }; let inputs = plan.inputs().into_iter().cloned().collect::>(); let expr = plan @@ -118,16 +115,14 @@ impl AnalyzerRule for TypeConversionRule { } struct TypeConverter { - schemas: Vec, + schema: DFSchemaRef, } impl TypeConverter { fn column_type(&self, expr: &Expr) -> Option { if let Expr::Column(_) = expr { - for schema in &self.schemas { - if let Ok(v) = expr.get_type(schema) { - return Some(v); - } + if let Ok(v) = expr.get_type(&self.schema) { + return Some(v); } } None @@ -296,65 +291,67 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; + use datafusion::logical_expr::expr::AggregateFunction as AggrExpr; use datafusion_common::{Column, DFField, DFSchema}; + use datafusion_expr::{AggregateFunction, LogicalPlanBuilder}; use datafusion_sql::TableReference; use super::*; #[test] fn test_string_to_timestamp_ms() { - assert!(matches!( + assert_eq!( string_to_timestamp_ms("2022-02-02 19:00:00+08:00").unwrap(), ScalarValue::TimestampMillisecond(Some(1643799600000), None) - )); - assert!(matches!( + ); + assert_eq!( string_to_timestamp_ms("2009-02-13 23:31:30Z").unwrap(), ScalarValue::TimestampMillisecond(Some(1234567890000), None) - )); + ); } #[test] fn test_timestamp_to_timestamp_ms_expr() { - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(123, TimeUnit::Second), Expr::Literal(ScalarValue::TimestampMillisecond(Some(123000), None)) - )); + ); - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None)) - )); + ); - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None)) - )); + ); - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None)) - )); + ); - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None)) - )); + ); - assert!(matches!( + assert_eq!( timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None)) - )); - assert!(matches!( + ); + assert_eq!( timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond), Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None)) - )); + ); } #[test] fn test_convert_timestamp_str() { use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit; - let schema_ref = Arc::new( + let schema = Arc::new( DFSchema::new_with_metadata( vec![DFField::new( None::, @@ -366,9 +363,7 @@ mod tests { ) .unwrap(), ); - let mut converter = TypeConverter { - schemas: vec![schema_ref], - }; + let mut converter = TypeConverter { schema }; assert_eq!( Expr::Column(Column::from_name("ts")).gt(Expr::Literal( @@ -387,7 +382,7 @@ mod tests { #[test] fn test_convert_bool() { let col_name = "is_valid"; - let schema_ref = Arc::new( + let schema = Arc::new( DFSchema::new_with_metadata( vec![DFField::new( None::, @@ -399,9 +394,7 @@ mod tests { ) .unwrap(), ); - let mut converter = TypeConverter { - schemas: vec![schema_ref], - }; + let mut converter = TypeConverter { schema }; assert_eq!( Expr::Column(Column::from_name(col_name)) @@ -414,4 +407,42 @@ mod tests { .unwrap() ); } + + #[test] + fn test_retrieve_type_from_aggr_plan() { + let plan = + LogicalPlanBuilder::values(vec![vec![ + Expr::Literal(ScalarValue::Int64(Some(1))), + Expr::Literal(ScalarValue::Float64(Some(1.0))), + Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None)), + ]]) + .unwrap() + .filter(Expr::Column(Column::from_name("column3")).gt(Expr::Literal( + ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())), + ))) + .unwrap() + .aggregate( + Vec::::new(), + vec![Expr::AggregateFunction(AggrExpr { + fun: AggregateFunction::Count, + args: vec![Expr::Column(Column::from_name("column1"))], + distinct: false, + filter: None, + order_by: None, + })], + ) + .unwrap() + .build() + .unwrap(); + + let transformed_plan = TypeConversionRule + .analyze(plan, &ConfigOptions::default()) + .unwrap(); + let expected = String::from( + "Aggregate: groupBy=[[]], aggr=[[COUNT(column1)]]\ + \n Filter: column3 > TimestampMillisecond(-28800000, None)\ + \n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))", + ); + assert_eq!(format!("{}", transformed_plan.display_indent()), expected); + } }