From ffcd41adf8bc5bf952be327f5efb2df1c4ac9579 Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Wed, 25 Feb 2026 10:33:49 +0800 Subject: [PATCH] fix: handle scalar result in MultiDimPartitionRule (#7715) * fix: handle scalar result in MultiDimPartitionRule Signed-off-by: Lei, HUANG * add more complex test Signed-off-by: Lei, HUANG --------- Signed-off-by: Lei, HUANG --- src/partition/src/multi_dim.rs | 133 ++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 35 deletions(-) diff --git a/src/partition/src/multi_dim.rs b/src/partition/src/multi_dim.rs index a8c11d1ae8..7b1f7aa3dd 100644 --- a/src/partition/src/multi_dim.rs +++ b/src/partition/src/multi_dim.rs @@ -265,30 +265,18 @@ impl MultiDimPartitionRule { return Some(Err(e)); } }; - let ColumnarValue::Array(column) = col_val else { - unreachable!("Expected an array") + let array = match columnar_value_to_boolean_array(col_val, num_rows) { + Ok(array) => array, + Err(e) => { + return Some(Err(e)); + } }; - let array = - match column - .as_any() - .downcast_ref::() - .with_context(|| error::UnexpectedColumnTypeSnafu { - data_type: column.data_type().clone(), - }) { - Ok(array) => array, - Err(e) => { - return Some(Err(e)); - } - }; let selected_rows = array.true_count(); if selected_rows == 0 { // skip empty region in results. return None; } - Some(Ok(( - *region_num, - RegionMask::new(array.clone(), selected_rows), - ))) + Some(Ok((*region_num, RegionMask::new(array, selected_rows)))) }) .collect::>()?; @@ -329,6 +317,22 @@ impl MultiDimPartitionRule { } } +fn columnar_value_to_boolean_array( + col_val: ColumnarValue, + num_rows: usize, +) -> Result { + let column = col_val + .into_array(num_rows) + .context(error::EvaluateRecordBatchSnafu)?; + let array = column + .as_any() + .downcast_ref::() + .with_context(|| error::UnexpectedColumnTypeSnafu { + data_type: column.data_type().clone(), + })?; + Ok(array.clone()) +} + impl PartitionRule for MultiDimPartitionRule { fn as_any(&self) -> &dyn Any { self @@ -469,12 +473,12 @@ mod tests { } /// ```ignore - /// │ │ - /// │ │ + /// │ │ + /// │ │ /// ─────────┼──────────┼────────────► b - /// │ │ - /// │ │ - /// b <= h b >= s + /// │ │ + /// │ │ + /// b <= h b >= s /// ``` #[test] fn empty_expr_case_1() { @@ -505,18 +509,18 @@ mod tests { } /// ``` - /// a - /// ▲ - /// │ ‖ - /// │ ‖ - /// 200 │ ┌─────────┤ - /// │ │ │ - /// │ │ │ - /// │ │ │ - /// 100 │ ======┴─────────┘ - /// │ + /// a + /// ▲ + /// │ ‖ + /// │ ‖ + /// 200 │ ┌─────────┤ + /// │ │ │ + /// │ │ │ + /// │ │ │ + /// 100 │ ======┴─────────┘ + /// │ /// └──────────────────────────►b - /// 10 20 + /// 10 20 /// ``` #[test] fn empty_expr_case_2() { @@ -744,13 +748,14 @@ mod tests { mod test_split_record_batch { use std::sync::Arc; + use datafusion_common::ScalarValue; use datatypes::arrow::array::{Int64Array, StringArray}; use datatypes::arrow::datatypes::{DataType, Field, Schema}; use datatypes::arrow::record_batch::RecordBatch; use rand::Rng; use super::*; - use crate::expr::col; + use crate::expr::{Operand, col}; fn test_schema() -> Arc { Arc::new(Schema::new(vec![ @@ -889,4 +894,62 @@ mod test_split_record_batch { assert_eq!(result.get(&1).unwrap().selected_rows(), 2); // values < 30 assert_eq!(result.get(&2).unwrap().selected_rows(), 2); // values >= 30 } + + #[test] + fn test_split_record_batch_with_scalar_predicate() { + // Ensure split handles conjunctive/disjunctive predicates on the same column. + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string()], + vec![0, 1], + vec![ + PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Lt, + Operand::Value(Value::String("never_happen_1".into())), + ), + PartitionExpr::new( + Operand::Expr(PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::GtEq, + Operand::Value(Value::String("never_happen_1".into())), + )), + RestrictedOp::And, + Operand::Value(Value::Boolean(false)), + ), + ], + false, + ) + .unwrap(); + + let batch = generate_random_record_batch(8); + let result = rule.split_record_batch(&batch).unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key(&0)); + + let total_rows = result.get(&0).unwrap().selected_rows(); + assert_eq!(total_rows, batch.num_rows()); + } + + #[test] + fn test_columnar_value_to_boolean_array_scalar_false() { + let result = columnar_value_to_boolean_array( + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))), + 4, + ) + .unwrap(); + assert_eq!(result.len(), 4); + assert_eq!(result.true_count(), 0); + } + + #[test] + fn test_columnar_value_to_boolean_array_scalar_true() { + let result = columnar_value_to_boolean_array( + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), + 4, + ) + .unwrap(); + assert_eq!(result.len(), 4); + assert_eq!(result.true_count(), 4); + } }