diff --git a/Cargo.lock b/Cargo.lock index 26d6e0b96d..7181483de7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8055,6 +8055,7 @@ dependencies = [ "common-query", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr", "datatypes", "itertools 0.14.0", "serde", diff --git a/src/partition/Cargo.toml b/src/partition/Cargo.toml index 6402d2feff..084efcff5f 100644 --- a/src/partition/Cargo.toml +++ b/src/partition/Cargo.toml @@ -16,6 +16,7 @@ common-meta.workspace = true common-query.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-physical-expr.workspace = true datatypes.workspace = true itertools.workspace = true serde.workspace = true diff --git a/src/partition/src/expr.rs b/src/partition/src/expr.rs index bec9543e72..0a517cd2a6 100644 --- a/src/partition/src/expr.rs +++ b/src/partition/src/expr.rs @@ -13,12 +13,22 @@ // limitations under the License. use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; -use datatypes::value::Value; +use datafusion_common::{ScalarValue, ToDFSchema}; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::Expr; +use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; +use datatypes::arrow; +use datatypes::value::{ + duration_to_scalar_value, time_to_scalar_value, timestamp_to_scalar_value, Value, +}; use serde::{Deserialize, Serialize}; use sql::statements::value_to_sql_value; use sqlparser::ast::{BinaryOperator as ParserBinaryOperator, Expr as ParserExpr, Ident}; +use crate::error; + /// Struct for partition expression. This can be converted back to sqlparser's [Expr]. /// by [`Self::to_parser_expr`]. /// @@ -37,6 +47,48 @@ pub enum Operand { Expr(PartitionExpr), } +impl Operand { + pub fn try_as_logical_expr(&self) -> error::Result { + match self { + Self::Column(c) => Ok(datafusion_expr::col(c)), + Self::Value(v) => { + let scalar_value = match v { + Value::Boolean(v) => ScalarValue::Boolean(Some(*v)), + Value::UInt8(v) => ScalarValue::UInt8(Some(*v)), + Value::UInt16(v) => ScalarValue::UInt16(Some(*v)), + Value::UInt32(v) => ScalarValue::UInt32(Some(*v)), + Value::UInt64(v) => ScalarValue::UInt64(Some(*v)), + Value::Int8(v) => ScalarValue::Int8(Some(*v)), + Value::Int16(v) => ScalarValue::Int16(Some(*v)), + Value::Int32(v) => ScalarValue::Int32(Some(*v)), + Value::Int64(v) => ScalarValue::Int64(Some(*v)), + Value::Float32(v) => ScalarValue::Float32(Some(v.0)), + Value::Float64(v) => ScalarValue::Float64(Some(v.0)), + Value::String(v) => ScalarValue::Utf8(Some(v.as_utf8().to_string())), + Value::Binary(v) => ScalarValue::Binary(Some(v.to_vec())), + Value::Date(v) => ScalarValue::Date32(Some(v.val())), + Value::Null => ScalarValue::Null, + Value::Timestamp(t) => timestamp_to_scalar_value(t.unit(), Some(t.value())), + Value::Time(t) => time_to_scalar_value(*t.unit(), Some(t.value())).unwrap(), + Value::IntervalYearMonth(v) => ScalarValue::IntervalYearMonth(Some(v.to_i32())), + Value::IntervalDayTime(v) => ScalarValue::IntervalDayTime(Some((*v).into())), + Value::IntervalMonthDayNano(v) => { + ScalarValue::IntervalMonthDayNano(Some((*v).into())) + } + Value::Duration(d) => duration_to_scalar_value(d.unit(), Some(d.value())), + Value::Decimal128(d) => { + let (v, p, s) = d.to_scalar_value(); + ScalarValue::Decimal128(v, p, s) + } + _ => unreachable!(), + }; + Ok(datafusion_expr::lit(scalar_value)) + } + Self::Expr(e) => e.try_as_logical_expr(), + } + } +} + impl Display for Operand { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -140,6 +192,33 @@ impl PartitionExpr { right: Box::new(rhs), } } + + pub fn try_as_logical_expr(&self) -> error::Result { + let lhs = self.lhs.try_as_logical_expr()?; + let rhs = self.rhs.try_as_logical_expr()?; + + let expr = match &self.op { + RestrictedOp::And => datafusion_expr::and(lhs, rhs), + RestrictedOp::Or => datafusion_expr::or(lhs, rhs), + RestrictedOp::Gt => lhs.gt(rhs), + RestrictedOp::GtEq => lhs.gt_eq(rhs), + RestrictedOp::Lt => lhs.lt(rhs), + RestrictedOp::LtEq => lhs.lt_eq(rhs), + RestrictedOp::Eq => lhs.eq(rhs), + RestrictedOp::NotEq => lhs.not_eq(rhs), + }; + Ok(expr) + } + + pub fn try_as_physical_expr( + &self, + schema: &arrow::datatypes::SchemaRef, + ) -> error::Result> { + let df_schema = schema.clone().to_dfschema_ref().unwrap(); + let execution_props = &ExecutionProps::default(); + let expr = self.try_as_logical_expr()?; + Ok(create_physical_expr(&expr, &df_schema, execution_props).unwrap()) + } } impl Display for PartitionExpr { @@ -150,6 +229,10 @@ impl Display for PartitionExpr { #[cfg(test)] mod tests { + use std::sync::Arc; + + use datatypes::arrow::datatypes::{DataType, Field, Schema}; + use super::*; #[test] @@ -220,4 +303,16 @@ mod tests { assert_eq!(case.3, expr.to_string()); } } + + #[test] + fn test_to_physical_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let expr = PartitionExpr::new( + Operand::Column("a".to_string()), + RestrictedOp::Eq, + Operand::Value(Value::Int32(10)), + ); + let physical_expr = expr.try_as_physical_expr(&schema).unwrap(); + println!("{:?}", physical_expr); + } } diff --git a/src/partition/src/multi_dim.rs b/src/partition/src/multi_dim.rs index ae4044fe7e..1afdef69e1 100644 --- a/src/partition/src/multi_dim.rs +++ b/src/partition/src/multi_dim.rs @@ -16,8 +16,9 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::HashMap; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch}; -use datatypes::arrow::buffer::BooleanBuffer; use datatypes::prelude::Value; use datatypes::vectors::{Helper, VectorRef}; use serde::{Deserialize, Serialize}; @@ -168,7 +169,11 @@ impl MultiDimPartitionRule { let mut result = self .regions .iter() - .map(|region| (*region, BooleanBufferBuilder::new(num_rows))) + .map(|region| { + let mut builder = BooleanBufferBuilder::new(num_rows); + builder.append_n(num_rows, false); + (*region, builder) + }) .collect::>(); let cols = self.record_batch_to_cols(record_batch)?; @@ -192,7 +197,25 @@ impl MultiDimPartitionRule { &self, record_batch: &RecordBatch, ) -> Result> { - todo!() + Ok(self + .exprs + .iter() + .zip(self.regions.iter()) + .map(|(expr, region_num)| { + let df_expr = expr.try_as_physical_expr(&record_batch.schema()).unwrap(); + let ColumnarValue::Array(column) = df_expr.evaluate(record_batch).unwrap() else { + unreachable!("Expected an array") + }; + ( + *region_num, + column + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + ) + }) + .collect()) } } @@ -343,6 +366,7 @@ impl<'a> RuleChecker<'a> { #[cfg(test)] mod tests { use std::assert_matches::assert_matches; + use std::sync::Arc; use super::*; use crate::error::{self, Error}; @@ -692,4 +716,164 @@ mod tests { // check rule assert!(rule.is_err()); } + + #[test] + fn test_split_record_batch_by_one_column() { + use datatypes::arrow::array::{Int64Array, StringArray}; + use datatypes::arrow::datatypes::{DataType, Field, Schema}; + use datatypes::arrow::record_batch::RecordBatch; + + // Create a simple MultiDimPartitionRule + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![0, 1], + vec![ + PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Lt, + Operand::Value(Value::String("server1".into())), + ), + PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::GtEq, + Operand::Value(Value::String("server1".into())), + ), + ], + ) + .unwrap(); + + // Create a record batch with test data + let schema = Arc::new(Schema::new(vec![ + Field::new("host", DataType::Utf8, false), + Field::new("value", DataType::Int64, false), + ])); + + let host_array = StringArray::from(vec!["server1", "server2", "server3", "server1"]); + let value_array = Int64Array::from(vec![10, 20, 30, 40]); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + // Split the batch + let result = rule.split_record_batch(&batch).unwrap(); + let expected = rule.split_record_batch_naive(&batch).unwrap(); + assert_eq!(result.len(), expected.len()); + for (region, value) in &result { + assert_eq!( + value, + expected.get(region).unwrap(), + "failed on region: {}", + region + ); + } + } + + #[test] + fn test_split_record_batch_empty() { + use datatypes::arrow::array::{Int64Array, StringArray}; + use datatypes::arrow::datatypes::{DataType, Field, Schema}; + use datatypes::arrow::record_batch::RecordBatch; + + // Create a simple MultiDimPartitionRule + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string()], + vec![1], + vec![PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Eq, + Operand::Value(Value::String("server1".into())), + )], + ) + .unwrap(); + + // Create an empty record batch + let schema = Arc::new(Schema::new(vec![ + Field::new("host", DataType::Utf8, false), + Field::new("value", DataType::Int64, false), + ])); + + let host_array = StringArray::from(Vec::<&str>::new()); + let value_array = Int64Array::from(Vec::::new()); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + // Split the batch + let result = rule.split_record_batch(&batch).unwrap(); + + // Empty batch should result in empty map + assert_eq!(result.len(), 0); + } + + #[test] + fn test_split_record_batch_complex_condition() { + use datatypes::arrow::array::{Int64Array, StringArray}; + use datatypes::arrow::datatypes::{DataType, Field, Schema}; + use datatypes::arrow::record_batch::RecordBatch; + + // Create a rule with more complex conditions + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![1, 2], + vec![ + // Region 1: host = 'server1' AND value > 20 + PartitionExpr::new( + Operand::Expr(PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Eq, + Operand::Value(Value::String("server1".into())), + )), + RestrictedOp::And, + Operand::Expr(PartitionExpr::new( + Operand::Column("value".into()), + RestrictedOp::Gt, + Operand::Value(Value::Int64(20)), + )), + ), + // Region 2: host = 'server2' + PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Eq, + Operand::Value(Value::String("server2".into())), + ), + ], + ) + .unwrap(); + + // Create a record batch with test data + let schema = Arc::new(Schema::new(vec![ + Field::new("host", DataType::Utf8, false), + Field::new("value", DataType::Int64, false), + ])); + + let host_array = StringArray::from(vec!["server1", "server1", "server2", "server3"]); + let value_array = Int64Array::from(vec![10, 30, 20, 40]); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + // Split the batch + let result = rule.split_record_batch(&batch).unwrap(); + + // Check the results + assert_eq!(result.len(), 3); // 2 defined regions + 1 default region + + // Check region 1 (server1 AND value > 20) + let region1_array = result.get(&1).unwrap(); + assert_eq!(region1_array.len(), 1); + assert_eq!( + region1_array.iter().map(|b| b.unwrap()).collect::>(), + vec![true] + ); + + // Check region 2 (server2) + assert!(result.contains_key(&2)); + let region2_batch = result.get(&2).unwrap(); + assert_eq!(region2_batch.len(), 1); + + // Check default region (server1 with value <= 20 and server3) + assert!(result.contains_key(&0)); + let default_batch = result.get(&0).unwrap(); + assert_eq!(default_batch.len(), 2); + } }