handle no cast

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-04-04 03:50:44 +08:00
parent b6cd91f446
commit 444a00c866
4 changed files with 213 additions and 44 deletions

View File

@@ -24,9 +24,9 @@ use datafusion_expr::{
};
use datafusion_optimizer::analyzer::AnalyzerRule;
/// ConstNormalizationRule removes casts on column operands in filters when the
/// constant side can be normalized to the source column type ahead of filter
/// pushdown.
/// ConstNormalizationRule rewrites filter literals to the column-side type
/// ahead of filter pushdown, and also removes lossless casts on column
/// operands when the literal can be normalized to the source column type.
#[derive(Debug)]
pub struct ConstNormalizationRule;
@@ -128,15 +128,13 @@ fn normalize_binary_expr(
right: Expr,
schema: &DFSchemaRef,
) -> Result<Expr> {
if let Some(expr) =
normalize_binary_with_casted_column(left.clone(), op, right.clone(), schema)?
{
if let Some(expr) = normalize_binary_with_target(left.clone(), op, right.clone(), schema)? {
return Ok(expr);
}
if let Some(swapped_op) = op.swap()
&& let Some(expr) =
normalize_binary_with_casted_column(right.clone(), swapped_op, left.clone(), schema)?
normalize_binary_with_target(right.clone(), swapped_op, left.clone(), schema)?
{
return Ok(expr);
}
@@ -148,28 +146,24 @@ fn normalize_binary_expr(
}))
}
fn normalize_binary_with_casted_column(
casted_column: Expr,
fn normalize_binary_with_target(
target: Expr,
op: Operator,
constant: Expr,
schema: &DFSchemaRef,
) -> Result<Option<Expr>> {
let Some(casted_column) = extract_casted_column(&casted_column, schema)? else {
let Some(target) = extract_normalization_target(&target, schema)? else {
return Ok(None);
};
let Some(constant) = extract_constant_scalar(&constant)? else {
return Ok(None);
};
if let Some(expr) = normalize_lossless_binary(&casted_column, op, &constant) {
if let Some(expr) = normalize_lossless_binary(&target, op, &constant) {
return Ok(Some(expr));
}
Ok(normalize_timestamp_downcast_binary(
&casted_column,
op,
&constant,
))
Ok(normalize_timestamp_downcast_binary(&target, op, &constant))
}
fn normalize_between_expr(
@@ -188,7 +182,7 @@ fn normalize_between_expr(
}));
}
let Some(casted_column) = extract_casted_column(&expr, schema)? else {
let Some(target) = extract_normalization_target(&expr, schema)? else {
return Ok(Expr::Between(Between {
expr: Box::new(expr),
negated,
@@ -214,20 +208,18 @@ fn normalize_between_expr(
};
if let (Some(low), Some(high)) = (
normalize_lossless_literal(&casted_column, &low_value),
normalize_lossless_literal(&casted_column, &high_value),
normalize_lossless_literal(&target, &low_value),
normalize_lossless_literal(&target, &high_value),
) {
return Ok(Expr::Between(Between {
expr: Box::new(casted_column.source_expr.clone()),
expr: Box::new(target.source_expr.clone()),
negated,
low: Box::new(Expr::Literal(low, None)),
high: Box::new(Expr::Literal(high, None)),
}));
}
if let Some(expr) =
normalize_timestamp_downcast_between(&casted_column, &low_value, &high_value)
{
if let Some(expr) = normalize_timestamp_downcast_between(&target, &low_value, &high_value) {
return Ok(expr);
}
@@ -245,7 +237,7 @@ fn normalize_in_list_expr(
negated: bool,
schema: &DFSchemaRef,
) -> Result<Expr> {
let Some(casted_column) = extract_casted_column(&expr, schema)? else {
let Some(target) = extract_normalization_target(&expr, schema)? else {
return Ok(Expr::InList(InList {
expr: Box::new(expr),
list,
@@ -262,7 +254,7 @@ fn normalize_in_list_expr(
negated,
}));
};
let Some(normalized) = normalize_lossless_literal(&casted_column, &value) else {
let Some(normalized) = normalize_lossless_literal(&target, &value) else {
return Ok(Expr::InList(InList {
expr: Box::new(expr),
list,
@@ -273,14 +265,14 @@ fn normalize_in_list_expr(
}
Ok(Expr::InList(InList {
expr: Box::new(casted_column.source_expr.clone()),
expr: Box::new(target.source_expr.clone()),
list: new_list,
negated,
}))
}
#[derive(Clone)]
struct CastedColumn {
struct NormalizationTarget {
source_expr: Expr,
source_type: DataType,
kind: CastKind,
@@ -296,7 +288,18 @@ enum CastKind {
},
}
fn extract_casted_column(expr: &Expr, schema: &DFSchemaRef) -> Result<Option<CastedColumn>> {
fn extract_normalization_target(
expr: &Expr,
schema: &DFSchemaRef,
) -> Result<Option<NormalizationTarget>> {
if matches!(expr, Expr::Column(_)) {
return Ok(Some(NormalizationTarget {
source_expr: expr.clone(),
source_type: expr.get_type(schema)?,
kind: CastKind::Lossless,
}));
}
let (source_expr, target_type) = match expr {
Expr::Cast(Cast { expr, data_type }) | Expr::TryCast(TryCast { expr, data_type }) => {
(expr.as_ref(), data_type)
@@ -313,7 +316,7 @@ fn extract_casted_column(expr: &Expr, schema: &DFSchemaRef) -> Result<Option<Cas
return Ok(None);
};
Ok(Some(CastedColumn {
Ok(Some(NormalizationTarget {
source_expr: source_expr.clone(),
source_type,
kind,
@@ -359,29 +362,29 @@ fn is_lossless_column_cast(source_type: &DataType, target_type: &DataType) -> bo
}
fn normalize_lossless_binary(
casted_column: &CastedColumn,
target: &NormalizationTarget,
op: Operator,
constant: &ScalarValue,
) -> Option<Expr> {
let normalized = normalize_lossless_literal(casted_column, constant)?;
let normalized = normalize_lossless_literal(target, constant)?;
Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(casted_column.source_expr.clone()),
left: Box::new(target.source_expr.clone()),
op,
right: Box::new(Expr::Literal(normalized, None)),
}))
}
fn normalize_lossless_literal(
casted_column: &CastedColumn,
target: &NormalizationTarget,
constant: &ScalarValue,
) -> Option<ScalarValue> {
matches!(casted_column.kind, CastKind::Lossless)
matches!(target.kind, CastKind::Lossless)
.then_some(())
.and_then(|_| cast_literal_losslessly(constant, &casted_column.source_type))
.and_then(|_| cast_literal_losslessly(constant, &target.source_type))
}
fn normalize_timestamp_downcast_binary(
casted_column: &CastedColumn,
target: &NormalizationTarget,
op: Operator,
constant: &ScalarValue,
) -> Option<Expr> {
@@ -389,7 +392,7 @@ fn normalize_timestamp_downcast_binary(
source_unit,
target_unit,
timezone,
} = &casted_column.kind
} = &target.kind
else {
return None;
};
@@ -413,7 +416,7 @@ fn normalize_timestamp_downcast_binary(
};
Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(casted_column.source_expr.clone()),
left: Box::new(target.source_expr.clone()),
op: normalized_op,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), bound),
@@ -423,7 +426,7 @@ fn normalize_timestamp_downcast_binary(
}
fn normalize_timestamp_downcast_between(
casted_column: &CastedColumn,
target: &NormalizationTarget,
low: &ScalarValue,
high: &ScalarValue,
) -> Option<Expr> {
@@ -431,7 +434,7 @@ fn normalize_timestamp_downcast_between(
source_unit,
target_unit,
timezone,
} = &casted_column.kind
} = &target.kind
else {
return None;
};
@@ -447,7 +450,7 @@ fn normalize_timestamp_downcast_between(
Some(
Expr::BinaryExpr(BinaryExpr {
left: Box::new(casted_column.source_expr.clone()),
left: Box::new(target.source_expr.clone()),
op: Operator::GtEq,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), lower),
@@ -455,7 +458,7 @@ fn normalize_timestamp_downcast_between(
)),
})
.and(Expr::BinaryExpr(BinaryExpr {
left: Box::new(casted_column.source_expr.clone()),
left: Box::new(target.source_expr.clone()),
op: Operator::Lt,
right: Box::new(Expr::Literal(
timestamp_scalar(*source_unit, timezone.clone(), upper),
@@ -588,6 +591,54 @@ mod tests {
);
}
#[test]
fn test_normalize_plain_integer_literals() {
let schema = test_schema(vec![Field::new("v", DataType::Int16, false)]);
let comparison_plan = LogicalPlanBuilder::scan("t", test_source(schema.clone()), None)
.unwrap()
.filter(col("v").gt_eq(lit(42_i64)))
.unwrap()
.build()
.unwrap();
let in_list_plan = LogicalPlanBuilder::scan("t", test_source(schema.clone()), None)
.unwrap()
.filter(col("v").in_list(vec![lit(1_i64), lit(2_i64)], false))
.unwrap()
.build()
.unwrap();
let between_plan = LogicalPlanBuilder::scan("t", test_source(schema), None)
.unwrap()
.filter(col("v").between(lit(3_i64), lit(5_i64)))
.unwrap()
.build()
.unwrap();
let comparison = ConstNormalizationRule
.analyze(comparison_plan, &ConfigOptions::default())
.unwrap();
let in_list = ConstNormalizationRule
.analyze(in_list_plan, &ConfigOptions::default())
.unwrap();
let between = ConstNormalizationRule
.analyze(between_plan, &ConfigOptions::default())
.unwrap();
assert_eq!(
"Filter: t.v >= Int16(42)\n TableScan: t",
comparison.to_string()
);
assert_eq!(
"Filter: t.v IN ([Int16(1), Int16(2)])\n TableScan: t",
in_list.to_string()
);
assert_eq!(
"Filter: t.v BETWEEN Int16(3) AND Int16(5)\n TableScan: t",
between.to_string()
);
}
#[test]
fn test_normalize_in_list_and_between() {
let schema = test_schema(vec![Field::new("v", DataType::Int16, false)]);
@@ -679,6 +730,52 @@ mod tests {
);
}
#[test]
fn test_normalize_plain_timestamp_literals() {
let schema = test_schema(vec![Field::new(
"ts",
DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
false,
)]);
let plan = LogicalPlanBuilder::scan("t", test_source(schema), None)
.unwrap()
.filter(
col("ts")
.gt_eq(lit(ScalarValue::TimestampMillisecond(Some(-299_999), None)))
.and(
col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(10_000), None))),
),
)
.unwrap()
.build()
.unwrap();
let analyzed = ConstNormalizationRule
.analyze(plan, &ConfigOptions::default())
.unwrap();
let expected = "\
\nFilter: t.ts >= TimestampNanosecond(-299999000000, None) AND t.ts <= TimestampNanosecond(10000000000, None)\
\n TableScan: t";
assert_eq!(expected.trim_start(), analyzed.to_string());
let pushed = Optimizer::with_rules(vec![Arc::new(PushDownFilter::new())])
.optimize(analyzed, &OptimizerContext::new(), |_, _| {})
.unwrap();
let expected = "\
\nTableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999000000, None), t.ts <= TimestampNanosecond(10000000000, None)]";
assert_eq!(expected.trim_start(), pushed.to_string());
let filters = extract_scan_filters(&pushed);
let range = build_time_range_predicate("ts", TimeUnit::Nanosecond, &filters);
assert_eq!(
TimestampRange::new_inclusive(
Some(Timestamp::new_nanosecond(-299_999_000_000)),
Some(Timestamp::new_nanosecond(10_000_000_000))
),
range
);
}
fn extract_scan_filters(plan: &LogicalPlan) -> Vec<Expr> {
match plan {
LogicalPlan::TableScan(scan) => scan.filters.clone(),

View File

@@ -90,6 +90,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test;
| logical_plan after TranscribeAtatRule_| SAME TEXT AS ABOVE_|
| logical_plan after resolve_grouping_function_| SAME TEXT AS ABOVE_|
| logical_plan after type_coercion_| SAME TEXT AS ABOVE_|
| logical_plan after ConstNormalizationRule_| SAME TEXT AS ABOVE_|
| logical_plan after DistPlannerAnalyzer_| Projection: test.i, test.j, test.k_|
|_|_MergeScan [is_placeholder=false, remote_input=[_|
|_| PromInstantManipulate: range=[0..10000], lookback=[300000], interval=[5000], time index=[j]_|
@@ -233,6 +234,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series;
| logical_plan after TranscribeAtatRule_| SAME TEXT AS ABOVE_|
| logical_plan after resolve_grouping_function_| SAME TEXT AS ABOVE_|
| logical_plan after type_coercion_| SAME TEXT AS ABOVE_|
| logical_plan after ConstNormalizationRule_| SAME TEXT AS ABOVE_|
| logical_plan after DistPlannerAnalyzer_| Projection: series, test.k, test.j_|
|_|_MergeScan [is_placeholder=false, remote_input=[_|
|_| Projection: test.i AS series, test.k, test.j_|
@@ -356,6 +358,47 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series;
|_|_|
+-+-+
CREATE TABLE test_nano(i DOUBLE, j TIMESTAMP(9) TIME INDEX, k STRING PRIMARY KEY);
Affected Rows: 0
INSERT INTO test_nano VALUES (1, 1000000, "a"), (1, 1000000, "b"), (2, 2000000, "a");
Affected Rows: 3
-- explain at 0s, 5s and 10s for a nanosecond time index.
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (peers.*) REDACTED
TQL EXPLAIN (0, 10, '5s') test_nano;
+---------------+-----------------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+-----------------------------------------------------------------------------------------------------------------------------------+
| logical_plan | PromInstantManipulate: range=[0..10000], lookback=[300000], interval=[5000], time index=[j] |
| | PromSeriesDivide: tags=["k"] |
| | Sort: test_nano.k ASC NULLS FIRST, test_nano.j ASC NULLS FIRST |
| | Projection: test_nano.i, test_nano.k, CAST(test_nano.j AS Timestamp(ms)) AS j |
| | Projection: test_nano.i, test_nano.j, test_nano.k |
| | Filter: __common_expr_3 >= TimestampMillisecond(-299999, None) AND __common_expr_3 <= TimestampMillisecond(10000, None) |
| | Projection: CAST(test_nano.j AS Timestamp(ms)) AS __common_expr_3, test_nano.i, test_nano.j, test_nano.k |
| | MergeScan [is_placeholder=false, remote_input=[ |
| | TableScan: test_nano |
| | ]] |
| physical_plan | PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[j] |
| | PromSeriesDivideExec: tags=["k"] |
| | SortExec: expr=[k@1 ASC, j@2 ASC], preserve_partitioning=[true] |
| | RepartitionExec: partitioning=Hash([k@1], 32), input_partitions=32 |
| | ProjectionExec: expr=[i@0 as i, k@2 as k, CAST(j@1 AS Timestamp(ms)) as j] |
| | FilterExec: __common_expr_3@0 >= -299999 AND __common_expr_3@0 <= 10000, projection=[i@1, j@2, k@3] |
| | ProjectionExec: expr=[CAST(j@1 AS Timestamp(ms)) as __common_expr_3, i@0 as i, j@1 as j, k@2 as k] |
| | MergeScanExec: REDACTED
| | |
+---------------+-----------------------------------------------------------------------------------------------------------------------------------+
DROP TABLE test_nano;
Affected Rows: 0
DROP TABLE test;
Affected Rows: 0

View File

@@ -59,7 +59,7 @@ EXPLAIN SELECT
FROM
cpu
WHERE
usage_small IN (CAST(10 AS BIGINT), CAST(20 AS BIGINT));
usage_small IN (10, 20);
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
@@ -74,6 +74,27 @@ WHERE
| | |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+
-- SQLNESS REPLACE (peers.*) REDACTED
EXPLAIN SELECT
rack
FROM
cpu
WHERE
usage_small BETWEEN 10 AND 20;
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan | MergeScan [is_placeholder=false, remote_input=[ |
| | Projection: cpu.rack |
| | Filter: cpu.usage_small >= Int16(10) AND cpu.usage_small <= Int16(20) |
| | TableScan: cpu |
| | ]] |
| physical_plan | CooperativeExec |
| | MergeScanExec: REDACTED
| | |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+
-- SQLNESS SORT_RESULT 3 1
select
count(*)

View File

@@ -47,7 +47,15 @@ EXPLAIN SELECT
FROM
cpu
WHERE
usage_small IN (CAST(10 AS BIGINT), CAST(20 AS BIGINT));
usage_small IN (10, 20);
-- SQLNESS REPLACE (peers.*) REDACTED
EXPLAIN SELECT
rack
FROM
cpu
WHERE
usage_small BETWEEN 10 AND 20;
-- SQLNESS SORT_RESULT 3 1
select