feat/column-partition:

### Add support for DataFusion physical expressions

 - **`Cargo.lock` & `Cargo.toml`**: Added `datafusion-physical-expr` as a dependency to support physical expression creation.
 - **`expr.rs`**: Implemented conversion methods `try_as_logical_expr` and `try_as_physical_expr` for `Operand` and `PartitionExpr` to facilitate logical and physical expression handling.
 - **`multi_dim.rs`**: Enhanced `MultiDimPartitionRule` to utilize physical expressions for partitioning logic, including new methods for evaluating record batches.
 - **Tests**: Added unit tests for logical and physical expression conversions and partitioning logic in `expr.rs` and `multi_dim.rs`.
This commit is contained in:
Lei, HUANG
2025-03-28 06:30:17 +00:00
parent 11c5cb44d8
commit 404df92a60
4 changed files with 285 additions and 4 deletions

1
Cargo.lock generated
View File

@@ -8055,6 +8055,7 @@ dependencies = [
"common-query",
"datafusion-common",
"datafusion-expr",
"datafusion-physical-expr",
"datatypes",
"itertools 0.14.0",
"serde",

View File

@@ -16,6 +16,7 @@ common-meta.workspace = true
common-query.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datatypes.workspace = true
itertools.workspace = true
serde.workspace = true

View File

@@ -13,12 +13,22 @@
// limitations under the License.
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use datatypes::value::Value;
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::Expr;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datatypes::arrow;
use datatypes::value::{
duration_to_scalar_value, time_to_scalar_value, timestamp_to_scalar_value, Value,
};
use serde::{Deserialize, Serialize};
use sql::statements::value_to_sql_value;
use sqlparser::ast::{BinaryOperator as ParserBinaryOperator, Expr as ParserExpr, Ident};
use crate::error;
/// Struct for partition expression. This can be converted back to sqlparser's [Expr].
/// by [`Self::to_parser_expr`].
///
@@ -37,6 +47,48 @@ pub enum Operand {
Expr(PartitionExpr),
}
impl Operand {
pub fn try_as_logical_expr(&self) -> error::Result<Expr> {
match self {
Self::Column(c) => Ok(datafusion_expr::col(c)),
Self::Value(v) => {
let scalar_value = match v {
Value::Boolean(v) => ScalarValue::Boolean(Some(*v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(*v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(*v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(*v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(*v)),
Value::Int8(v) => ScalarValue::Int8(Some(*v)),
Value::Int16(v) => ScalarValue::Int16(Some(*v)),
Value::Int32(v) => ScalarValue::Int32(Some(*v)),
Value::Int64(v) => ScalarValue::Int64(Some(*v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::Utf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::Binary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v.val())),
Value::Null => ScalarValue::Null,
Value::Timestamp(t) => timestamp_to_scalar_value(t.unit(), Some(t.value())),
Value::Time(t) => time_to_scalar_value(*t.unit(), Some(t.value())).unwrap(),
Value::IntervalYearMonth(v) => ScalarValue::IntervalYearMonth(Some(v.to_i32())),
Value::IntervalDayTime(v) => ScalarValue::IntervalDayTime(Some((*v).into())),
Value::IntervalMonthDayNano(v) => {
ScalarValue::IntervalMonthDayNano(Some((*v).into()))
}
Value::Duration(d) => duration_to_scalar_value(d.unit(), Some(d.value())),
Value::Decimal128(d) => {
let (v, p, s) = d.to_scalar_value();
ScalarValue::Decimal128(v, p, s)
}
_ => unreachable!(),
};
Ok(datafusion_expr::lit(scalar_value))
}
Self::Expr(e) => e.try_as_logical_expr(),
}
}
}
impl Display for Operand {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
@@ -140,6 +192,33 @@ impl PartitionExpr {
right: Box::new(rhs),
}
}
pub fn try_as_logical_expr(&self) -> error::Result<Expr> {
let lhs = self.lhs.try_as_logical_expr()?;
let rhs = self.rhs.try_as_logical_expr()?;
let expr = match &self.op {
RestrictedOp::And => datafusion_expr::and(lhs, rhs),
RestrictedOp::Or => datafusion_expr::or(lhs, rhs),
RestrictedOp::Gt => lhs.gt(rhs),
RestrictedOp::GtEq => lhs.gt_eq(rhs),
RestrictedOp::Lt => lhs.lt(rhs),
RestrictedOp::LtEq => lhs.lt_eq(rhs),
RestrictedOp::Eq => lhs.eq(rhs),
RestrictedOp::NotEq => lhs.not_eq(rhs),
};
Ok(expr)
}
pub fn try_as_physical_expr(
&self,
schema: &arrow::datatypes::SchemaRef,
) -> error::Result<Arc<dyn PhysicalExpr>> {
let df_schema = schema.clone().to_dfschema_ref().unwrap();
let execution_props = &ExecutionProps::default();
let expr = self.try_as_logical_expr()?;
Ok(create_physical_expr(&expr, &df_schema, execution_props).unwrap())
}
}
impl Display for PartitionExpr {
@@ -150,6 +229,10 @@ impl Display for PartitionExpr {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use super::*;
#[test]
@@ -220,4 +303,16 @@ mod tests {
assert_eq!(case.3, expr.to_string());
}
}
#[test]
fn test_to_physical_expr() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let expr = PartitionExpr::new(
Operand::Column("a".to_string()),
RestrictedOp::Eq,
Operand::Value(Value::Int32(10)),
);
let physical_expr = expr.try_as_physical_expr(&schema).unwrap();
println!("{:?}", physical_expr);
}
}

View File

@@ -16,8 +16,9 @@ use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
use datatypes::arrow::buffer::BooleanBuffer;
use datatypes::prelude::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::{Deserialize, Serialize};
@@ -168,7 +169,11 @@ impl MultiDimPartitionRule {
let mut result = self
.regions
.iter()
.map(|region| (*region, BooleanBufferBuilder::new(num_rows)))
.map(|region| {
let mut builder = BooleanBufferBuilder::new(num_rows);
builder.append_n(num_rows, false);
(*region, builder)
})
.collect::<HashMap<_, _>>();
let cols = self.record_batch_to_cols(record_batch)?;
@@ -192,7 +197,25 @@ impl MultiDimPartitionRule {
&self,
record_batch: &RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
todo!()
Ok(self
.exprs
.iter()
.zip(self.regions.iter())
.map(|(expr, region_num)| {
let df_expr = expr.try_as_physical_expr(&record_batch.schema()).unwrap();
let ColumnarValue::Array(column) = df_expr.evaluate(record_batch).unwrap() else {
unreachable!("Expected an array")
};
(
*region_num,
column
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone(),
)
})
.collect())
}
}
@@ -343,6 +366,7 @@ impl<'a> RuleChecker<'a> {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use super::*;
use crate::error::{self, Error};
@@ -692,4 +716,164 @@ mod tests {
// check rule
assert!(rule.is_err());
}
#[test]
fn test_split_record_batch_by_one_column() {
use datatypes::arrow::array::{Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::arrow::record_batch::RecordBatch;
// Create a simple MultiDimPartitionRule
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string(), "value".to_string()],
vec![0, 1],
vec![
PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::Lt,
Operand::Value(Value::String("server1".into())),
),
PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::GtEq,
Operand::Value(Value::String("server1".into())),
),
],
)
.unwrap();
// Create a record batch with test data
let schema = Arc::new(Schema::new(vec![
Field::new("host", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]));
let host_array = StringArray::from(vec!["server1", "server2", "server3", "server1"]);
let value_array = Int64Array::from(vec![10, 20, 30, 40]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
.unwrap();
// Split the batch
let result = rule.split_record_batch(&batch).unwrap();
let expected = rule.split_record_batch_naive(&batch).unwrap();
assert_eq!(result.len(), expected.len());
for (region, value) in &result {
assert_eq!(
value,
expected.get(region).unwrap(),
"failed on region: {}",
region
);
}
}
#[test]
fn test_split_record_batch_empty() {
use datatypes::arrow::array::{Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::arrow::record_batch::RecordBatch;
// Create a simple MultiDimPartitionRule
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string()],
vec![1],
vec![PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::Eq,
Operand::Value(Value::String("server1".into())),
)],
)
.unwrap();
// Create an empty record batch
let schema = Arc::new(Schema::new(vec![
Field::new("host", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]));
let host_array = StringArray::from(Vec::<&str>::new());
let value_array = Int64Array::from(Vec::<i64>::new());
let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
.unwrap();
// Split the batch
let result = rule.split_record_batch(&batch).unwrap();
// Empty batch should result in empty map
assert_eq!(result.len(), 0);
}
#[test]
fn test_split_record_batch_complex_condition() {
use datatypes::arrow::array::{Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::arrow::record_batch::RecordBatch;
// Create a rule with more complex conditions
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string(), "value".to_string()],
vec![1, 2],
vec![
// Region 1: host = 'server1' AND value > 20
PartitionExpr::new(
Operand::Expr(PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::Eq,
Operand::Value(Value::String("server1".into())),
)),
RestrictedOp::And,
Operand::Expr(PartitionExpr::new(
Operand::Column("value".into()),
RestrictedOp::Gt,
Operand::Value(Value::Int64(20)),
)),
),
// Region 2: host = 'server2'
PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::Eq,
Operand::Value(Value::String("server2".into())),
),
],
)
.unwrap();
// Create a record batch with test data
let schema = Arc::new(Schema::new(vec![
Field::new("host", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]));
let host_array = StringArray::from(vec!["server1", "server1", "server2", "server3"]);
let value_array = Int64Array::from(vec![10, 30, 20, 40]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
.unwrap();
// Split the batch
let result = rule.split_record_batch(&batch).unwrap();
// Check the results
assert_eq!(result.len(), 3); // 2 defined regions + 1 default region
// Check region 1 (server1 AND value > 20)
let region1_array = result.get(&1).unwrap();
assert_eq!(region1_array.len(), 1);
assert_eq!(
region1_array.iter().map(|b| b.unwrap()).collect::<Vec<_>>(),
vec![true]
);
// Check region 2 (server2)
assert!(result.contains_key(&2));
let region2_batch = result.get(&2).unwrap();
assert_eq!(region2_batch.len(), 1);
// Check default region (server1 with value <= 20 and server3)
assert!(result.contains_key(&0));
let default_batch = result.get(&0).unwrap();
assert_eq!(default_batch.len(), 2);
}
}