From 9233400f26d240b2c6cb75c0f4364879ced92a39 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 4 Apr 2026 07:59:13 +0800 Subject: [PATCH] refactor using common-expr Signed-off-by: Ruihang Xia --- Cargo.lock | 1 + Cargo.toml | 2 + src/query/Cargo.toml | 1 + .../src/optimizer/const_normalization.rs | 134 ++++++++++-------- 4 files changed, 76 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 695f19b072..f0707e4e7f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10650,6 +10650,7 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-optimizer", "datafusion-physical-expr", diff --git a/Cargo.toml b/Cargo.toml index 5041f167c3..88fa323163 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,6 +131,7 @@ datafusion = "=52.1" datafusion-common = "=52.1" datafusion-datasource = "=52.1" datafusion-expr = "=52.1" +datafusion-expr-common = "=52.1" datafusion-functions = "=52.1" datafusion-functions-aggregate-common = "=52.1" datafusion-functions-window-common = "=52.1" @@ -334,6 +335,7 @@ rev = "5618e779cf2bb4755b499c630fba4c35e91898cb" datafusion = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } +datafusion-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-functions = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-functions-aggregate-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-functions-window-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index af81325aa9..126b47ab63 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -40,6 +40,7 @@ common-time.workspace = true datafusion.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-expr-common.workspace = true datafusion-functions.workspace = true datafusion-optimizer.workspace = true datafusion-physical-expr.workspace = true diff --git a/src/query/src/optimizer/const_normalization.rs b/src/query/src/optimizer/const_normalization.rs index 871c88c921..3a110a4599 100644 --- a/src/query/src/optimizer/const_normalization.rs +++ b/src/query/src/optimizer/const_normalization.rs @@ -16,12 +16,13 @@ use std::sync::Arc; use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit}; use datafusion::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{Cast, InList, TryCast}; use datafusion_expr::{ - Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan, + Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan, lit, }; +use datafusion_expr_common::casts::try_cast_literal_to_type; use datafusion_optimizer::analyzer::AnalyzerRule; /// ConstNormalizationRule rewrites filter literals to the column-side type @@ -83,42 +84,34 @@ fn normalize_scan_plan(scan: TableScan) -> Result> { } fn normalize_filter_expr(expr: Expr, schema: DFSchemaRef) -> Result { - expr.rewrite(&mut ConstNormalizer { schema }) + expr.transform_up(|expr| normalize_expr_node(expr, &schema)) .map(|x| x.data) } -struct ConstNormalizer { - schema: DFSchemaRef, -} - -impl TreeNodeRewriter for ConstNormalizer { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - let original = expr.clone(); - let new_expr = match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - normalize_binary_expr(*left, op, *right, &self.schema)? - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => normalize_between_expr(*expr, negated, *low, *high, &self.schema)?, - Expr::InList(InList { - expr, - list, - negated, - }) => normalize_in_list_expr(*expr, list, negated, &self.schema)?, - expr => expr, - }; - - if new_expr != original { - Ok(Transformed::yes(new_expr)) - } else { - Ok(Transformed::no(new_expr)) +fn normalize_expr_node(expr: Expr, schema: &DFSchemaRef) -> Result> { + let original = expr.clone(); + let new_expr = match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + normalize_binary_expr(*left, op, *right, schema)? } + Expr::Between(Between { + expr, + negated, + low, + high, + }) => normalize_between_expr(*expr, negated, *low, *high, schema)?, + Expr::InList(InList { + expr, + list, + negated, + }) => normalize_in_list_expr(*expr, list, negated, schema)?, + expr => expr, + }; + + if new_expr != original { + Ok(Transformed::yes(new_expr)) + } else { + Ok(Transformed::no(new_expr)) } } @@ -214,8 +207,8 @@ fn normalize_between_expr( return Ok(Expr::Between(Between { expr: Box::new(target.source_expr.clone()), negated, - low: Box::new(Expr::Literal(low, None)), - high: Box::new(Expr::Literal(high, None)), + low: Box::new(lit(low)), + high: Box::new(lit(high)), })); } @@ -261,7 +254,7 @@ fn normalize_in_list_expr( negated, })); }; - new_list.push(Expr::Literal(normalized, None)); + new_list.push(lit(normalized)); } Ok(Expr::InList(InList { @@ -292,7 +285,7 @@ fn extract_normalization_target( expr: &Expr, schema: &DFSchemaRef, ) -> Result> { - if matches!(expr, Expr::Column(_)) { + if expr.try_as_col().is_some() { return Ok(Some(NormalizationTarget { source_expr: expr.clone(), source_type: expr.get_type(schema)?, @@ -300,14 +293,11 @@ fn extract_normalization_target( })); } - let (source_expr, target_type) = match expr { - Expr::Cast(Cast { expr, data_type }) | Expr::TryCast(TryCast { expr, data_type }) => { - (expr.as_ref(), data_type) - } - _ => return Ok(None), + let Some((_, source_expr, target_type)) = extract_cast_input(expr) else { + return Ok(None); }; - if !matches!(source_expr, Expr::Column(_)) { + if source_expr.try_as_col().is_none() { return Ok(None); } @@ -370,7 +360,7 @@ fn normalize_lossless_binary( Some(Expr::BinaryExpr(BinaryExpr { left: Box::new(target.source_expr.clone()), op, - right: Box::new(Expr::Literal(normalized, None)), + right: Box::new(lit(normalized)), })) } @@ -418,10 +408,7 @@ fn normalize_timestamp_downcast_binary( Some(Expr::BinaryExpr(BinaryExpr { left: Box::new(target.source_expr.clone()), op: normalized_op, - right: Box::new(Expr::Literal( - timestamp_scalar(*source_unit, timezone.clone(), bound), - None, - )), + right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), bound))), })) } @@ -452,41 +439,64 @@ fn normalize_timestamp_downcast_between( Expr::BinaryExpr(BinaryExpr { left: Box::new(target.source_expr.clone()), op: Operator::GtEq, - right: Box::new(Expr::Literal( - timestamp_scalar(*source_unit, timezone.clone(), lower), - None, - )), + right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), lower))), }) .and(Expr::BinaryExpr(BinaryExpr { left: Box::new(target.source_expr.clone()), op: Operator::Lt, - right: Box::new(Expr::Literal( - timestamp_scalar(*source_unit, timezone.clone(), upper), - None, - )), + right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), upper))), })), ) } fn extract_constant_scalar(expr: &Expr) -> Result> { - match expr { - Expr::Literal(value, _) => Ok(Some(value.clone())), - Expr::Cast(Cast { expr, data_type }) => extract_constant_scalar(expr)? + if let Some(value) = expr.as_literal() { + return Ok(Some(value.clone())); + } + + let Some((kind, expr, data_type)) = extract_cast_input(expr) else { + return Ok(None); + }; + + match kind { + CastInputKind::Cast => extract_constant_scalar(expr)? .map(|value| value.cast_to(data_type)) .transpose(), - Expr::TryCast(TryCast { expr, data_type }) => { + CastInputKind::TryCast => { Ok(extract_constant_scalar(expr)?.and_then(|value| value.cast_to(data_type).ok())) } - _ => Ok(None), } } fn cast_literal_losslessly(value: &ScalarValue, target_type: &DataType) -> Option { + try_cast_literal_to_type(value, target_type) + .or_else(|| cast_literal_by_round_trip(value, target_type)) +} + +fn cast_literal_by_round_trip(value: &ScalarValue, target_type: &DataType) -> Option { let casted = value.cast_to(target_type).ok()?; let round_trip = casted.cast_to(&value.data_type()).ok()?; (round_trip == *value).then_some(casted) } +#[derive(Clone, Copy)] +enum CastInputKind { + Cast, + TryCast, +} + +fn extract_cast_input(expr: &Expr) -> Option<(CastInputKind, &Expr, &DataType)> { + match expr { + Expr::Cast(Cast { expr, data_type }) => { + Some((CastInputKind::Cast, expr.as_ref(), data_type)) + } + Expr::TryCast(TryCast { expr, data_type }) => { + Some((CastInputKind::TryCast, expr.as_ref(), data_type)) + } + _ => None, + } +} + fn time_unit_rank(unit: ArrowTimeUnit) -> usize { match unit { ArrowTimeUnit::Second => 0,