diff --git a/src/query/src/optimizer/const_normalization.rs b/src/query/src/optimizer/const_normalization.rs index 606e5ee4d9..3bc7d3b8f8 100644 --- a/src/query/src/optimizer/const_normalization.rs +++ b/src/query/src/optimizer/const_normalization.rs @@ -58,12 +58,8 @@ fn rewrite_plan_exprs(plan: LogicalPlan, schema: DFSchemaRef) -> Result>>()?; if !rewriter.transformed { return Ok(Transformed::no(plan)); @@ -287,7 +283,9 @@ struct NormalizationTarget { #[derive(Clone)] enum NormalizationKind { + /// The cast preserves every source value exactly, so literals can be cast directly. Lossless, + /// The cast drops timestamp precision and must widen predicate bounds to preserve semantics. TimestampDowncast { source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit, @@ -387,6 +385,10 @@ impl NormalizationTarget { } } +/// Returns the non-constant side we should normalize against. +/// +/// Plain expressions normalize literals to their own type. Cast expressions only participate when +/// the cast is lossless or when timestamp downcasts can be rewritten as wider source-side bounds. fn extract_normalization_target( expr: &Expr, schema: &DFSchemaRef, @@ -440,14 +442,17 @@ fn classify_normalization_kind( } } +/// Returns whether every value of `source_type` is representable in `target_type`. fn is_lossless_cast(source_type: &DataType, target_type: &DataType) -> bool { match (source_type, target_type) { (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) | (DataType::Int16, DataType::Int32 | DataType::Int64) | (DataType::Int32, DataType::Int64) | (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) + | (DataType::UInt8, DataType::Int16 | DataType::Int32 | DataType::Int64) | (DataType::UInt16, DataType::UInt32 | DataType::UInt64) - | (DataType::UInt32, DataType::UInt64) + | (DataType::UInt16, DataType::Int32 | DataType::Int64) + | (DataType::UInt32, DataType::UInt64 | DataType::Int64) | (DataType::Utf8, DataType::Utf8View | DataType::LargeUtf8) => true, ( DataType::Timestamp(source_unit, source_tz), @@ -496,13 +501,6 @@ fn extract_constant_scalar(expr: &Expr) -> Result> { fn cast_literal_losslessly(value: &ScalarValue, target_type: &DataType) -> Option { try_cast_literal_to_type(value, target_type) - .or_else(|| cast_literal_by_round_trip(value, target_type)) -} - -fn cast_literal_by_round_trip(value: &ScalarValue, target_type: &DataType) -> Option { - let casted = value.cast_to(target_type).ok()?; - let round_trip = casted.cast_to(&value.data_type()).ok()?; - (round_trip == *value).then_some(casted) } #[derive(Clone, Copy)] @@ -547,6 +545,12 @@ fn finer_to_coarser_ratio(source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit (source_scale >= target_scale).then_some(source_scale / target_scale) } +/// Returns the smallest source-unit timestamp whose downcast is greater than or equal to +/// `target_value`. +/// +/// DataFusion timestamp downcasts truncate toward zero. For non-positive buckets that means the +/// bucket starts before `target_value * ratio`, so `<= x` can be rewritten as `< lower_bound(x+1)` +/// without dropping rows near zero or across negative boundaries. fn lower_bound_for_ge( target_value: i64, source_unit: ArrowTimeUnit, @@ -606,7 +610,9 @@ mod tests { use datafusion_optimizer::push_down_filter::PushDownFilter; use table::predicate::build_time_range_predicate; - use super::{ConstNormalizationRule, PatternMatchKind}; + use super::{ + ConstNormalizationRule, PatternMatchKind, cast_literal_losslessly, lower_bound_for_ge, + }; #[test] fn test_normalize_direct_integer_cast_comparison() { @@ -692,6 +698,31 @@ mod tests { } } + #[test] + fn test_normalize_unsigned_to_signed_literals() { + let cases = [ + ( + vec![Field::new("v", DataType::UInt8, false)], + cast(col("v"), DataType::Int16).lt_eq(lit(255_i16)), + "Filter: t.v <= UInt8(255)\n TableScan: t", + ), + ( + vec![Field::new("v", DataType::UInt16, false)], + cast(col("v"), DataType::Int32).gt_eq(lit(42_i32)), + "Filter: t.v >= UInt16(42)\n TableScan: t", + ), + ( + vec![Field::new("v", DataType::UInt32, false)], + cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)), + "Filter: t.v BETWEEN UInt32(3) AND UInt32(5)\n TableScan: t", + ), + ]; + + for (fields, predicate, expected) in cases { + assert_filter_plan(fields, predicate, expected); + } + } + #[test] fn test_normalize_in_list_and_between() { let fields = vec![Field::new("v", DataType::Int16, false)]; @@ -869,6 +900,44 @@ mod tests { ); } + #[test] + fn test_timestamp_downcast_contract_matches_datafusion_casts() { + let cases = [ + (-1_000_001, -1), + (-1_000_000, -1), + (-999_999, 0), + (-1, 0), + (0, 0), + (999_999, 0), + (1_000_000, 1), + ]; + + for (source, expected) in cases { + let casted = cast_literal_losslessly( + &ScalarValue::TimestampNanosecond(Some(source), None), + &DataType::Timestamp(ArrowTimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!( + ScalarValue::TimestampMillisecond(Some(expected), None), + casted + ); + } + + assert_eq!( + Some(-1_999_999), + lower_bound_for_ge(-1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond) + ); + assert_eq!( + Some(-999_999), + lower_bound_for_ge(0, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond) + ); + assert_eq!( + Some(1_000_000), + lower_bound_for_ge(1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond) + ); + } + #[test] fn test_normalize_plain_timestamp_literals() { assert_timestamp_pushdown(