fix: tighten const normalization casts

This commit is contained in:
Ruihang Xia
2026-04-15 09:01:06 +08:00
parent c06b0c53d2
commit f6379aaf4d

View File

@@ -58,12 +58,8 @@ fn rewrite_plan_exprs(plan: LogicalPlan, schema: DFSchemaRef) -> Result<Transfor
};
let exprs = plan
.expressions_consider_join()
.iter()
.map(|expr| {
expr.clone()
.rewrite(&mut rewriter)
.map(|rewritten| rewritten.data)
})
.into_iter()
.map(|expr| expr.rewrite(&mut rewriter).map(|rewritten| rewritten.data))
.collect::<Result<Vec<_>>>()?;
if !rewriter.transformed {
return Ok(Transformed::no(plan));
@@ -287,7 +283,9 @@ struct NormalizationTarget {
#[derive(Clone)]
enum NormalizationKind {
/// The cast preserves every source value exactly, so literals can be cast directly.
Lossless,
/// The cast drops timestamp precision and must widen predicate bounds to preserve semantics.
TimestampDowncast {
source_unit: ArrowTimeUnit,
target_unit: ArrowTimeUnit,
@@ -387,6 +385,10 @@ impl NormalizationTarget {
}
}
/// Returns the non-constant side we should normalize against.
///
/// Plain expressions normalize literals to their own type. Cast expressions only participate when
/// the cast is lossless or when timestamp downcasts can be rewritten as wider source-side bounds.
fn extract_normalization_target(
expr: &Expr,
schema: &DFSchemaRef,
@@ -440,14 +442,17 @@ fn classify_normalization_kind(
}
}
/// Returns whether every value of `source_type` is representable in `target_type`.
fn is_lossless_cast(source_type: &DataType, target_type: &DataType) -> bool {
match (source_type, target_type) {
(DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64)
| (DataType::Int16, DataType::Int32 | DataType::Int64)
| (DataType::Int32, DataType::Int64)
| (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64)
| (DataType::UInt8, DataType::Int16 | DataType::Int32 | DataType::Int64)
| (DataType::UInt16, DataType::UInt32 | DataType::UInt64)
| (DataType::UInt32, DataType::UInt64)
| (DataType::UInt16, DataType::Int32 | DataType::Int64)
| (DataType::UInt32, DataType::UInt64 | DataType::Int64)
| (DataType::Utf8, DataType::Utf8View | DataType::LargeUtf8) => true,
(
DataType::Timestamp(source_unit, source_tz),
@@ -496,13 +501,6 @@ fn extract_constant_scalar(expr: &Expr) -> Result<Option<ScalarValue>> {
fn cast_literal_losslessly(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
try_cast_literal_to_type(value, target_type)
.or_else(|| cast_literal_by_round_trip(value, target_type))
}
fn cast_literal_by_round_trip(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
let casted = value.cast_to(target_type).ok()?;
let round_trip = casted.cast_to(&value.data_type()).ok()?;
(round_trip == *value).then_some(casted)
}
#[derive(Clone, Copy)]
@@ -547,6 +545,12 @@ fn finer_to_coarser_ratio(source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit
(source_scale >= target_scale).then_some(source_scale / target_scale)
}
/// Returns the smallest source-unit timestamp whose downcast is greater than or equal to
/// `target_value`.
///
/// DataFusion timestamp downcasts truncate toward zero. For non-positive buckets that means the
/// bucket starts before `target_value * ratio`, so `<= x` can be rewritten as `< lower_bound(x+1)`
/// without dropping rows near zero or across negative boundaries.
fn lower_bound_for_ge(
target_value: i64,
source_unit: ArrowTimeUnit,
@@ -606,7 +610,9 @@ mod tests {
use datafusion_optimizer::push_down_filter::PushDownFilter;
use table::predicate::build_time_range_predicate;
use super::{ConstNormalizationRule, PatternMatchKind};
use super::{
ConstNormalizationRule, PatternMatchKind, cast_literal_losslessly, lower_bound_for_ge,
};
#[test]
fn test_normalize_direct_integer_cast_comparison() {
@@ -692,6 +698,31 @@ mod tests {
}
}
#[test]
fn test_normalize_unsigned_to_signed_literals() {
let cases = [
(
vec![Field::new("v", DataType::UInt8, false)],
cast(col("v"), DataType::Int16).lt_eq(lit(255_i16)),
"Filter: t.v <= UInt8(255)\n TableScan: t",
),
(
vec![Field::new("v", DataType::UInt16, false)],
cast(col("v"), DataType::Int32).gt_eq(lit(42_i32)),
"Filter: t.v >= UInt16(42)\n TableScan: t",
),
(
vec![Field::new("v", DataType::UInt32, false)],
cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
"Filter: t.v BETWEEN UInt32(3) AND UInt32(5)\n TableScan: t",
),
];
for (fields, predicate, expected) in cases {
assert_filter_plan(fields, predicate, expected);
}
}
#[test]
fn test_normalize_in_list_and_between() {
let fields = vec![Field::new("v", DataType::Int16, false)];
@@ -869,6 +900,44 @@ mod tests {
);
}
#[test]
fn test_timestamp_downcast_contract_matches_datafusion_casts() {
let cases = [
(-1_000_001, -1),
(-1_000_000, -1),
(-999_999, 0),
(-1, 0),
(0, 0),
(999_999, 0),
(1_000_000, 1),
];
for (source, expected) in cases {
let casted = cast_literal_losslessly(
&ScalarValue::TimestampNanosecond(Some(source), None),
&DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
)
.unwrap();
assert_eq!(
ScalarValue::TimestampMillisecond(Some(expected), None),
casted
);
}
assert_eq!(
Some(-1_999_999),
lower_bound_for_ge(-1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
);
assert_eq!(
Some(-999_999),
lower_bound_for_ge(0, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
);
assert_eq!(
Some(1_000_000),
lower_bound_for_ge(1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
);
}
#[test]
fn test_normalize_plain_timestamp_literals() {
assert_timestamp_pushdown(