1use std::sync::Arc;
16
17use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
18use datafusion::config::ConfigOptions;
19use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
20use datafusion_common::{DFSchemaRef, Result, ScalarValue};
21use datafusion_expr::expr::{Cast, InList, Like, TryCast};
22use datafusion_expr::{Between, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, lit};
23use datafusion_expr_common::casts::try_cast_literal_to_type;
24use datafusion_optimizer::analyzer::AnalyzerRule;
25
26use crate::plan::ExtractExpr;
27
28#[derive(Debug)]
31pub struct ConstNormalizationRule;
32
33impl AnalyzerRule for ConstNormalizationRule {
34 fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
35 plan.transform(|plan| match plan {
36 LogicalPlan::Filter(filter) => {
37 let schema = filter.input.schema().clone();
38 rewrite_plan_exprs(LogicalPlan::Filter(filter), schema)
39 }
40 LogicalPlan::TableScan(scan) => {
41 let schema = scan.projected_schema.clone();
42 rewrite_plan_exprs(LogicalPlan::TableScan(scan), schema)
43 }
44 _ => Ok(Transformed::no(plan)),
45 })
46 .map(|x| x.data)
47 }
48
49 fn name(&self) -> &str {
50 "ConstNormalizationRule"
51 }
52}
53
54fn rewrite_plan_exprs(plan: LogicalPlan, schema: DFSchemaRef) -> Result<Transformed<LogicalPlan>> {
55 let mut rewriter = ConstNormalizationRewriter {
56 schema,
57 transformed: false,
58 };
59 let exprs = plan
60 .expressions_consider_join()
61 .into_iter()
62 .map(|expr| expr.rewrite(&mut rewriter).map(|rewritten| rewritten.data))
63 .collect::<Result<Vec<_>>>()?;
64 if !rewriter.transformed {
65 return Ok(Transformed::no(plan));
66 }
67
68 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
69 plan.with_new_exprs(exprs, inputs).map(Transformed::yes)
70}
71
72struct ConstNormalizationRewriter {
73 schema: DFSchemaRef,
74 transformed: bool,
75}
76
77impl TreeNodeRewriter for ConstNormalizationRewriter {
78 type Node = Expr;
79
80 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
81 let recursion = if matches!(
82 expr,
83 Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
84 ) {
85 TreeNodeRecursion::Jump
86 } else {
87 TreeNodeRecursion::Continue
88 };
89
90 Ok(Transformed::new(expr, false, recursion))
91 }
92
93 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
94 let rewritten = rewrite_expr_node(expr, &self.schema)?;
95 self.transformed |= rewritten.transformed;
96 Ok(rewritten)
97 }
98}
99
100fn rewrite_expr_node(expr: Expr, schema: &DFSchemaRef) -> Result<Transformed<Expr>> {
101 match expr {
102 Expr::BinaryExpr(binary) => match rewrite_binary_expr(binary.clone(), schema)? {
103 Some(expr) => Ok(Transformed::yes(expr)),
104 None => Ok(Transformed::no(Expr::BinaryExpr(binary))),
105 },
106 Expr::Between(between) => match rewrite_between_expr(between.clone(), schema)? {
107 Some(expr) => Ok(Transformed::yes(expr)),
108 None => Ok(Transformed::no(Expr::Between(between))),
109 },
110 Expr::InList(in_list) => match rewrite_in_list_expr(in_list.clone(), schema)? {
111 Some(expr) => Ok(Transformed::yes(expr)),
112 None => Ok(Transformed::no(Expr::InList(in_list))),
113 },
114 Expr::Like(like) => rewrite_like_expr(like, PatternMatchKind::Like, schema),
115 Expr::SimilarTo(like) => rewrite_like_expr(like, PatternMatchKind::SimilarTo, schema),
116 expr => Ok(Transformed::no(expr)),
117 }
118}
119
120fn rewrite_between_expr(between: Between, schema: &DFSchemaRef) -> Result<Option<Expr>> {
121 let Between {
122 expr,
123 negated,
124 low,
125 high,
126 } = between;
127 let expr = *expr;
128 let low_expr = *low;
129 let high_expr = *high;
130 let Some((target, constants)) =
131 extract_rewrite_operands(&expr, &[low_expr.clone(), high_expr.clone()], schema)?
132 else {
133 return Ok(None);
134 };
135
136 if let Some(mut constants) = target.normalize_constants(&constants) {
137 let high = constants
138 .pop()
139 .expect("between normalization expects high constant");
140 let low = constants
141 .pop()
142 .expect("between normalization expects low constant");
143 return Ok(Some(Expr::Between(Between {
144 expr: Box::new(target.expr.clone()),
145 negated,
146 low: Box::new(lit(low)),
147 high: Box::new(lit(high)),
148 })));
149 }
150
151 Ok((!negated)
152 .then(|| target.normalize_timestamp_between(&constants[0], &constants[1]))
153 .flatten())
154}
155
156fn rewrite_in_list_expr(in_list: InList, schema: &DFSchemaRef) -> Result<Option<Expr>> {
157 let InList {
158 expr,
159 list,
160 negated,
161 } = in_list;
162 let expr = *expr;
163 let Some((target, constants)) = extract_rewrite_operands(&expr, &list, schema)? else {
164 return Ok(None);
165 };
166
167 Ok(target.normalize_constants(&constants).map(|constants| {
168 target
169 .expr
170 .clone()
171 .in_list(constants.into_iter().map(lit).collect(), negated)
172 }))
173}
174
175fn rewrite_like_expr(
176 like: Like,
177 kind: PatternMatchKind,
178 schema: &DFSchemaRef,
179) -> Result<Transformed<Expr>> {
180 let original = match kind {
181 PatternMatchKind::Like => Expr::Like(like.clone()),
182 PatternMatchKind::SimilarTo => Expr::SimilarTo(like.clone()),
183 };
184 let Like {
185 negated,
186 expr,
187 pattern,
188 escape_char,
189 case_insensitive,
190 } = like;
191 let expr = *expr;
192 let pattern = *pattern;
193 let Some((target, constants)) =
194 extract_rewrite_operands(&expr, std::slice::from_ref(&pattern), schema)?
195 else {
196 return Ok(Transformed::no(original));
197 };
198 let Some(mut constants) = target.normalize_constants(&constants) else {
199 return Ok(Transformed::no(original));
200 };
201
202 let pattern = lit(constants
203 .pop()
204 .expect("pattern normalization expects one constant"));
205 let like = Like::new(
206 negated,
207 Box::new(target.expr.clone()),
208 Box::new(pattern),
209 escape_char,
210 case_insensitive,
211 );
212 let rewritten = match kind {
213 PatternMatchKind::Like => Expr::Like(like),
214 PatternMatchKind::SimilarTo => Expr::SimilarTo(like),
215 };
216 Ok(Transformed::yes(rewritten))
217}
218
219fn rewrite_binary_expr(binary: BinaryExpr, schema: &DFSchemaRef) -> Result<Option<Expr>> {
220 if !binary.op.supports_propagation() {
221 return Ok(None);
222 }
223
224 let BinaryExpr { left, op, right } = binary;
225 let left = *left;
226 let right = *right;
227 if let Some(expr) = rewrite_binary_side(left.clone(), op, right.clone(), schema)? {
228 return Ok(Some(expr));
229 }
230
231 let Some(swapped_op) = op.swap() else {
232 return Ok(None);
233 };
234
235 rewrite_binary_side(right, swapped_op, left, schema)
236}
237
238fn rewrite_binary_side(
239 target_expr: Expr,
240 op: Operator,
241 constant_expr: Expr,
242 schema: &DFSchemaRef,
243) -> Result<Option<Expr>> {
244 let Some((target, constants)) =
245 extract_rewrite_operands(&target_expr, std::slice::from_ref(&constant_expr), schema)?
246 else {
247 return Ok(None);
248 };
249
250 if let Some(mut constants) = target.normalize_constants(&constants) {
251 let constant = constants
252 .pop()
253 .expect("binary normalization expects one constant");
254 return Ok(Some(Expr::BinaryExpr(BinaryExpr {
255 left: Box::new(target.expr.clone()),
256 op,
257 right: Box::new(lit(constant)),
258 })));
259 }
260
261 Ok(target.normalize_timestamp_binary(op, &constants[0]))
262}
263
264fn extract_rewrite_operands(
265 target_expr: &Expr,
266 constant_exprs: &[Expr],
267 schema: &DFSchemaRef,
268) -> Result<Option<(NormalizationTarget, Vec<ScalarValue>)>> {
269 let Some(target) = extract_normalization_target(target_expr, schema)? else {
270 return Ok(None);
271 };
272
273 extract_constant_scalars(constant_exprs)
274 .map(|constants| constants.map(|constants| (target, constants)))
275}
276
277#[derive(Clone)]
278struct NormalizationTarget {
279 expr: Expr,
280 data_type: DataType,
281 kind: NormalizationKind,
282}
283
284#[derive(Clone)]
285enum NormalizationKind {
286 Lossless,
288 TimestampDowncast {
290 source_unit: ArrowTimeUnit,
291 target_unit: ArrowTimeUnit,
292 timezone: Option<Arc<str>>,
293 },
294}
295
296impl NormalizationTarget {
297 fn normalize_constants(&self, constants: &[ScalarValue]) -> Option<Vec<ScalarValue>> {
300 constants
301 .iter()
302 .map(|constant| self.normalize_constant(constant))
303 .collect()
304 }
305
306 fn normalize_constant(&self, constant: &ScalarValue) -> Option<ScalarValue> {
307 match self.kind {
308 NormalizationKind::TimestampDowncast { .. } => None,
309 NormalizationKind::Lossless => try_cast_literal_to_type(constant, &self.data_type),
310 }
311 }
312
313 fn normalize_timestamp_binary(&self, op: Operator, constant: &ScalarValue) -> Option<Expr> {
315 let NormalizationKind::TimestampDowncast {
316 source_unit,
317 target_unit,
318 timezone,
319 } = &self.kind
320 else {
321 return None;
322 };
323
324 let constant = constant
325 .cast_to(&DataType::Timestamp(*target_unit, timezone.clone()))
326 .ok()?;
327 let value = timestamp_scalar_value(&constant)?;
328 let bound = match op {
329 Operator::GtEq => lower_bound_for_ge(value, *source_unit, *target_unit)?,
330 Operator::Gt => lower_bound_for_ge(value.checked_add(1)?, *source_unit, *target_unit)?,
331 Operator::Lt => lower_bound_for_ge(value, *source_unit, *target_unit)?,
332 Operator::LtEq => {
333 lower_bound_for_ge(value.checked_add(1)?, *source_unit, *target_unit)?
334 }
335 _ => return None,
336 };
337
338 let normalized_op = match op {
339 Operator::GtEq | Operator::Gt => Operator::GtEq,
340 Operator::Lt | Operator::LtEq => Operator::Lt,
341 _ => return None,
342 };
343
344 Some(match normalized_op {
345 Operator::GtEq => self.expr.clone().gt_eq(lit(timestamp_scalar(
346 *source_unit,
347 timezone.clone(),
348 bound,
349 ))),
350 Operator::Lt => {
351 self.expr
352 .clone()
353 .lt(lit(timestamp_scalar(*source_unit, timezone.clone(), bound)))
354 }
355 _ => unreachable!("timestamp normalization only rewrites to >= or <"),
356 })
357 }
358
359 fn normalize_timestamp_between(&self, low: &ScalarValue, high: &ScalarValue) -> Option<Expr> {
362 let NormalizationKind::TimestampDowncast {
363 source_unit,
364 target_unit,
365 timezone,
366 } = &self.kind
367 else {
368 return None;
369 };
370
371 let target_type = DataType::Timestamp(*target_unit, timezone.clone());
372 let low = low.cast_to(&target_type).ok()?;
373 let high = high.cast_to(&target_type).ok()?;
374 let low = timestamp_scalar_value(&low)?;
375 let high = timestamp_scalar_value(&high)?;
376
377 let lower = lower_bound_for_ge(low, *source_unit, *target_unit)?;
378 let upper = lower_bound_for_ge(high.checked_add(1)?, *source_unit, *target_unit)?;
379
380 Some(
381 self.expr
382 .clone()
383 .gt_eq(lit(timestamp_scalar(*source_unit, timezone.clone(), lower)))
384 .and(self.expr.clone().lt(lit(timestamp_scalar(
385 *source_unit,
386 timezone.clone(),
387 upper,
388 )))),
389 )
390 }
391}
392
393fn extract_normalization_target(
398 expr: &Expr,
399 schema: &DFSchemaRef,
400) -> Result<Option<NormalizationTarget>> {
401 if extract_constant_scalar(expr)?.is_some() {
402 return Ok(None);
403 }
404
405 let Some((_, source_expr, target_type)) = extract_cast_input(expr) else {
406 return Ok(Some(NormalizationTarget {
407 expr: expr.clone(),
408 data_type: expr.get_type(schema)?,
409 kind: NormalizationKind::Lossless,
410 }));
411 };
412
413 let data_type = source_expr.get_type(schema)?;
414 let Some(kind) = classify_normalization_kind(&data_type, target_type) else {
415 return Ok(None);
416 };
417
418 Ok(Some(NormalizationTarget {
419 expr: source_expr.clone(),
420 data_type,
421 kind,
422 }))
423}
424
425fn classify_normalization_kind(
426 source_type: &DataType,
427 target_type: &DataType,
428) -> Option<NormalizationKind> {
429 if is_lossless_cast(source_type, target_type) {
433 return Some(NormalizationKind::Lossless);
434 }
435
436 match (source_type, target_type) {
437 (
438 DataType::Timestamp(source_unit, source_tz),
439 DataType::Timestamp(target_unit, target_tz),
440 ) if source_tz == target_tz
441 && time_unit_rank(*source_unit) > time_unit_rank(*target_unit) =>
442 {
443 Some(NormalizationKind::TimestampDowncast {
444 source_unit: *source_unit,
445 target_unit: *target_unit,
446 timezone: source_tz.clone(),
447 })
448 }
449 _ => None,
450 }
451}
452
453fn is_lossless_cast(source_type: &DataType, target_type: &DataType) -> bool {
455 match (source_type, target_type) {
456 (DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64)
457 | (DataType::Int16, DataType::Int32 | DataType::Int64)
458 | (DataType::Int32, DataType::Int64)
459 | (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64)
460 | (DataType::UInt8, DataType::Int16 | DataType::Int32 | DataType::Int64)
461 | (DataType::UInt16, DataType::UInt32 | DataType::UInt64)
462 | (DataType::UInt16, DataType::Int32 | DataType::Int64)
463 | (DataType::UInt32, DataType::UInt64 | DataType::Int64)
464 | (DataType::Utf8, DataType::Utf8View | DataType::LargeUtf8) => true,
465 (
466 DataType::Timestamp(source_unit, source_tz),
467 DataType::Timestamp(target_unit, target_tz),
468 ) => source_tz == target_tz && source_unit == target_unit,
469 _ => false,
470 }
471}
472
473#[derive(Clone, Copy)]
474enum PatternMatchKind {
475 Like,
476 SimilarTo,
477}
478
479fn extract_constant_scalars(exprs: &[Expr]) -> Result<Option<Vec<ScalarValue>>> {
480 let mut values = Vec::with_capacity(exprs.len());
481 for expr in exprs {
482 let Some(value) = extract_constant_scalar(expr)? else {
483 return Ok(None);
484 };
485 values.push(value);
486 }
487
488 Ok(Some(values))
489}
490
491fn extract_constant_scalar(expr: &Expr) -> Result<Option<ScalarValue>> {
493 if let Some(value) = expr.as_literal() {
494 return Ok(Some(value.clone()));
495 }
496
497 let Some((kind, expr, data_type)) = extract_cast_input(expr) else {
498 return Ok(None);
499 };
500
501 match kind {
502 CastInputKind::Cast => extract_constant_scalar(expr)?
503 .map(|value| value.cast_to(data_type))
504 .transpose(),
505 CastInputKind::TryCast => {
506 Ok(extract_constant_scalar(expr)?.and_then(|value| value.cast_to(data_type).ok()))
507 }
508 }
509}
510
511#[derive(Clone, Copy)]
512enum CastInputKind {
513 Cast,
514 TryCast,
515}
516
517fn extract_cast_input(expr: &Expr) -> Option<(CastInputKind, &Expr, &DataType)> {
519 match expr {
520 Expr::Cast(Cast { expr, data_type }) => {
521 Some((CastInputKind::Cast, expr.as_ref(), data_type))
522 }
523 Expr::TryCast(TryCast { expr, data_type }) => {
524 Some((CastInputKind::TryCast, expr.as_ref(), data_type))
525 }
526 _ => None,
527 }
528}
529
530fn time_unit_rank(unit: ArrowTimeUnit) -> usize {
531 match unit {
532 ArrowTimeUnit::Second => 0,
533 ArrowTimeUnit::Millisecond => 1,
534 ArrowTimeUnit::Microsecond => 2,
535 ArrowTimeUnit::Nanosecond => 3,
536 }
537}
538
539fn time_unit_scale(unit: ArrowTimeUnit) -> i64 {
540 match unit {
541 ArrowTimeUnit::Second => 1,
542 ArrowTimeUnit::Millisecond => 1_000,
543 ArrowTimeUnit::Microsecond => 1_000_000,
544 ArrowTimeUnit::Nanosecond => 1_000_000_000,
545 }
546}
547
548fn finer_to_coarser_ratio(source_unit: ArrowTimeUnit, target_unit: ArrowTimeUnit) -> Option<i64> {
550 let source_scale = time_unit_scale(source_unit);
551 let target_scale = time_unit_scale(target_unit);
552 (source_scale >= target_scale).then_some(source_scale / target_scale)
553}
554
555fn lower_bound_for_ge(
562 target_value: i64,
563 source_unit: ArrowTimeUnit,
564 target_unit: ArrowTimeUnit,
565) -> Option<i64> {
566 let ratio = finer_to_coarser_ratio(source_unit, target_unit)?;
567 let base = target_value.checked_mul(ratio)?;
568 if target_value <= 0 {
569 base.checked_sub(ratio - 1)
570 } else {
571 Some(base)
572 }
573}
574
575fn timestamp_scalar_value(value: &ScalarValue) -> Option<i64> {
576 match value {
577 ScalarValue::TimestampSecond(Some(value), _)
578 | ScalarValue::TimestampMillisecond(Some(value), _)
579 | ScalarValue::TimestampMicrosecond(Some(value), _)
580 | ScalarValue::TimestampNanosecond(Some(value), _) => Some(*value),
581 _ => None,
582 }
583}
584
585fn timestamp_scalar(unit: ArrowTimeUnit, timezone: Option<Arc<str>>, value: i64) -> ScalarValue {
586 match unit {
587 ArrowTimeUnit::Second => ScalarValue::TimestampSecond(Some(value), timezone),
588 ArrowTimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(value), timezone),
589 ArrowTimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(value), timezone),
590 ArrowTimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(value), timezone),
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use std::sync::Arc;
597
598 use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
599 use async_trait::async_trait;
600 use common_time::Timestamp;
601 use common_time::range::TimestampRange;
602 use common_time::timestamp::TimeUnit;
603 use datafusion::catalog::Session;
604 use datafusion::config::ConfigOptions;
605 use datafusion::datasource::{TableProvider, provider_as_source};
606 use datafusion::physical_plan::ExecutionPlan;
607 use datafusion_common::arrow::datatypes::Field;
608 use datafusion_common::{DFSchema, ScalarValue, ToDFSchema};
609 use datafusion_expr::expr::{Between, Like};
610 use datafusion_expr::expr_fn::{cast, col, try_cast};
611 use datafusion_expr::{
612 Expr, LogicalPlan, LogicalPlanBuilder, TableProviderFilterPushDown, TableScan, TableSource,
613 TableType, lit,
614 };
615 use datafusion_optimizer::analyzer::AnalyzerRule;
616 use datafusion_optimizer::optimizer::{Optimizer, OptimizerContext};
617 use datafusion_optimizer::push_down_filter::PushDownFilter;
618 use table::predicate::build_time_range_predicate;
619
620 use super::{
621 ConstNormalizationRule, PatternMatchKind, lower_bound_for_ge, try_cast_literal_to_type,
622 };
623
624 #[test]
625 fn test_normalize_direct_integer_cast_comparison() {
626 assert_filter_plan(
627 vec![Field::new("v", DataType::Int32, false)],
628 cast(col("v"), DataType::Int64).gt_eq(lit(42_i64)),
629 "Filter: t.v >= Int32(42)\n TableScan: t",
630 );
631 }
632
633 #[test]
634 fn test_normalize_non_column_operand() {
635 assert_filter_plan(
636 vec![Field::new("v", DataType::Int32, false)],
637 cast(col("v") + lit(1_i32), DataType::Int64).gt_eq(lit(42_i64)),
638 "Filter: t.v + Int32(1) >= Int32(42)\n TableScan: t",
639 );
640 }
641
642 #[test]
643 fn test_normalize_swapped_binary_comparison() {
644 assert_filter_plan(
645 vec![Field::new("v", DataType::Int16, false)],
646 lit(42_i64).lt_eq(cast(col("v"), DataType::Int64)),
647 "Filter: t.v >= Int16(42)\n TableScan: t",
648 );
649 }
650
651 #[test]
652 fn test_normalize_try_cast_target() {
653 assert_filter_plan(
654 vec![Field::new("v", DataType::Int16, false)],
655 try_cast(col("v"), DataType::Int64).gt_eq(lit(42_i64)),
656 "Filter: t.v >= Int16(42)\n TableScan: t",
657 );
658 }
659
660 #[test]
661 fn test_normalize_casted_constants() {
662 let fields = vec![Field::new("v", DataType::Int16, false)];
663 let cases = [
664 (
665 col("v").gt_eq(cast(lit(42_i8), DataType::Int64)),
666 "Filter: t.v >= Int16(42)\n TableScan: t",
667 ),
668 (
669 col("v").in_list(
670 vec![
671 cast(lit(1_i8), DataType::Int64),
672 try_cast(lit(2_i8), DataType::Int64),
673 ],
674 false,
675 ),
676 "Filter: t.v IN ([Int16(1), Int16(2)])\n TableScan: t",
677 ),
678 ];
679
680 for (predicate, expected) in cases {
681 assert_filter_plan(fields.clone(), predicate, expected);
682 }
683 }
684
685 #[test]
686 fn test_normalize_plain_integer_literals() {
687 let fields = vec![Field::new("v", DataType::Int16, false)];
688 let cases = [
689 (
690 col("v").gt_eq(lit(42_i64)),
691 "Filter: t.v >= Int16(42)\n TableScan: t",
692 ),
693 (
694 col("v").in_list(vec![lit(1_i64), lit(2_i64)], false),
695 "Filter: t.v IN ([Int16(1), Int16(2)])\n TableScan: t",
696 ),
697 (
698 col("v").between(lit(3_i64), lit(5_i64)),
699 "Filter: t.v BETWEEN Int16(3) AND Int16(5)\n TableScan: t",
700 ),
701 ];
702
703 for (predicate, expected) in cases {
704 assert_filter_plan(fields.clone(), predicate, expected);
705 }
706 }
707
708 #[test]
709 fn test_normalize_unsigned_to_signed_literals() {
710 let cases = [
711 (
712 vec![Field::new("v", DataType::UInt8, false)],
713 cast(col("v"), DataType::Int16).lt_eq(lit(255_i16)),
714 "Filter: t.v <= UInt8(255)\n TableScan: t",
715 ),
716 (
717 vec![Field::new("v", DataType::UInt16, false)],
718 cast(col("v"), DataType::Int32).gt_eq(lit(42_i32)),
719 "Filter: t.v >= UInt16(42)\n TableScan: t",
720 ),
721 (
722 vec![Field::new("v", DataType::UInt32, false)],
723 cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
724 "Filter: t.v BETWEEN UInt32(3) AND UInt32(5)\n TableScan: t",
725 ),
726 ];
727
728 for (fields, predicate, expected) in cases {
729 assert_filter_plan(fields, predicate, expected);
730 }
731 }
732
733 #[test]
734 fn test_normalize_in_list_and_between() {
735 let fields = vec![Field::new("v", DataType::Int16, false)];
736 let cases = [
737 (
738 cast(col("v"), DataType::Int64).in_list(vec![lit(1_i64), lit(2_i64)], false),
739 "Filter: t.v IN ([Int16(1), Int16(2)])\n TableScan: t",
740 ),
741 (
742 cast(col("v"), DataType::Int64).between(lit(3_i64), lit(5_i64)),
743 "Filter: t.v BETWEEN Int16(3) AND Int16(5)\n TableScan: t",
744 ),
745 ];
746
747 for (predicate, expected) in cases {
748 assert_filter_plan(fields.clone(), predicate, expected);
749 }
750 }
751
752 #[test]
753 fn test_keep_non_lossless_literal_unchanged() {
754 assert_filter_plan(
755 vec![Field::new("v", DataType::Int16, false)],
756 col("v").gt_eq(lit(100_000_i64)),
757 "Filter: t.v >= Int64(100000)\n TableScan: t",
758 );
759 }
760
761 #[test]
762 fn test_normalize_scan_filters() {
763 let scan = build_scan_plan(test_schema(vec![Field::new("v", DataType::Int16, false)]));
764 let LogicalPlan::TableScan(scan) = scan else {
765 panic!("expected table scan");
766 };
767 let plan = LogicalPlan::TableScan(TableScan {
768 filters: vec![cast(col("v"), DataType::Int64).gt_eq(lit(42_i64))],
769 ..scan
770 });
771
772 let analyzed = analyze_plan(plan);
773
774 assert_eq!(
775 vec![col("v").gt_eq(lit(42_i16))],
776 extract_scan_filters(&analyzed)
777 );
778 }
779
780 #[test]
781 fn test_normalize_negated_between() {
782 assert_filter_plan(
783 vec![Field::new("v", DataType::Int16, false)],
784 Expr::Between(Between {
785 expr: Box::new(cast(col("v"), DataType::Int64)),
786 negated: true,
787 low: Box::new(lit(3_i64)),
788 high: Box::new(lit(5_i64)),
789 }),
790 "Filter: t.v NOT BETWEEN Int16(3) AND Int16(5)\n TableScan: t",
791 );
792 }
793
794 #[test]
795 fn test_normalize_like_literal() {
796 assert_pattern_match_plan(
797 PatternMatchKind::Like,
798 ScalarValue::LargeUtf8(Some("api%".to_string())),
799 "Filter: t.s LIKE Utf8(\"api%\")\n TableScan: t",
800 );
801 }
802
803 #[test]
804 fn test_normalize_similar_to_literal() {
805 assert_pattern_match_plan(
806 PatternMatchKind::SimilarTo,
807 ScalarValue::LargeUtf8(Some("api.*".to_string())),
808 "Filter: t.s SIMILAR TO Utf8(\"api.*\")\n TableScan: t",
809 );
810 }
811
812 #[test]
813 fn test_normalize_direct_timestamp_filter() {
814 assert_timestamp_pushdown(
815 vec![
816 Field::new(
817 "ts",
818 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
819 false,
820 ),
821 Field::new("tag", DataType::Utf8, true),
822 ],
823 ts_cast_to_ms()
824 .gt_eq(ts_ms_literal(-299_999))
825 .and(ts_cast_to_ms().lt_eq(ts_ms_literal(10_000)))
826 .and(col("tag").eq(lit("api"))),
827 "Filter: t.ts >= TimestampNanosecond(-299999999999, None) AND t.ts < TimestampNanosecond(10001000000, None) AND t.tag = Utf8(\"api\")\n TableScan: t",
828 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999999999, None), t.ts < TimestampNanosecond(10001000000, None), t.tag = Utf8(\"api\")]",
829 TimestampRange::new_inclusive(
830 Some(Timestamp::new_nanosecond(-299_999_999_999)),
831 Some(Timestamp::new_nanosecond(10_000_999_999)),
832 ),
833 );
834 }
835
836 #[test]
837 fn test_normalize_timestamp_between_filter() {
838 assert_timestamp_pushdown(
839 vec![Field::new(
840 "ts",
841 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
842 false,
843 )],
844 ts_cast_to_ms().between(ts_ms_literal(-299_999), ts_ms_literal(10_000)),
845 "Filter: t.ts >= TimestampNanosecond(-299999999999, None) AND t.ts < TimestampNanosecond(10001000000, None)\n TableScan: t",
846 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999999999, None), t.ts < TimestampNanosecond(10001000000, None)]",
847 TimestampRange::new_inclusive(
848 Some(Timestamp::new_nanosecond(-299_999_999_999)),
849 Some(Timestamp::new_nanosecond(10_000_999_999)),
850 ),
851 );
852 }
853
854 #[test]
855 fn test_normalize_strict_timestamp_filter() {
856 assert_timestamp_pushdown(
857 vec![Field::new(
858 "ts",
859 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
860 false,
861 )],
862 ts_cast_to_ms()
863 .gt(ts_ms_literal(10_000))
864 .and(ts_cast_to_ms().lt(ts_ms_literal(20_000))),
865 "Filter: t.ts >= TimestampNanosecond(10001000000, None) AND t.ts < TimestampNanosecond(20000000000, None)\n TableScan: t",
866 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(10001000000, None), t.ts < TimestampNanosecond(20000000000, None)]",
867 TimestampRange::new_inclusive(
868 Some(Timestamp::new_nanosecond(10_001_000_000)),
869 Some(Timestamp::new_nanosecond(19_999_999_999)),
870 ),
871 );
872 }
873
874 #[test]
875 fn test_normalize_zero_boundary_timestamp_filter() {
876 let fields = vec![Field::new(
877 "ts",
878 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
879 false,
880 )];
881
882 assert_timestamp_pushdown(
883 fields.clone(),
884 ts_cast_to_ms().gt_eq(ts_ms_literal(0)),
885 "Filter: t.ts >= TimestampNanosecond(-999999, None)\n TableScan: t",
886 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-999999, None)]",
887 TimestampRange::from_start(Timestamp::new_nanosecond(-999_999)),
888 );
889
890 assert_timestamp_pushdown(
891 fields.clone(),
892 ts_cast_to_ms().lt(ts_ms_literal(0)),
893 "Filter: t.ts < TimestampNanosecond(-999999, None)\n TableScan: t",
894 "TableScan: t, full_filters=[t.ts < TimestampNanosecond(-999999, None)]",
895 TimestampRange::until_end(Timestamp::new_nanosecond(-999_999), false),
896 );
897
898 assert_timestamp_pushdown(
899 fields,
900 ts_cast_to_ms().between(ts_ms_literal(0), ts_ms_literal(0)),
901 "Filter: t.ts >= TimestampNanosecond(-999999, None) AND t.ts < TimestampNanosecond(1000000, None)\n TableScan: t",
902 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-999999, None), t.ts < TimestampNanosecond(1000000, None)]",
903 TimestampRange::new_inclusive(
904 Some(Timestamp::new_nanosecond(-999_999)),
905 Some(Timestamp::new_nanosecond(999_999)),
906 ),
907 );
908 }
909
910 #[test]
911 fn test_timestamp_downcast_contract_matches_datafusion_casts() {
912 let cases = [
913 (-1_000_001, -1),
914 (-1_000_000, -1),
915 (-999_999, 0),
916 (-1, 0),
917 (0, 0),
918 (999_999, 0),
919 (1_000_000, 1),
920 ];
921
922 for (source, expected) in cases {
923 let casted = try_cast_literal_to_type(
924 &ScalarValue::TimestampNanosecond(Some(source), None),
925 &DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
926 )
927 .unwrap();
928 assert_eq!(
929 ScalarValue::TimestampMillisecond(Some(expected), None),
930 casted
931 );
932 }
933
934 assert_eq!(
935 Some(-1_999_999),
936 lower_bound_for_ge(-1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
937 );
938 assert_eq!(
939 Some(-999_999),
940 lower_bound_for_ge(0, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
941 );
942 assert_eq!(
943 Some(1_000_000),
944 lower_bound_for_ge(1, ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond)
945 );
946 }
947
948 #[test]
949 fn test_normalize_plain_timestamp_literals() {
950 assert_timestamp_pushdown(
951 vec![Field::new(
952 "ts",
953 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
954 false,
955 )],
956 col("ts")
957 .gt_eq(ts_ms_literal(-299_999))
958 .and(col("ts").lt_eq(ts_ms_literal(10_000))),
959 "Filter: t.ts >= TimestampNanosecond(-299999000000, None) AND t.ts <= TimestampNanosecond(10000000000, None)\n TableScan: t",
960 "TableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999000000, None), t.ts <= TimestampNanosecond(10000000000, None)]",
961 TimestampRange::new_inclusive(
962 Some(Timestamp::new_nanosecond(-299_999_000_000)),
963 Some(Timestamp::new_nanosecond(10_000_000_000)),
964 ),
965 );
966 }
967
968 #[test]
969 fn test_keep_timestamp_upcast_filter_unchanged() {
970 assert_filter_plan(
971 vec![Field::new(
972 "ts",
973 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
974 false,
975 )],
976 cast(
977 col("ts"),
978 DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
979 )
980 .gt_eq(lit(ScalarValue::TimestampNanosecond(Some(1), None))),
981 "Filter: CAST(t.ts AS Timestamp(ns)) >= TimestampNanosecond(1, None)\n TableScan: t",
982 );
983 }
984
985 fn assert_pattern_match_plan(kind: PatternMatchKind, pattern: ScalarValue, expected: &str) {
986 let predicate = match kind {
987 PatternMatchKind::Like => Expr::Like(Like::new(
988 false,
989 Box::new(cast(col("s"), DataType::LargeUtf8)),
990 Box::new(lit(pattern)),
991 None,
992 false,
993 )),
994 PatternMatchKind::SimilarTo => Expr::SimilarTo(Like::new(
995 false,
996 Box::new(cast(col("s"), DataType::LargeUtf8)),
997 Box::new(lit(pattern)),
998 None,
999 false,
1000 )),
1001 };
1002
1003 assert_filter_plan(
1004 vec![Field::new("s", DataType::Utf8, false)],
1005 predicate,
1006 expected,
1007 );
1008 }
1009
1010 fn assert_filter_plan(fields: Vec<Field>, predicate: Expr, expected: &str) {
1011 assert_eq!(expected, analyze_filter(fields, predicate).to_string());
1012 }
1013
1014 fn assert_timestamp_pushdown(
1015 fields: Vec<Field>,
1016 predicate: Expr,
1017 expected_analyzed: &str,
1018 expected_pushed: &str,
1019 expected_range: TimestampRange,
1020 ) {
1021 let analyzed = analyze_filter(fields, predicate);
1022 assert_eq!(expected_analyzed, analyzed.to_string());
1023
1024 let pushed = push_down_filters(analyzed);
1025 assert_eq!(expected_pushed, pushed.to_string());
1026
1027 let range =
1028 build_time_range_predicate("ts", TimeUnit::Nanosecond, &extract_scan_filters(&pushed));
1029 assert_eq!(expected_range, range);
1030 }
1031
1032 fn analyze_filter(fields: Vec<Field>, predicate: Expr) -> LogicalPlan {
1033 analyze_plan(build_filter_plan(test_schema(fields), predicate))
1034 }
1035
1036 fn analyze_plan(plan: LogicalPlan) -> LogicalPlan {
1037 ConstNormalizationRule
1038 .analyze(plan, &ConfigOptions::default())
1039 .unwrap()
1040 }
1041
1042 fn build_filter_plan(schema: Arc<DFSchema>, predicate: Expr) -> LogicalPlan {
1043 LogicalPlanBuilder::scan("t", test_source(schema), None)
1044 .unwrap()
1045 .filter(predicate)
1046 .unwrap()
1047 .build()
1048 .unwrap()
1049 }
1050
1051 fn build_scan_plan(schema: Arc<DFSchema>) -> LogicalPlan {
1052 LogicalPlanBuilder::scan("t", test_source(schema), None)
1053 .unwrap()
1054 .build()
1055 .unwrap()
1056 }
1057
1058 fn push_down_filters(plan: LogicalPlan) -> LogicalPlan {
1059 Optimizer::with_rules(vec![Arc::new(PushDownFilter::new())])
1060 .optimize(plan, &OptimizerContext::new(), |_, _| {})
1061 .unwrap()
1062 }
1063
1064 fn ts_cast_to_ms() -> Expr {
1065 cast(
1066 col("ts"),
1067 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
1068 )
1069 }
1070
1071 fn ts_ms_literal(value: i64) -> Expr {
1072 lit(ScalarValue::TimestampMillisecond(Some(value), None))
1073 }
1074
1075 fn extract_scan_filters(plan: &LogicalPlan) -> Vec<Expr> {
1076 match plan {
1077 LogicalPlan::TableScan(scan) => scan.filters.clone(),
1078 _ => plan
1079 .inputs()
1080 .into_iter()
1081 .flat_map(extract_scan_filters)
1082 .collect(),
1083 }
1084 }
1085
1086 fn test_schema(fields: Vec<Field>) -> Arc<DFSchema> {
1087 arrow_schema::Schema::new(fields).to_dfschema_ref().unwrap()
1088 }
1089
1090 fn test_source(schema: Arc<DFSchema>) -> Arc<dyn TableSource> {
1091 let table = ExactPushdownProvider {
1092 schema: Arc::new(schema.as_ref().as_arrow().clone()),
1093 };
1094 provider_as_source(Arc::new(table))
1095 }
1096
1097 #[derive(Debug)]
1098 struct ExactPushdownProvider {
1099 schema: arrow_schema::SchemaRef,
1100 }
1101
1102 #[async_trait]
1103 impl TableProvider for ExactPushdownProvider {
1104 fn as_any(&self) -> &dyn std::any::Any {
1105 self
1106 }
1107
1108 fn schema(&self) -> arrow_schema::SchemaRef {
1109 self.schema.clone()
1110 }
1111
1112 fn table_type(&self) -> TableType {
1113 TableType::Base
1114 }
1115
1116 async fn scan(
1117 &self,
1118 _state: &dyn Session,
1119 _projection: Option<&Vec<usize>>,
1120 _filters: &[Expr],
1121 _limit: Option<usize>,
1122 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
1123 unreachable!("scan should not be called in const_normalization tests")
1124 }
1125
1126 fn supports_filters_pushdown(
1127 &self,
1128 filters: &[&Expr],
1129 ) -> datafusion::error::Result<Vec<TableProviderFilterPushDown>> {
1130 Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
1131 }
1132 }
1133}