refactor using common-expr

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-04-04 07:59:13 +08:00
parent 444a00c866
commit 9233400f26
4 changed files with 76 additions and 62 deletions

1
Cargo.lock generated
View File

@@ -10650,6 +10650,7 @@ dependencies = [
"datafusion",
"datafusion-common",
"datafusion-expr",
"datafusion-expr-common",
"datafusion-functions",
"datafusion-optimizer",
"datafusion-physical-expr",

View File

@@ -131,6 +131,7 @@ datafusion = "=52.1"
datafusion-common = "=52.1"
datafusion-datasource = "=52.1"
datafusion-expr = "=52.1"
datafusion-expr-common = "=52.1"
datafusion-functions = "=52.1"
datafusion-functions-aggregate-common = "=52.1"
datafusion-functions-window-common = "=52.1"
@@ -334,6 +335,7 @@ rev = "5618e779cf2bb4755b499c630fba4c35e91898cb"
datafusion = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-functions = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-functions-aggregate-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-functions-window-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }

View File

@@ -40,6 +40,7 @@ common-time.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-expr-common.workspace = true
datafusion-functions.workspace = true
datafusion-optimizer.workspace = true
datafusion-physical-expr.workspace = true

View File

@@ -16,12 +16,13 @@ use std::sync::Arc;
use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
use datafusion::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DFSchemaRef, Result, ScalarValue};
use datafusion_expr::expr::{Cast, InList, TryCast};
use datafusion_expr::{
Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan,
Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan, lit,
};
use datafusion_expr_common::casts::try_cast_literal_to_type;
use datafusion_optimizer::analyzer::AnalyzerRule;
/// ConstNormalizationRule rewrites filter literals to the column-side type
@@ -83,42 +84,34 @@ fn normalize_scan_plan(scan: TableScan) -> Result<Transformed<LogicalPlan>> {
}
fn normalize_filter_expr(expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
expr.rewrite(&mut ConstNormalizer { schema })
expr.transform_up(|expr| normalize_expr_node(expr, &schema))
.map(|x| x.data)
}
struct ConstNormalizer {
schema: DFSchemaRef,
}
impl TreeNodeRewriter for ConstNormalizer {
type Node = Expr;
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
let original = expr.clone();
let new_expr = match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
normalize_binary_expr(*left, op, *right, &self.schema)?
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => normalize_between_expr(*expr, negated, *low, *high, &self.schema)?,
Expr::InList(InList {
expr,
list,
negated,
}) => normalize_in_list_expr(*expr, list, negated, &self.schema)?,
expr => expr,
};
if new_expr != original {
Ok(Transformed::yes(new_expr))
} else {
Ok(Transformed::no(new_expr))
fn normalize_expr_node(expr: Expr, schema: &DFSchemaRef) -> Result<Transformed<Expr>> {
let original = expr.clone();
let new_expr = match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
normalize_binary_expr(*left, op, *right, schema)?
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => normalize_between_expr(*expr, negated, *low, *high, schema)?,
Expr::InList(InList {
expr,
list,
negated,
}) => normalize_in_list_expr(*expr, list, negated, schema)?,
expr => expr,
};
if new_expr != original {
Ok(Transformed::yes(new_expr))
} else {
Ok(Transformed::no(new_expr))
}
}
@@ -214,8 +207,8 @@ fn normalize_between_expr(
return Ok(Expr::Between(Between {
expr: Box::new(target.source_expr.clone()),
negated,
low: Box::new(Expr::Literal(low, None)),
high: Box::new(Expr::Literal(high, None)),
low: Box::new(lit(low)),
high: Box::new(lit(high)),
}));
}
@@ -261,7 +254,7 @@ fn normalize_in_list_expr(
negated,
}));
};
new_list.push(Expr::Literal(normalized, None));
new_list.push(lit(normalized));
}
Ok(Expr::InList(InList {
@@ -292,7 +285,7 @@ fn extract_normalization_target(
expr: &Expr,
schema: &DFSchemaRef,
) -> Result<Option<NormalizationTarget>> {
if matches!(expr, Expr::Column(_)) {
if expr.try_as_col().is_some() {
return Ok(Some(NormalizationTarget {
source_expr: expr.clone(),
source_type: expr.get_type(schema)?,
@@ -300,14 +293,11 @@ fn extract_normalization_target(
}));
}
let (source_expr, target_type) = match expr {
Expr::Cast(Cast { expr, data_type }) | Expr::TryCast(TryCast { expr, data_type }) => {
(expr.as_ref(), data_type)
}
_ => return Ok(None),
let Some((_, source_expr, target_type)) = extract_cast_input(expr) else {
return Ok(None);
};
if !matches!(source_expr, Expr::Column(_)) {
if source_expr.try_as_col().is_none() {
return Ok(None);
}
@@ -370,7 +360,7 @@ fn normalize_lossless_binary(
Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(target.source_expr.clone()),
op,
right: Box::new(Expr::Literal(normalized, None)),
right: Box::new(lit(normalized)),
}))
}
@@ -418,10 +408,7 @@ fn normalize_timestamp_downcast_binary(
Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(target.source_expr.clone()),
op: normalized_op,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), bound),
None,
)),
right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), bound))),
}))
}
@@ -452,41 +439,64 @@ fn normalize_timestamp_downcast_between(
Expr::BinaryExpr(BinaryExpr {
left: Box::new(target.source_expr.clone()),
op: Operator::GtEq,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), lower),
None,
)),
right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), lower))),
})
.and(Expr::BinaryExpr(BinaryExpr {
left: Box::new(target.source_expr.clone()),
op: Operator::Lt,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), upper),
None,
)),
right: Box::new(lit(timestamp_scalar(*source_unit, timezone.clone(), upper))),
})),
)
}
fn extract_constant_scalar(expr: &Expr) -> Result<Option<ScalarValue>> {
match expr {
Expr::Literal(value, _) => Ok(Some(value.clone())),
Expr::Cast(Cast { expr, data_type }) => extract_constant_scalar(expr)?
if let Some(value) = expr.as_literal() {
return Ok(Some(value.clone()));
}
let Some((kind, expr, data_type)) = extract_cast_input(expr) else {
return Ok(None);
};
match kind {
CastInputKind::Cast => extract_constant_scalar(expr)?
.map(|value| value.cast_to(data_type))
.transpose(),
Expr::TryCast(TryCast { expr, data_type }) => {
CastInputKind::TryCast => {
Ok(extract_constant_scalar(expr)?.and_then(|value| value.cast_to(data_type).ok()))
}
_ => Ok(None),
}
}
fn cast_literal_losslessly(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
try_cast_literal_to_type(value, target_type)
.or_else(|| cast_literal_by_round_trip(value, target_type))
}
fn cast_literal_by_round_trip(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
let casted = value.cast_to(target_type).ok()?;
let round_trip = casted.cast_to(&value.data_type()).ok()?;
(round_trip == *value).then_some(casted)
}
#[derive(Clone, Copy)]
enum CastInputKind {
Cast,
TryCast,
}
fn extract_cast_input(expr: &Expr) -> Option<(CastInputKind, &Expr, &DataType)> {
match expr {
Expr::Cast(Cast { expr, data_type }) => {
Some((CastInputKind::Cast, expr.as_ref(), data_type))
}
Expr::TryCast(TryCast { expr, data_type }) => {
Some((CastInputKind::TryCast, expr.as_ref(), data_type))
}
_ => None,
}
}
fn time_unit_rank(unit: ArrowTimeUnit) -> usize {
match unit {
ArrowTimeUnit::Second => 0,