Skip to main content

query/optimizer/
const_normalization.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
18use datafusion::config::ConfigOptions;
19use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
20use datafusion_common::{DFSchemaRef, Result, ScalarValue};
21use datafusion_expr::expr::{Cast, InList, Like, TryCast};
22use datafusion_expr::{Between, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, lit};
23use datafusion_expr_common::casts::try_cast_literal_to_type;
24use datafusion_optimizer::analyzer::AnalyzerRule;
25
26use crate::plan::ExtractExpr;
27
28/// ConstNormalizationRule rewrites castable constants against their
29/// non-constant comparison operand ahead of filter pushdown.
30#[derive(Debug)]
31pub struct ConstNormalizationRule;
32
33impl AnalyzerRule for ConstNormalizationRule {
34    fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
35        plan.transform(|plan| match plan {
36            LogicalPlan::Filter(filter) => {
37                let schema = filter.input.schema().clone();
38                rewrite_plan_exprs(LogicalPlan::Filter(filter), schema)
39            }
40            LogicalPlan::TableScan(scan) => {
41                let schema = scan.projected_schema.clone();
42                rewrite_plan_exprs(LogicalPlan::TableScan(scan), schema)
43            }
44            _ => Ok(Transformed::no(plan)),
45        })
46        .map(|x| x.data)
47    }
48
49    fn name(&self) -> &str {
50        "ConstNormalizationRule"
51    }
52}
53
54fn rewrite_plan_exprs(plan: LogicalPlan, schema: DFSchemaRef) -> Result<Transformed<LogicalPlan>> {
55    let mut rewriter = ConstNormalizationRewriter {
56        schema,
57        transformed: false,
58    };
59    let exprs = plan
60        .expressions_consider_join()
61        .into_iter()
62        .map(|expr| expr.rewrite(&mut rewriter).map(|rewritten| rewritten.data))
63        .collect::<Result<Vec<_>>>()?;
64    if !rewriter.transformed {
65        return Ok(Transformed::no(plan));
66    }
67
68    let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
69    plan.with_new_exprs(exprs, inputs).map(Transformed::yes)
70}
71
72struct ConstNormalizationRewriter {
73    schema: DFSchemaRef,
74    transformed: bool,
75}
76
77impl TreeNodeRewriter for ConstNormalizationRewriter {
78    type Node = Expr;
79
80    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
81        let recursion = if matches!(
82            expr,
83            Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
84        ) {
85            TreeNodeRecursion::Jump
86        } else {
87            TreeNodeRecursion::Continue
88        };
89
90        Ok(Transformed::new(expr, false, recursion))
91    }
92
93    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
94        let rewritten = rewrite_expr_node(expr, &self.schema)?;
95        self.transformed |= rewritten.transformed;
96        Ok(rewritten)
97    }
98}
99
100fn rewrite_expr_node(expr: Expr, schema: &DFSchemaRef) -> Result<Transformed<Expr>> {
101    match expr {
102        Expr::BinaryExpr(binary) => match rewrite_binary_expr(binary.clone(), schema)? {
103            Some(expr) => Ok(Transformed::yes(expr)),
104            None => Ok(Transformed::no(Expr::BinaryExpr(binary))),
105        },
106        Expr::Between(between) => match rewrite_between_expr(between.clone(), schema)? {
107            Some(expr) => Ok(Transformed::yes(expr)),
108            None => Ok(Transformed::no(Expr::Between(between))),
109        },
110        Expr::InList(in_list) => match rewrite_in_list_expr(in_list.clone(), schema)? {
111            Some(expr) => Ok(Transformed::yes(expr)),
112            None => Ok(Transformed::no(Expr::InList(in_list))),
113        },
114        Expr::Like(like) => rewrite_like_expr(like, PatternMatchKind::Like, schema),
115        Expr::SimilarTo(like) => rewrite_like_expr(like, PatternMatchKind::SimilarTo, schema),
116        expr => Ok(Transformed::no(expr)),
117    }
118}
119
120fn rewrite_between_expr(between: Between, schema: &DFSchemaRef) -> Result<Option<Expr>> {
121    let Between {
122        expr,
123        negated,
124        low,
125        high,
126    } = between;
127    let expr = *expr;
128    let low_expr = *low;
129    let high_expr = *high;
130    let Some((target, constants)) =
131        extract_rewrite_operands(&expr, &[low_expr.clone(), high_expr.clone()], schema)?
132    else {
133        return Ok(None);
134    };
135
136    if let Some(mut constants) = target.normalize_constants(&constants) {
137        let high = constants
138            .pop()
139            .expect("between normalization expects high constant");
140        let low = constants
141            .pop()
142            .expect("between normalization expects low constant");
143        return Ok(Some(Expr::Between(Between {
144            expr: Box::new(target.expr.clone()),
145            negated,
146            low: Box::new(lit(low)),
147            high: Box::new(lit(high)),
148        })));
149    }
150
151    Ok((!negated)
152        .then(|| target.normalize_timestamp_between(&constants[0], &constants[1]))
153        .flatten())
154}
155
156fn rewrite_in_list_expr(in_list: InList, schema: &DFSchemaRef) -> Result<Option<Expr>> {
157    let InList {
158        expr,
159        list,
160        negated,
161    } = in_list;
162    let expr = *expr;
163    let Some((target, constants)) = extract_rewrite_operands(&expr, &list, schema)? else {
164        return Ok(None);
165    };
166
167    Ok(target.normalize_constants(&constants).map(|constants| {
168        target
169            .expr
170            .clone()
171            .in_list(constants.into_iter().map(lit).collect(), negated)
172    }))
173}
174
175fn rewrite_like_expr(
176    like: Like,
177    kind: PatternMatchKind,
178    schema: &DFSchemaRef,
179) -> Result<Transformed<Expr>> {
180    let original = match kind {
181        PatternMatchKind::Like => Expr::Like(like.clone()),
182        PatternMatchKind::SimilarTo => Expr::SimilarTo(like.clone()),
183    };
184    let Like {
185        negated,
186        expr,
187        pattern,
188        escape_char,
189        case_insensitive,
190    } = like;
191    let expr = *expr;
192    let pattern = *pattern;
193    let Some((target, constants)) =
194        extract_rewrite_operands(&expr, std::slice::from_ref(&pattern), schema)?
195    else {
196        return Ok(Transformed::no(original));
197    };
198    let Some(mut constants) = target.normalize_constants(&constants) else {
199        return Ok(Transformed::no(original));
200    };
201
202    let pattern = lit(constants
203        .pop()
204        .expect("pattern normalization expects one constant"));
205    let like = Like::new(
206        negated,
207        Box::new(target.expr.clone()),
208        Box::new(pattern),
209        escape_char,
210        case_insensitive,
211    );
212    let rewritten = match kind {
213        PatternMatchKind::Like => Expr::Like(like),
214        PatternMatchKind::SimilarTo => Expr::SimilarTo(like),
215    };
216    Ok(Transformed::yes(rewritten))
217}
218
219fn rewrite_binary_expr(binary: BinaryExpr, schema: &DFSchemaRef) -> Result<Option<Expr>> {
220    if !binary.op.supports_propagation() {
221        return Ok(None);
222    }
223
224    let BinaryExpr { left, op, right } = binary;
225    let left = *left;
226    let right = *right;
227    if let Some(expr) = rewrite_binary_side(left.clone(), op, right.clone(), schema)? {
228        return Ok(Some(expr));
229    }
230
231    let Some(swapped_op) = op.swap() else {
232        return Ok(None);
233    };
234
235    rewrite_binary_side(right, swapped_op, left, schema)
236}
237
238fn rewrite_binary_side(
239    target_expr: Expr,
240    op: Operator,
241    constant_expr: Expr,
242    schema: &DFSchemaRef,
243) -> Result<Option<Expr>> {
244    let Some((target, constants)) =
245        extract_rewrite_operands(&target_expr, std::slice::from_ref(&constant_expr), schema)?
246    else {
247        return Ok(None);
248    };
249
250    if let Some(mut constants) = target.normalize_constants(&constants) {
251        let constant = constants
252            .pop()
253            .expect("binary normalization expects one constant");
254        return Ok(Some(Expr::BinaryExpr(BinaryExpr {
255            left: Box::new(target.expr.clone()),
256            op,
257            right: Box::new(lit(constant)),
258        })));
259    }
260
261    Ok(target.normalize_timestamp_binary(op, &constants[0]))
262}
263
264fn extract_rewrite_operands(
265    target_expr: &Expr,
266    constant_exprs: &[Expr],
267    schema: &DFSchemaRef,
268) -> Result<Option<(NormalizationTarget, Vec<ScalarValue>)>> {
269    let Some(target) = extract_normalization_target(target_expr, schema)? else {
270        return Ok(None);
271    };
272
273    extract_constant_scalars(constant_exprs)
274        .map(|constants| constants.map(|constants| (target, constants)))
275}
276
277#[derive(Clone)]
278struct NormalizationTarget {
279    expr: Expr,
280    data_type: DataType,
281    kind: NormalizationKind,
282}
283
284#[derive(Clone)]
285enum NormalizationKind {
286    /// The cast preserves every source value exactly, so literals can be cast directly.
287    Lossless,
288    /// The cast drops timestamp precision and must widen predicate bounds to preserve semantics.
289    TimestampDowncast {
290        source_unit: ArrowTimeUnit,
291        target_unit: ArrowTimeUnit,
292        timezone: Option<Arc<str>>,
293    },
294}
295
296impl NormalizationTarget {
297    /// Normalizes constants for rewrites that can preserve the original predicate with a direct
298    /// literal cast. Timestamp precision-changing casts are handled by timestamp-specific helpers.
299    fn normalize_constants(&self, constants: &[ScalarValue]) -> Option<Vec<ScalarValue>> {
300        constants
301            .iter()
302            .map(|constant| self.normalize_constant(constant))
303            .collect()
304    }
305
306    fn normalize_constant(&self, constant: &ScalarValue) -> Option<ScalarValue> {
307        match self.kind {
308            NormalizationKind::TimestampDowncast { .. } => None,
309            NormalizationKind::Lossless => try_cast_literal_to_type(constant, &self.data_type),
310        }
311    }
312
313    /// Rewrites predicates over timestamp downcasts into source-side half-open bounds.
314    fn normalize_timestamp_binary(&self, op: Operator, constant: &ScalarValue) -> Option<Expr> {
315        let NormalizationKind::TimestampDowncast {
316            source_unit,
317            target_unit,
318            timezone,
319        } = &self.kind
320        else {
321            return None;
322        };
323
324        let constant = constant
325            .cast_to(&DataType::Timestamp(*target_unit, timezone.clone()))
326            .ok()?;
327        let value = timestamp_scalar_value(&constant)?;
328        let bound = match op {
329            Operator::GtEq => lower_bound_for_ge(value, *source_unit, *target_unit)?,
330            Operator::Gt => lower_bound_for_ge(value.checked_add(1)?, *source_unit, *target_unit)?,
331            Operator::Lt => lower_bound_for_ge(value, *source_unit, *target_unit)?,
332            Operator::LtEq => {
333                lower_bound_for_ge(value.checked_add(1)?, *source_unit, *target_unit)?
334            }
335            _ => return None,
336        };
337
338        let normalized_op = match op {
339            Operator::GtEq | Operator::Gt => Operator::GtEq,
340            Operator::Lt | Operator::LtEq => Operator::Lt,
341            _ => return None,
342        };
343
344        Some(match normalized_op {
345            Operator::GtEq => self.expr.clone().gt_eq(lit(timestamp_scalar(
346                *source_unit,
347                timezone.clone(),
348                bound,
349            ))),
350            Operator::Lt => {
351                self.expr
352                    .clone()
353                    .lt(lit(timestamp_scalar(*source_unit, timezone.clone(), bound)))
354            }
355            _ => unreachable!("timestamp normalization only rewrites to >= or <"),
356        })
357    }
358
359    /// Rewrites `BETWEEN` over timestamp downcasts into an inclusive lower bound and exclusive
360    /// upper bound over the source timestamp unit.
361    fn normalize_timestamp_between(&self, low: &ScalarValue, high: &ScalarValue) -> Option<Expr> {
362        let NormalizationKind::TimestampDowncast {
363            source_unit,
364            target_unit,
365            timezone,
366        } = &self.kind
367        else {
368            return None;
369        };
370
371        let target_type = DataType::Timestamp(*target_unit, timezone.clone());
372        let low = low.cast_to(&target_type).ok()?;
373        let high = high.cast_to(&target_type).ok()?;
374        let low = timestamp_scalar_value(&low)?;
375        let high = timestamp_scalar_value(&high)?;
376
377        let lower = lower_bound_for_ge(low, *source_unit, *target_unit)?;
378        let upper = lower_bound_for_ge(high.checked_add(1)?, *source_unit, *target_unit)?;
379
380        Some(
381            self.expr
382                .clone()
383                .gt_eq(lit(timestamp_scalar(*source_unit, timezone.clone(), lower)))
384                .and(self.expr.clone().lt(lit(timestamp_scalar(
385                    *source_unit,
386                    timezone.clone(),
387                    upper,
388                )))),
389        )
390    }
391}
392
393/// Returns the non-constant side we should normalize against.
394///
395/// Plain expressions normalize literals to their own type. Cast expressions only participate when
396/// the cast is lossless or when timestamp downcasts can be rewritten as wider source-side bounds.
397fn extract_normalization_target(
398    expr: &Expr,
399    schema: &DFSchemaRef,
400) -> Result<Option<NormalizationTarget>> {
401    if extract_constant_scalar(expr)?.is_some() {
402        return Ok(None);
403    }
404
405    let Some((_, source_expr, target_type)) = extract_cast_input(expr) else {
406        return Ok(Some(NormalizationTarget {
407            expr: expr.clone(),
408            data_type: expr.get_type(schema)?,
409            kind: NormalizationKind::Lossless,
410        }));
411    };
412
413    let data_type = source_expr.get_type(schema)?;
414    let Some(kind) = classify_normalization_kind(&data_type, target_type) else {
415        return Ok(None);
416    };
417
418    Ok(Some(NormalizationTarget {
419        expr: source_expr.clone(),
420        data_type,
421        kind,
422    }))
423}
424
425fn classify_normalization_kind(
426    source_type: &DataType,
427    target_type: &DataType,
428) -> Option<NormalizationKind> {
429    // Timestamp casts that change precision need boundary-aware rewrites. A finer target literal
430    // may not map exactly back to the coarser source unit, so the generic lossless path is only
431    // safe for timestamp casts that keep the same unit.
432    if is_lossless_cast(source_type, target_type) {
433        return Some(NormalizationKind::Lossless);
434    }
435
436    match (source_type, target_type) {
437        (
438            DataType::Timestamp(source_unit, source_tz),
439            DataType::Timestamp(target_unit, target_tz),
440        ) if source_tz == target_tz
441            && time_unit_rank(*source_unit) > time_unit_rank(*target_unit) =>
442        {
443            Some(NormalizationKind::TimestampDowncast {
444                source_unit: *source_unit,
445                target_unit: *target_unit,
446                timezone: source_tz.clone(),
447            })
448        }
449        _ => None,
450    }
451}
452
453/// Returns whether every value of `source_type` is representable in `target_type`.
454fn is_lossless_cast(source_type: &DataType, target_type: &DataType) -> bool {
455    match (source_type, target_type) {
456        (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64)
457        | (DataType::Int16, DataType::Int32 | DataType::Int64)
458        | (DataType::Int32, DataType::Int64)
459        | (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64)
460        | (DataType::UInt8, DataType::Int16 | DataType::Int32 | DataType::Int64)
461        | (DataType::UInt16, DataType::UInt32 | DataType::UInt64)
462        | (DataType::UInt16, DataType::Int32 | DataType::Int64)
463        | (DataType::UInt32, DataType::UInt64 | DataType::Int64)
464        | (DataType::Utf8, DataType::Utf8View | DataType::LargeUtf8) => true,
465        (
466            DataType::Timestamp(source_unit, source_tz),
467            DataType::Timestamp(target_unit, target_tz),
468        ) => source_tz == target_tz && source_unit == target_unit,
469        _ => false,
470    }
471}
472
473#[derive(Clone, Copy)]
474enum PatternMatchKind {
475    Like,
476    SimilarTo,
477}
478
479fn extract_constant_scalars(exprs: &[Expr]) -> Result<Option<Vec<ScalarValue>>> {
480    let mut values = Vec::with_capacity(exprs.len());
481    for expr in exprs {
482        let Some(value) = extract_constant_scalar(expr)? else {
483            return Ok(None);
484        };
485        values.push(value);
486    }
487
488    Ok(Some(values))
489}
490
491/// Extracts a literal scalar from an expression, folding constant `CAST` and `TRY_CAST` nodes.
492fn extract_constant_scalar(expr: &Expr) -> Result<Option<ScalarValue>> {
493    if let Some(value) = expr.as_literal() {
494        return Ok(Some(value.clone()));
495    }
496
497    let Some((kind, expr, data_type)) = extract_cast_input(expr) else {
498        return Ok(None);
499    };
500
501    match kind {
502        CastInputKind::Cast => extract_constant_scalar(expr)?
503            .map(|value| value.cast_to(data_type))
504            .transpose(),
505        CastInputKind::TryCast => {
506            Ok(extract_constant_scalar(expr)?.and_then(|value| value.cast_to(data_type).ok()))
507        }
508    }
509}
510
511#[derive(Clone, Copy)]
512enum CastInputKind {
513    Cast,
514    TryCast,
515}
516
517/// Returns the input expression and target type for `CAST` and `TRY_CAST` expressions.
518fn extract_cast_input(expr: &Expr) -> Option<(CastInputKind, &Expr, &DataType)> {
519    match expr {
520        Expr::Cast(Cast { expr, data_type }) => {
521            Some((CastInputKind::Cast, expr.as_ref(), data_type))
522        }
523        Expr::TryCast(TryCast { expr, data_type }) => {
524            Some((CastInputKind::TryCast, expr.as_ref(), data_type))
525        }
526        _ => None,
527    }
528}
529
530fn time_unit_rank(unit: ArrowTimeUnit) -> usize {
531    match unit {
532        ArrowTimeUnit::Second => 0,
533        ArrowTimeUnit::Millisecond => 1,
534        ArrowTimeUnit::Microsecond => 2,
535        ArrowTimeUnit::Nanosecond => 3,
536    }
537}
538
539fn time_unit_scale(unit: ArrowTimeUnit) -> i64 {
540    match unit {
541        ArrowTimeUnit::Second => 1,
542        ArrowTimeUnit::Millisecond => 1_000,
543        ArrowTimeUnit::Microsecond => 1_000_000,
544        ArrowTimeUnit::Nanosecond => 1_000_000_000,
545    }
546}
547
548/// Returns the number of source-unit ticks in one target-unit tick for finer-to-coarser casts.
549fn finer_to_coarser_ratio(source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit) -> Option<i64> {
550    let source_scale = time_unit_scale(source_unit);
551    let target_scale = time_unit_scale(target_unit);
552    (source_scale >= target_scale).then_some(source_scale / target_scale)
553}
554
555/// Returns the smallest source-unit timestamp whose downcast is greater than or equal to
556/// `target_value`.
557///
558/// DataFusion timestamp downcasts truncate toward zero. For non-positive buckets that means the
559/// bucket starts before `target_value * ratio`, so `<= x` can be rewritten as `< lower_bound(x+1)`
560/// without dropping rows near zero or across negative boundaries.
561fn lower_bound_for_ge(
562    target_value: i64,
563    source_unit: ArrowTimeUnit,
564    target_unit: ArrowTimeUnit,
565) -> Option<i64> {
566    let ratio = finer_to_coarser_ratio(source_unit, target_unit)?;
567    let base = target_value.checked_mul(ratio)?;
568    if target_value <= 0 {
569        base.checked_sub(ratio - 1)
570    } else {
571        Some(base)
572    }
573}
574
575fn timestamp_scalar_value(value: &ScalarValue) -> Option<i64> {
576    match value {
577        ScalarValue::TimestampSecond(Some(value), _)
578        | ScalarValue::TimestampMillisecond(Some(value), _)
579        | ScalarValue::TimestampMicrosecond(Some(value), _)
580        | ScalarValue::TimestampNanosecond(Some(value), _) => Some(*value),
581        _ => None,
582    }
583}
584
585fn timestamp_scalar(unit: ArrowTimeUnit, timezone: Option<Arc<str>>, value: i64) -> ScalarValue {
586    match unit {
587        ArrowTimeUnit::Second => ScalarValue::TimestampSecond(Some(value), timezone),
588        ArrowTimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(value), timezone),
589        ArrowTimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(value), timezone),
590        ArrowTimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(value), timezone),
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use std::sync::Arc;
597
598    use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
599    use async_trait::async_trait;
600    use common_time::Timestamp;
601    use common_time::range::TimestampRange;
602    use common_time::timestamp::TimeUnit;
603    use datafusion::catalog::Session;
604    use datafusion::config::ConfigOptions;
605    use datafusion::datasource::{TableProvider, provider_as_source};
606    use datafusion::physical_plan::ExecutionPlan;
607    use datafusion_common::arrow::datatypes::Field;
608    use datafusion_common::{DFSchema, ScalarValue, ToDFSchema};
609    use datafusion_expr::expr::{Between, Like};
610    use datafusion_expr::expr_fn::{cast, col, try_cast};
611    use datafusion_expr::{
612        Expr, LogicalPlan, LogicalPlanBuilder, TableProviderFilterPushDown, TableScan, TableSource,
613        TableType, lit,
614    };
615    use datafusion_optimizer::analyzer::AnalyzerRule;
616    use datafusion_optimizer::optimizer::{Optimizer, OptimizerContext};
617    use datafusion_optimizer::push_down_filter::PushDownFilter;
618    use table::predicate::build_time_range_predicate;
619
620    use super::{
621        ConstNormalizationRule, PatternMatchKind, lower_bound_for_ge, try_cast_literal_to_type,
622    };
623
624    #[test]
625    fn test_normalize_direct_integer_cast_comparison() {
626        assert_filter_plan(
627            vec![Field::new("v", DataType::Int32, false)],
628            cast(col("v"), DataType::Int64).gt_eq(lit(42_i64)),
629            "Filter: t.v >= Int32(42)\n  TableScan: t",
630        );
631    }
632
633    #[test]
634    fn test_normalize_non_column_operand() {
635        assert_filter_plan(
636            vec![Field::new("v", DataType::Int32, false)],
637            cast(col("v") + lit(1_i32), DataType::Int64).gt_eq(lit(42_i64)),
638            "Filter: t.v + Int32(1) >= Int32(42)\n  TableScan: t",
639        );
640    }
641
642    #[test]
643    fn test_normalize_swapped_binary_comparison() {
644        assert_filter_plan(
645            vec![Field::new("v", DataType::Int16, false)],
646            lit(42_i64).lt_eq(cast(col("v"), DataType::Int64)),
647            "Filter: t.v >= Int16(42)\n  TableScan: t",
648        );
649    }
650
651    #[test]
652    fn test_normalize_try_cast_target() {
653        assert_filter_plan(
654            vec![Field::new("v", DataType::Int16, false)],
655            try_cast(col("v"), DataType::Int64).gt_eq(lit(42_i64)),
656            "Filter: t.v >= Int16(42)\n  TableScan: t",
657        );
658    }
659
660    #[test]
661    fn test_normalize_casted_constants() {
662        let fields = vec![Field::new("v", DataType::Int16, false)];
663        let cases = [
664            (
665                col("v").gt_eq(cast(lit(42_i8), DataType::Int64)),
666                "Filter: t.v >= Int16(42)\n  TableScan: t",
667            ),
668            (
669                col("v").in_list(
670                    vec![
671                        cast(lit(1_i8), DataType::Int64),
672                        try_cast(lit(2_i8), DataType::Int64),
673                    ],
674                    false,
675                ),
676                "Filter: t.v IN ([Int16(1), Int16(2)])\n  TableScan: t",
677            ),
678        ];
679
680        for (predicate, expected) in cases {
681            assert_filter_plan(fields.clone(), predicate, expected);
682        }
683    }
684
685    #[test]
686    fn test_normalize_plain_integer_literals() {
687        let fields = vec![Field::new("v", DataType::Int16, false)];
688        let cases = [
689            (
690                col("v").gt_eq(lit(42_i64)),
691                "Filter: t.v >= Int16(42)\n  TableScan: t",
692            ),
693            (
694                col("v").in_list(vec![lit(1_i64), lit(2_i64)], false),
695                "Filter: t.v IN ([Int16(1), Int16(2)])\n  TableScan: t",
696            ),
697            (
698                col("v").between(lit(3_i64), lit(5_i64)),
699                "Filter: t.v BETWEEN Int16(3) AND Int16(5)\n  TableScan: t",
700            ),
701        ];
702
703        for (predicate, expected) in cases {
704            assert_filter_plan(fields.clone(), predicate, expected);
705        }
706    }
707
708    #[test]
709    fn test_normalize_unsigned_to_signed_literals() {
710        let cases = [
711            (
712                vec![Field::new("v", DataType::UInt8, false)],
713                cast(col("v"), DataType::Int16).lt_eq(lit(255_i16)),
714                "Filter: t.v <= UInt8(255)\n  TableScan: t",
715            ),
716            (
717                vec![Field::new("v", DataType::UInt16, false)],
718                cast(col("v"), DataType::Int32).gt_eq(lit(42_i32)),
719                "Filter: t.v >= UInt16(42)\n  TableScan: t",
720            ),
721            (
722                vec![Field::new("v", DataType::UInt32, false)],
723                cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
724                "Filter: t.v BETWEEN UInt32(3) AND UInt32(5)\n  TableScan: t",
725            ),
726        ];
727
728        for (fields, predicate, expected) in cases {
729            assert_filter_plan(fields, predicate, expected);
730        }
731    }
732
733    #[test]
734    fn test_normalize_in_list_and_between() {
735        let fields = vec![Field::new("v", DataType::Int16, false)];
736        let cases = [
737            (
738                cast(col("v"), DataType::Int64).in_list(vec![lit(1_i64), lit(2_i64)], false),
739                "Filter: t.v IN ([Int16(1), Int16(2)])\n  TableScan: t",
740            ),
741            (
742                cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
743                "Filter: t.v BETWEEN Int16(3) AND Int16(5)\n  TableScan: t",
744            ),
745        ];
746
747        for (predicate, expected) in cases {
748            assert_filter_plan(fields.clone(), predicate, expected);
749        }
750    }
751
752    #[test]
753    fn test_keep_non_lossless_literal_unchanged() {
754        assert_filter_plan(
755            vec![Field::new("v", DataType::Int16, false)],
756            col("v").gt_eq(lit(100_000_i64)),
757            "Filter: t.v >= Int64(100000)\n  TableScan: t",
758        );
759    }
760
761    #[test]
762    fn test_normalize_scan_filters() {
763        let scan = build_scan_plan(test_schema(vec![Field::new("v", DataType::Int16, false)]));
764        let LogicalPlan::TableScan(scan) = scan else {
765            panic!("expected table scan");
766        };
767        let plan = LogicalPlan::TableScan(TableScan {
768            filters: vec![cast(col("v"), DataType::Int64).gt_eq(lit(42_i64))],
769            ..scan
770        });
771
772        let analyzed = analyze_plan(plan);
773
774        assert_eq!(
775            vec![col("v").gt_eq(lit(42_i16))],
776            extract_scan_filters(&analyzed)
777        );
778    }
779
780    #[test]
781    fn test_normalize_negated_between() {
782        assert_filter_plan(
783            vec![Field::new("v", DataType::Int16, false)],
784            Expr::Between(Between {
785                expr: Box::new(cast(col("v"), DataType::Int64)),
786                negated: true,
787                low: Box::new(lit(3_i64)),
788                high: Box::new(lit(5_i64)),
789            }),
790            "Filter: t.v NOT BETWEEN Int16(3) AND Int16(5)\n  TableScan: t",
791        );
792    }
793
794    #[test]
795    fn test_normalize_like_literal() {
796        assert_pattern_match_plan(
797            PatternMatchKind::Like,
798            ScalarValue::LargeUtf8(Some("api%".to_string())),
799            "Filter: t.s LIKE Utf8(\"api%\")\n  TableScan: t",
800        );
801    }
802
803    #[test]
804    fn test_normalize_similar_to_literal() {
805        assert_pattern_match_plan(
806            PatternMatchKind::SimilarTo,
807            ScalarValue::LargeUtf8(Some("api.*".to_string())),
808            "Filter: t.s SIMILAR TO Utf8(\"api.*\")\n  TableScan: t",
809        );
810    }
811
812    #[test]
813    fn test_normalize_direct_timestamp_filter() {
814        assert_timestamp_pushdown(
815            vec![
816                Field::new(
817                    "ts",
818                    DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
819                    false,
820                ),
821                Field::new("tag", DataType::Utf8, true),
822            ],
823            ts_cast_to_ms()
824                .gt_eq(ts_ms_literal(-299_999))
825                .and(ts_cast_to_ms().lt_eq(ts_ms_literal(10_000)))
826                .and(col("tag").eq(lit("api"))),
827            "Filter: t.ts >= TimestampNanosecond(-299999999999, None) AND t.ts < TimestampNanosecond(10001000000, None) AND t.tag = Utf8(\"api\")\n  TableScan: t",
828            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999999999, None), t.ts < TimestampNanosecond(10001000000, None), t.tag = Utf8(\"api\")]",
829            TimestampRange::new_inclusive(
830                Some(Timestamp::new_nanosecond(-299_999_999_999)),
831                Some(Timestamp::new_nanosecond(10_000_999_999)),
832            ),
833        );
834    }
835
836    #[test]
837    fn test_normalize_timestamp_between_filter() {
838        assert_timestamp_pushdown(
839            vec![Field::new(
840                "ts",
841                DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
842                false,
843            )],
844            ts_cast_to_ms().between(ts_ms_literal(-299_999), ts_ms_literal(10_000)),
845            "Filter: t.ts >= TimestampNanosecond(-299999999999, None) AND t.ts < TimestampNanosecond(10001000000, None)\n  TableScan: t",
846            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999999999, None), t.ts < TimestampNanosecond(10001000000, None)]",
847            TimestampRange::new_inclusive(
848                Some(Timestamp::new_nanosecond(-299_999_999_999)),
849                Some(Timestamp::new_nanosecond(10_000_999_999)),
850            ),
851        );
852    }
853
854    #[test]
855    fn test_normalize_strict_timestamp_filter() {
856        assert_timestamp_pushdown(
857            vec![Field::new(
858                "ts",
859                DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
860                false,
861            )],
862            ts_cast_to_ms()
863                .gt(ts_ms_literal(10_000))
864                .and(ts_cast_to_ms().lt(ts_ms_literal(20_000))),
865            "Filter: t.ts >= TimestampNanosecond(10001000000, None) AND t.ts < TimestampNanosecond(20000000000, None)\n  TableScan: t",
866            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(10001000000, None), t.ts < TimestampNanosecond(20000000000, None)]",
867            TimestampRange::new_inclusive(
868                Some(Timestamp::new_nanosecond(10_001_000_000)),
869                Some(Timestamp::new_nanosecond(19_999_999_999)),
870            ),
871        );
872    }
873
874    #[test]
875    fn test_normalize_zero_boundary_timestamp_filter() {
876        let fields = vec![Field::new(
877            "ts",
878            DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
879            false,
880        )];
881
882        assert_timestamp_pushdown(
883            fields.clone(),
884            ts_cast_to_ms().gt_eq(ts_ms_literal(0)),
885            "Filter: t.ts >= TimestampNanosecond(-999999, None)\n  TableScan: t",
886            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-999999, None)]",
887            TimestampRange::from_start(Timestamp::new_nanosecond(-999_999)),
888        );
889
890        assert_timestamp_pushdown(
891            fields.clone(),
892            ts_cast_to_ms().lt(ts_ms_literal(0)),
893            "Filter: t.ts < TimestampNanosecond(-999999, None)\n  TableScan: t",
894            "TableScan: t, full_filters=[t.ts < TimestampNanosecond(-999999, None)]",
895            TimestampRange::until_end(Timestamp::new_nanosecond(-999_999), false),
896        );
897
898        assert_timestamp_pushdown(
899            fields,
900            ts_cast_to_ms().between(ts_ms_literal(0), ts_ms_literal(0)),
901            "Filter: t.ts >= TimestampNanosecond(-999999, None) AND t.ts < TimestampNanosecond(1000000, None)\n  TableScan: t",
902            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-999999, None), t.ts < TimestampNanosecond(1000000, None)]",
903            TimestampRange::new_inclusive(
904                Some(Timestamp::new_nanosecond(-999_999)),
905                Some(Timestamp::new_nanosecond(999_999)),
906            ),
907        );
908    }
909
910    #[test]
911    fn test_timestamp_downcast_contract_matches_datafusion_casts() {
912        let cases = [
913            (-1_000_001, -1),
914            (-1_000_000, -1),
915            (-999_999, 0),
916            (-1, 0),
917            (0, 0),
918            (999_999, 0),
919            (1_000_000, 1),
920        ];
921
922        for (source, expected) in cases {
923            let casted = try_cast_literal_to_type(
924                &ScalarValue::TimestampNanosecond(Some(source), None),
925                &DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
926            )
927            .unwrap();
928            assert_eq!(
929                ScalarValue::TimestampMillisecond(Some(expected), None),
930                casted
931            );
932        }
933
934        assert_eq!(
935            Some(-1_999_999),
936            lower_bound_for_ge(-1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
937        );
938        assert_eq!(
939            Some(-999_999),
940            lower_bound_for_ge(0, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
941        );
942        assert_eq!(
943            Some(1_000_000),
944            lower_bound_for_ge(1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
945        );
946    }
947
948    #[test]
949    fn test_normalize_plain_timestamp_literals() {
950        assert_timestamp_pushdown(
951            vec![Field::new(
952                "ts",
953                DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
954                false,
955            )],
956            col("ts")
957                .gt_eq(ts_ms_literal(-299_999))
958                .and(col("ts").lt_eq(ts_ms_literal(10_000))),
959            "Filter: t.ts >= TimestampNanosecond(-299999000000, None) AND t.ts <= TimestampNanosecond(10000000000, None)\n  TableScan: t",
960            "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999000000, None), t.ts <= TimestampNanosecond(10000000000, None)]",
961            TimestampRange::new_inclusive(
962                Some(Timestamp::new_nanosecond(-299_999_000_000)),
963                Some(Timestamp::new_nanosecond(10_000_000_000)),
964            ),
965        );
966    }
967
968    #[test]
969    fn test_keep_timestamp_upcast_filter_unchanged() {
970        assert_filter_plan(
971            vec![Field::new(
972                "ts",
973                DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
974                false,
975            )],
976            cast(
977                col("ts"),
978                DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
979            )
980            .gt_eq(lit(ScalarValue::TimestampNanosecond(Some(1), None))),
981            "Filter: CAST(t.ts AS Timestamp(ns)) >= TimestampNanosecond(1, None)\n  TableScan: t",
982        );
983    }
984
985    fn assert_pattern_match_plan(kind: PatternMatchKind, pattern: ScalarValue, expected: &str) {
986        let predicate = match kind {
987            PatternMatchKind::Like => Expr::Like(Like::new(
988                false,
989                Box::new(cast(col("s"), DataType::LargeUtf8)),
990                Box::new(lit(pattern)),
991                None,
992                false,
993            )),
994            PatternMatchKind::SimilarTo => Expr::SimilarTo(Like::new(
995                false,
996                Box::new(cast(col("s"), DataType::LargeUtf8)),
997                Box::new(lit(pattern)),
998                None,
999                false,
1000            )),
1001        };
1002
1003        assert_filter_plan(
1004            vec![Field::new("s", DataType::Utf8, false)],
1005            predicate,
1006            expected,
1007        );
1008    }
1009
1010    fn assert_filter_plan(fields: Vec<Field>, predicate: Expr, expected: &str) {
1011        assert_eq!(expected, analyze_filter(fields, predicate).to_string());
1012    }
1013
1014    fn assert_timestamp_pushdown(
1015        fields: Vec<Field>,
1016        predicate: Expr,
1017        expected_analyzed: &str,
1018        expected_pushed: &str,
1019        expected_range: TimestampRange,
1020    ) {
1021        let analyzed = analyze_filter(fields, predicate);
1022        assert_eq!(expected_analyzed, analyzed.to_string());
1023
1024        let pushed = push_down_filters(analyzed);
1025        assert_eq!(expected_pushed, pushed.to_string());
1026
1027        let range =
1028            build_time_range_predicate("ts", TimeUnit::Nanosecond, &extract_scan_filters(&pushed));
1029        assert_eq!(expected_range, range);
1030    }
1031
1032    fn analyze_filter(fields: Vec<Field>, predicate: Expr) -> LogicalPlan {
1033        analyze_plan(build_filter_plan(test_schema(fields), predicate))
1034    }
1035
1036    fn analyze_plan(plan: LogicalPlan) -> LogicalPlan {
1037        ConstNormalizationRule
1038            .analyze(plan, &ConfigOptions::default())
1039            .unwrap()
1040    }
1041
1042    fn build_filter_plan(schema: Arc<DFSchema>, predicate: Expr) -> LogicalPlan {
1043        LogicalPlanBuilder::scan("t", test_source(schema), None)
1044            .unwrap()
1045            .filter(predicate)
1046            .unwrap()
1047            .build()
1048            .unwrap()
1049    }
1050
1051    fn build_scan_plan(schema: Arc<DFSchema>) -> LogicalPlan {
1052        LogicalPlanBuilder::scan("t", test_source(schema), None)
1053            .unwrap()
1054            .build()
1055            .unwrap()
1056    }
1057
1058    fn push_down_filters(plan: LogicalPlan) -> LogicalPlan {
1059        Optimizer::with_rules(vec![Arc::new(PushDownFilter::new())])
1060            .optimize(plan, &OptimizerContext::new(), |_, _| {})
1061            .unwrap()
1062    }
1063
1064    fn ts_cast_to_ms() -> Expr {
1065        cast(
1066            col("ts"),
1067            DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
1068        )
1069    }
1070
1071    fn ts_ms_literal(value: i64) -> Expr {
1072        lit(ScalarValue::TimestampMillisecond(Some(value), None))
1073    }
1074
1075    fn extract_scan_filters(plan: &LogicalPlan) -> Vec<Expr> {
1076        match plan {
1077            LogicalPlan::TableScan(scan) => scan.filters.clone(),
1078            _ => plan
1079                .inputs()
1080                .into_iter()
1081                .flat_map(extract_scan_filters)
1082                .collect(),
1083        }
1084    }
1085
1086    fn test_schema(fields: Vec<Field>) -> Arc<DFSchema> {
1087        arrow_schema::Schema::new(fields).to_dfschema_ref().unwrap()
1088    }
1089
1090    fn test_source(schema: Arc<DFSchema>) -> Arc<dyn TableSource> {
1091        let table = ExactPushdownProvider {
1092            schema: Arc::new(schema.as_ref().as_arrow().clone()),
1093        };
1094        provider_as_source(Arc::new(table))
1095    }
1096
1097    #[derive(Debug)]
1098    struct ExactPushdownProvider {
1099        schema: arrow_schema::SchemaRef,
1100    }
1101
1102    #[async_trait]
1103    impl TableProvider for ExactPushdownProvider {
1104        fn as_any(&self) -> &dyn std::any::Any {
1105            self
1106        }
1107
1108        fn schema(&self) -> arrow_schema::SchemaRef {
1109            self.schema.clone()
1110        }
1111
1112        fn table_type(&self) -> TableType {
1113            TableType::Base
1114        }
1115
1116        async fn scan(
1117            &self,
1118            _state: &dyn Session,
1119            _projection: Option<&Vec<usize>>,
1120            _filters: &[Expr],
1121            _limit: Option<usize>,
1122        ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
1123            unreachable!("scan should not be called in const_normalization tests")
1124        }
1125
1126        fn supports_filters_pushdown(
1127            &self,
1128            filters: &[&Expr],
1129        ) -> datafusion::error::Result<Vec<TableProviderFilterPushDown>> {
1130            Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
1131        }
1132    }
1133}