mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-24 00:40:40 +00:00
fix: tighten const normalization casts
This commit is contained in:
@@ -58,12 +58,8 @@ fn rewrite_plan_exprs(plan: LogicalPlan, schema: DFSchemaRef) -> Result<Transfor
|
||||
};
|
||||
let exprs = plan
|
||||
.expressions_consider_join()
|
||||
.iter()
|
||||
.map(|expr| {
|
||||
expr.clone()
|
||||
.rewrite(&mut rewriter)
|
||||
.map(|rewritten| rewritten.data)
|
||||
})
|
||||
.into_iter()
|
||||
.map(|expr| expr.rewrite(&mut rewriter).map(|rewritten| rewritten.data))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
if !rewriter.transformed {
|
||||
return Ok(Transformed::no(plan));
|
||||
@@ -287,7 +283,9 @@ struct NormalizationTarget {
|
||||
|
||||
#[derive(Clone)]
|
||||
enum NormalizationKind {
|
||||
/// The cast preserves every source value exactly, so literals can be cast directly.
|
||||
Lossless,
|
||||
/// The cast drops timestamp precision and must widen predicate bounds to preserve semantics.
|
||||
TimestampDowncast {
|
||||
source_unit: ArrowTimeUnit,
|
||||
target_unit: ArrowTimeUnit,
|
||||
@@ -387,6 +385,10 @@ impl NormalizationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the non-constant side we should normalize against.
|
||||
///
|
||||
/// Plain expressions normalize literals to their own type. Cast expressions only participate when
|
||||
/// the cast is lossless or when timestamp downcasts can be rewritten as wider source-side bounds.
|
||||
fn extract_normalization_target(
|
||||
expr: &Expr,
|
||||
schema: &DFSchemaRef,
|
||||
@@ -440,14 +442,17 @@ fn classify_normalization_kind(
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether every value of `source_type` is representable in `target_type`.
|
||||
fn is_lossless_cast(source_type: &DataType, target_type: &DataType) -> bool {
|
||||
match (source_type, target_type) {
|
||||
(DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64)
|
||||
| (DataType::Int16, DataType::Int32 | DataType::Int64)
|
||||
| (DataType::Int32, DataType::Int64)
|
||||
| (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64)
|
||||
| (DataType::UInt8, DataType::Int16 | DataType::Int32 | DataType::Int64)
|
||||
| (DataType::UInt16, DataType::UInt32 | DataType::UInt64)
|
||||
| (DataType::UInt32, DataType::UInt64)
|
||||
| (DataType::UInt16, DataType::Int32 | DataType::Int64)
|
||||
| (DataType::UInt32, DataType::UInt64 | DataType::Int64)
|
||||
| (DataType::Utf8, DataType::Utf8View | DataType::LargeUtf8) => true,
|
||||
(
|
||||
DataType::Timestamp(source_unit, source_tz),
|
||||
@@ -496,13 +501,6 @@ fn extract_constant_scalar(expr: &Expr) -> Result<Option<ScalarValue>> {
|
||||
|
||||
fn cast_literal_losslessly(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
|
||||
try_cast_literal_to_type(value, target_type)
|
||||
.or_else(|| cast_literal_by_round_trip(value, target_type))
|
||||
}
|
||||
|
||||
fn cast_literal_by_round_trip(value: &ScalarValue, target_type: &DataType) -> Option<ScalarValue> {
|
||||
let casted = value.cast_to(target_type).ok()?;
|
||||
let round_trip = casted.cast_to(&value.data_type()).ok()?;
|
||||
(round_trip == *value).then_some(casted)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -547,6 +545,12 @@ fn finer_to_coarser_ratio(source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit
|
||||
(source_scale >= target_scale).then_some(source_scale / target_scale)
|
||||
}
|
||||
|
||||
/// Returns the smallest source-unit timestamp whose downcast is greater than or equal to
|
||||
/// `target_value`.
|
||||
///
|
||||
/// DataFusion timestamp downcasts truncate toward zero. For non-positive buckets that means the
|
||||
/// bucket starts before `target_value * ratio`, so `<= x` can be rewritten as `< lower_bound(x+1)`
|
||||
/// without dropping rows near zero or across negative boundaries.
|
||||
fn lower_bound_for_ge(
|
||||
target_value: i64,
|
||||
source_unit: ArrowTimeUnit,
|
||||
@@ -606,7 +610,9 @@ mod tests {
|
||||
use datafusion_optimizer::push_down_filter::PushDownFilter;
|
||||
use table::predicate::build_time_range_predicate;
|
||||
|
||||
use super::{ConstNormalizationRule, PatternMatchKind};
|
||||
use super::{
|
||||
ConstNormalizationRule, PatternMatchKind, cast_literal_losslessly, lower_bound_for_ge,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_normalize_direct_integer_cast_comparison() {
|
||||
@@ -692,6 +698,31 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_unsigned_to_signed_literals() {
|
||||
let cases = [
|
||||
(
|
||||
vec![Field::new("v", DataType::UInt8, false)],
|
||||
cast(col("v"), DataType::Int16).lt_eq(lit(255_i16)),
|
||||
"Filter: t.v <= UInt8(255)\n TableScan: t",
|
||||
),
|
||||
(
|
||||
vec![Field::new("v", DataType::UInt16, false)],
|
||||
cast(col("v"), DataType::Int32).gt_eq(lit(42_i32)),
|
||||
"Filter: t.v >= UInt16(42)\n TableScan: t",
|
||||
),
|
||||
(
|
||||
vec![Field::new("v", DataType::UInt32, false)],
|
||||
cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
|
||||
"Filter: t.v BETWEEN UInt32(3) AND UInt32(5)\n TableScan: t",
|
||||
),
|
||||
];
|
||||
|
||||
for (fields, predicate, expected) in cases {
|
||||
assert_filter_plan(fields, predicate, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_in_list_and_between() {
|
||||
let fields = vec![Field::new("v", DataType::Int16, false)];
|
||||
@@ -869,6 +900,44 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_downcast_contract_matches_datafusion_casts() {
|
||||
let cases = [
|
||||
(-1_000_001, -1),
|
||||
(-1_000_000, -1),
|
||||
(-999_999, 0),
|
||||
(-1, 0),
|
||||
(0, 0),
|
||||
(999_999, 0),
|
||||
(1_000_000, 1),
|
||||
];
|
||||
|
||||
for (source, expected) in cases {
|
||||
let casted = cast_literal_losslessly(
|
||||
&ScalarValue::TimestampNanosecond(Some(source), None),
|
||||
&DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
ScalarValue::TimestampMillisecond(Some(expected), None),
|
||||
casted
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
Some(-1_999_999),
|
||||
lower_bound_for_ge(-1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
|
||||
);
|
||||
assert_eq!(
|
||||
Some(-999_999),
|
||||
lower_bound_for_ge(0, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
|
||||
);
|
||||
assert_eq!(
|
||||
Some(1_000_000),
|
||||
lower_bound_for_ge(1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_plain_timestamp_literals() {
|
||||
assert_timestamp_pushdown(
|
||||
|
||||
Reference in New Issue
Block a user