From 444a00c866caff32647e846e0cb9edcf4ca015af Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 4 Apr 2026 03:50:44 +0800 Subject: [PATCH] handle no cast Signed-off-by: Ruihang Xia --- .../src/optimizer/const_normalization.rs | 181 ++++++++++++++---- .../common/tql-explain-analyze/explain.result | 43 +++++ .../time_index_filter_pushdown.result | 23 ++- .../optimizer/time_index_filter_pushdown.sql | 10 +- 4 files changed, 213 insertions(+), 44 deletions(-) diff --git a/src/query/src/optimizer/const_normalization.rs b/src/query/src/optimizer/const_normalization.rs index 8c8dc60089..871c88c921 100644 --- a/src/query/src/optimizer/const_normalization.rs +++ b/src/query/src/optimizer/const_normalization.rs @@ -24,9 +24,9 @@ use datafusion_expr::{ }; use datafusion_optimizer::analyzer::AnalyzerRule; -/// ConstNormalizationRule removes casts on column operands in filters when the -/// constant side can be normalized to the source column type ahead of filter -/// pushdown. +/// ConstNormalizationRule rewrites filter literals to the column-side type +/// ahead of filter pushdown, and also removes lossless casts on column +/// operands when the literal can be normalized to the source column type. #[derive(Debug)] pub struct ConstNormalizationRule; @@ -128,15 +128,13 @@ fn normalize_binary_expr( right: Expr, schema: &DFSchemaRef, ) -> Result { - if let Some(expr) = - normalize_binary_with_casted_column(left.clone(), op, right.clone(), schema)? - { + if let Some(expr) = normalize_binary_with_target(left.clone(), op, right.clone(), schema)? { return Ok(expr); } if let Some(swapped_op) = op.swap() && let Some(expr) = - normalize_binary_with_casted_column(right.clone(), swapped_op, left.clone(), schema)? + normalize_binary_with_target(right.clone(), swapped_op, left.clone(), schema)? { return Ok(expr); } @@ -148,28 +146,24 @@ fn normalize_binary_expr( })) } -fn normalize_binary_with_casted_column( - casted_column: Expr, +fn normalize_binary_with_target( + target: Expr, op: Operator, constant: Expr, schema: &DFSchemaRef, ) -> Result> { - let Some(casted_column) = extract_casted_column(&casted_column, schema)? else { + let Some(target) = extract_normalization_target(&target, schema)? else { return Ok(None); }; let Some(constant) = extract_constant_scalar(&constant)? else { return Ok(None); }; - if let Some(expr) = normalize_lossless_binary(&casted_column, op, &constant) { + if let Some(expr) = normalize_lossless_binary(&target, op, &constant) { return Ok(Some(expr)); } - Ok(normalize_timestamp_downcast_binary( - &casted_column, - op, - &constant, - )) + Ok(normalize_timestamp_downcast_binary(&target, op, &constant)) } fn normalize_between_expr( @@ -188,7 +182,7 @@ fn normalize_between_expr( })); } - let Some(casted_column) = extract_casted_column(&expr, schema)? else { + let Some(target) = extract_normalization_target(&expr, schema)? else { return Ok(Expr::Between(Between { expr: Box::new(expr), negated, @@ -214,20 +208,18 @@ fn normalize_between_expr( }; if let (Some(low), Some(high)) = ( - normalize_lossless_literal(&casted_column, &low_value), - normalize_lossless_literal(&casted_column, &high_value), + normalize_lossless_literal(&target, &low_value), + normalize_lossless_literal(&target, &high_value), ) { return Ok(Expr::Between(Between { - expr: Box::new(casted_column.source_expr.clone()), + expr: Box::new(target.source_expr.clone()), negated, low: Box::new(Expr::Literal(low, None)), high: Box::new(Expr::Literal(high, None)), })); } - if let Some(expr) = - normalize_timestamp_downcast_between(&casted_column, &low_value, &high_value) - { + if let Some(expr) = normalize_timestamp_downcast_between(&target, &low_value, &high_value) { return Ok(expr); } @@ -245,7 +237,7 @@ fn normalize_in_list_expr( negated: bool, schema: &DFSchemaRef, ) -> Result { - let Some(casted_column) = extract_casted_column(&expr, schema)? else { + let Some(target) = extract_normalization_target(&expr, schema)? else { return Ok(Expr::InList(InList { expr: Box::new(expr), list, @@ -262,7 +254,7 @@ fn normalize_in_list_expr( negated, })); }; - let Some(normalized) = normalize_lossless_literal(&casted_column, &value) else { + let Some(normalized) = normalize_lossless_literal(&target, &value) else { return Ok(Expr::InList(InList { expr: Box::new(expr), list, @@ -273,14 +265,14 @@ fn normalize_in_list_expr( } Ok(Expr::InList(InList { - expr: Box::new(casted_column.source_expr.clone()), + expr: Box::new(target.source_expr.clone()), list: new_list, negated, })) } #[derive(Clone)] -struct CastedColumn { +struct NormalizationTarget { source_expr: Expr, source_type: DataType, kind: CastKind, @@ -296,7 +288,18 @@ enum CastKind { }, } -fn extract_casted_column(expr: &Expr, schema: &DFSchemaRef) -> Result> { +fn extract_normalization_target( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result> { + if matches!(expr, Expr::Column(_)) { + return Ok(Some(NormalizationTarget { + source_expr: expr.clone(), + source_type: expr.get_type(schema)?, + kind: CastKind::Lossless, + })); + } + let (source_expr, target_type) = match expr { Expr::Cast(Cast { expr, data_type }) | Expr::TryCast(TryCast { expr, data_type }) => { (expr.as_ref(), data_type) @@ -313,7 +316,7 @@ fn extract_casted_column(expr: &Expr, schema: &DFSchemaRef) -> Result bo } fn normalize_lossless_binary( - casted_column: &CastedColumn, + target: &NormalizationTarget, op: Operator, constant: &ScalarValue, ) -> Option { - let normalized = normalize_lossless_literal(casted_column, constant)?; + let normalized = normalize_lossless_literal(target, constant)?; Some(Expr::BinaryExpr(BinaryExpr { - left: Box::new(casted_column.source_expr.clone()), + left: Box::new(target.source_expr.clone()), op, right: Box::new(Expr::Literal(normalized, None)), })) } fn normalize_lossless_literal( - casted_column: &CastedColumn, + target: &NormalizationTarget, constant: &ScalarValue, ) -> Option { - matches!(casted_column.kind, CastKind::Lossless) + matches!(target.kind, CastKind::Lossless) .then_some(()) - .and_then(|_| cast_literal_losslessly(constant, &casted_column.source_type)) + .and_then(|_| cast_literal_losslessly(constant, &target.source_type)) } fn normalize_timestamp_downcast_binary( - casted_column: &CastedColumn, + target: &NormalizationTarget, op: Operator, constant: &ScalarValue, ) -> Option { @@ -389,7 +392,7 @@ fn normalize_timestamp_downcast_binary( source_unit, target_unit, timezone, - } = &casted_column.kind + } = &target.kind else { return None; }; @@ -413,7 +416,7 @@ fn normalize_timestamp_downcast_binary( }; Some(Expr::BinaryExpr(BinaryExpr { - left: Box::new(casted_column.source_expr.clone()), + left: Box::new(target.source_expr.clone()), op: normalized_op, right: Box::new(Expr::Literal( timestamp_scalar(*source_unit, timezone.clone(), bound), @@ -423,7 +426,7 @@ fn normalize_timestamp_downcast_binary( } fn normalize_timestamp_downcast_between( - casted_column: &CastedColumn, + target: &NormalizationTarget, low: &ScalarValue, high: &ScalarValue, ) -> Option { @@ -431,7 +434,7 @@ fn normalize_timestamp_downcast_between( source_unit, target_unit, timezone, - } = &casted_column.kind + } = &target.kind else { return None; }; @@ -447,7 +450,7 @@ fn normalize_timestamp_downcast_between( Some( Expr::BinaryExpr(BinaryExpr { - left: Box::new(casted_column.source_expr.clone()), + left: Box::new(target.source_expr.clone()), op: Operator::GtEq, right: Box::new(Expr::Literal( timestamp_scalar(*source_unit, timezone.clone(), lower), @@ -455,7 +458,7 @@ fn normalize_timestamp_downcast_between( )), }) .and(Expr::BinaryExpr(BinaryExpr { - left: Box::new(casted_column.source_expr.clone()), + left: Box::new(target.source_expr.clone()), op: Operator::Lt, right: Box::new(Expr::Literal( timestamp_scalar(*source_unit, timezone.clone(), upper), @@ -588,6 +591,54 @@ mod tests { ); } + #[test] + fn test_normalize_plain_integer_literals() { + let schema = test_schema(vec![Field::new("v", DataType::Int16, false)]); + let comparison_plan = LogicalPlanBuilder::scan("t", test_source(schema.clone()), None) + .unwrap() + .filter(col("v").gt_eq(lit(42_i64))) + .unwrap() + .build() + .unwrap(); + + let in_list_plan = LogicalPlanBuilder::scan("t", test_source(schema.clone()), None) + .unwrap() + .filter(col("v").in_list(vec![lit(1_i64), lit(2_i64)], false)) + .unwrap() + .build() + .unwrap(); + + let between_plan = LogicalPlanBuilder::scan("t", test_source(schema), None) + .unwrap() + .filter(col("v").between(lit(3_i64), lit(5_i64))) + .unwrap() + .build() + .unwrap(); + + let comparison = ConstNormalizationRule + .analyze(comparison_plan, &ConfigOptions::default()) + .unwrap(); + let in_list = ConstNormalizationRule + .analyze(in_list_plan, &ConfigOptions::default()) + .unwrap(); + let between = ConstNormalizationRule + .analyze(between_plan, &ConfigOptions::default()) + .unwrap(); + + assert_eq!( + "Filter: t.v >= Int16(42)\n TableScan: t", + comparison.to_string() + ); + assert_eq!( + "Filter: t.v IN ([Int16(1), Int16(2)])\n TableScan: t", + in_list.to_string() + ); + assert_eq!( + "Filter: t.v BETWEEN Int16(3) AND Int16(5)\n TableScan: t", + between.to_string() + ); + } + #[test] fn test_normalize_in_list_and_between() { let schema = test_schema(vec![Field::new("v", DataType::Int16, false)]); @@ -679,6 +730,52 @@ mod tests { ); } + #[test] + fn test_normalize_plain_timestamp_literals() { + let schema = test_schema(vec![Field::new( + "ts", + DataType::Timestamp(ArrowTimeUnit::Nanosecond, None), + false, + )]); + let plan = LogicalPlanBuilder::scan("t", test_source(schema), None) + .unwrap() + .filter( + col("ts") + .gt_eq(lit(ScalarValue::TimestampMillisecond(Some(-299_999), None))) + .and( + col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(10_000), None))), + ), + ) + .unwrap() + .build() + .unwrap(); + + let analyzed = ConstNormalizationRule + .analyze(plan, &ConfigOptions::default()) + .unwrap(); + let expected = "\ +\nFilter: t.ts >= TimestampNanosecond(-299999000000, None) AND t.ts <= TimestampNanosecond(10000000000, None)\ +\n TableScan: t"; + assert_eq!(expected.trim_start(), analyzed.to_string()); + + let pushed = Optimizer::with_rules(vec![Arc::new(PushDownFilter::new())]) + .optimize(analyzed, &OptimizerContext::new(), |_, _| {}) + .unwrap(); + let expected = "\ +\nTableScan: t, full_filters=[t.ts >= TimestampNanosecond(-299999000000, None), t.ts <= TimestampNanosecond(10000000000, None)]"; + assert_eq!(expected.trim_start(), pushed.to_string()); + + let filters = extract_scan_filters(&pushed); + let range = build_time_range_predicate("ts", TimeUnit::Nanosecond, &filters); + assert_eq!( + TimestampRange::new_inclusive( + Some(Timestamp::new_nanosecond(-299_999_000_000)), + Some(Timestamp::new_nanosecond(10_000_000_000)) + ), + range + ); + } + fn extract_scan_filters(plan: &LogicalPlan) -> Vec { match plan { LogicalPlan::TableScan(scan) => scan.filters.clone(), diff --git a/tests/cases/standalone/common/tql-explain-analyze/explain.result b/tests/cases/standalone/common/tql-explain-analyze/explain.result index 1e4cf18b40..7133f1cb91 100644 --- a/tests/cases/standalone/common/tql-explain-analyze/explain.result +++ b/tests/cases/standalone/common/tql-explain-analyze/explain.result @@ -90,6 +90,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test; | logical_plan after TranscribeAtatRule_| SAME TEXT AS ABOVE_| | logical_plan after resolve_grouping_function_| SAME TEXT AS ABOVE_| | logical_plan after type_coercion_| SAME TEXT AS ABOVE_| +| logical_plan after ConstNormalizationRule_| SAME TEXT AS ABOVE_| | logical_plan after DistPlannerAnalyzer_| Projection: test.i, test.j, test.k_| |_|_MergeScan [is_placeholder=false, remote_input=[_| |_| PromInstantManipulate: range=[0..10000], lookback=[300000], interval=[5000], time index=[j]_| @@ -233,6 +234,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series; | logical_plan after TranscribeAtatRule_| SAME TEXT AS ABOVE_| | logical_plan after resolve_grouping_function_| SAME TEXT AS ABOVE_| | logical_plan after type_coercion_| SAME TEXT AS ABOVE_| +| logical_plan after ConstNormalizationRule_| SAME TEXT AS ABOVE_| | logical_plan after DistPlannerAnalyzer_| Projection: series, test.k, test.j_| |_|_MergeScan [is_placeholder=false, remote_input=[_| |_| Projection: test.i AS series, test.k, test.j_| @@ -356,6 +358,47 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series; |_|_| +-+-+ +CREATE TABLE test_nano(i DOUBLE, j TIMESTAMP(9) TIME INDEX, k STRING PRIMARY KEY); + +Affected Rows: 0 + +INSERT INTO test_nano VALUES (1, 1000000, "a"), (1, 1000000, "b"), (2, 2000000, "a"); + +Affected Rows: 3 + +-- explain at 0s, 5s and 10s for a nanosecond time index. +-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED +-- SQLNESS REPLACE (peers.*) REDACTED +TQL EXPLAIN (0, 10, '5s') test_nano; + ++---------------+-----------------------------------------------------------------------------------------------------------------------------------+ +| plan_type | plan | ++---------------+-----------------------------------------------------------------------------------------------------------------------------------+ +| logical_plan | PromInstantManipulate: range=[0..10000], lookback=[300000], interval=[5000], time index=[j] | +| | PromSeriesDivide: tags=["k"] | +| | Sort: test_nano.k ASC NULLS FIRST, test_nano.j ASC NULLS FIRST | +| | Projection: test_nano.i, test_nano.k, CAST(test_nano.j AS Timestamp(ms)) AS j | +| | Projection: test_nano.i, test_nano.j, test_nano.k | +| | Filter: __common_expr_3 >= TimestampMillisecond(-299999, None) AND __common_expr_3 <= TimestampMillisecond(10000, None) | +| | Projection: CAST(test_nano.j AS Timestamp(ms)) AS __common_expr_3, test_nano.i, test_nano.j, test_nano.k | +| | MergeScan [is_placeholder=false, remote_input=[ | +| | TableScan: test_nano | +| | ]] | +| physical_plan | PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[j] | +| | PromSeriesDivideExec: tags=["k"] | +| | SortExec: expr=[k@1 ASC, j@2 ASC], preserve_partitioning=[true] | +| | RepartitionExec: partitioning=Hash([k@1], 32), input_partitions=32 | +| | ProjectionExec: expr=[i@0 as i, k@2 as k, CAST(j@1 AS Timestamp(ms)) as j] | +| | FilterExec: __common_expr_3@0 >= -299999 AND __common_expr_3@0 <= 10000, projection=[i@1, j@2, k@3] | +| | ProjectionExec: expr=[CAST(j@1 AS Timestamp(ms)) as __common_expr_3, i@0 as i, j@1 as j, k@2 as k] | +| | MergeScanExec: REDACTED +| | | ++---------------+-----------------------------------------------------------------------------------------------------------------------------------+ + +DROP TABLE test_nano; + +Affected Rows: 0 + DROP TABLE test; Affected Rows: 0 diff --git a/tests/cases/standalone/optimizer/time_index_filter_pushdown.result b/tests/cases/standalone/optimizer/time_index_filter_pushdown.result index 40948080a1..efce2008ad 100644 --- a/tests/cases/standalone/optimizer/time_index_filter_pushdown.result +++ b/tests/cases/standalone/optimizer/time_index_filter_pushdown.result @@ -59,7 +59,7 @@ EXPLAIN SELECT FROM cpu WHERE - usage_small IN (CAST(10 AS BIGINT), CAST(20 AS BIGINT)); + usage_small IN (10, 20); +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | @@ -74,6 +74,27 @@ WHERE | | | +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+ +-- SQLNESS REPLACE (peers.*) REDACTED +EXPLAIN SELECT + rack +FROM + cpu +WHERE + usage_small BETWEEN 10 AND 20; + ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+ +| plan_type | plan | ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+ +| logical_plan | MergeScan [is_placeholder=false, remote_input=[ | +| | Projection: cpu.rack | +| | Filter: cpu.usage_small >= Int16(10) AND cpu.usage_small <= Int16(20) | +| | TableScan: cpu | +| | ]] | +| physical_plan | CooperativeExec | +| | MergeScanExec: REDACTED +| | | ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------+ + -- SQLNESS SORT_RESULT 3 1 select count(*) diff --git a/tests/cases/standalone/optimizer/time_index_filter_pushdown.sql b/tests/cases/standalone/optimizer/time_index_filter_pushdown.sql index de16d8e607..777284af80 100644 --- a/tests/cases/standalone/optimizer/time_index_filter_pushdown.sql +++ b/tests/cases/standalone/optimizer/time_index_filter_pushdown.sql @@ -47,7 +47,15 @@ EXPLAIN SELECT FROM cpu WHERE - usage_small IN (CAST(10 AS BIGINT), CAST(20 AS BIGINT)); + usage_small IN (10, 20); + +-- SQLNESS REPLACE (peers.*) REDACTED +EXPLAIN SELECT + rack +FROM + cpu +WHERE + usage_small BETWEEN 10 AND 20; -- SQLNESS SORT_RESULT 3 1 select