feat: Column-wise partition rule implementation (#5804)

* wip: naive impl

* feat/column-partition:
 ### Add support for DataFusion physical expressions

 - **`Cargo.lock` & `Cargo.toml`**: Added `datafusion-physical-expr` as a dependency to support physical expression creation.
 - **`expr.rs`**: Implemented conversion methods `try_as_logical_expr` and `try_as_physical_expr` for `Operand` and `PartitionExpr` to facilitate logical and physical expression handling.
 - **`multi_dim.rs`**: Enhanced `MultiDimPartitionRule` to utilize physical expressions for partitioning logic, including new methods for evaluating record batches.
 - **Tests**: Added unit tests for logical and physical expression conversions and partitioning logic in `expr.rs` and `multi_dim.rs`.

* feat/column-partition:
 ### Refactor and Enhance Partition Handling

 - **Refactor Partition Parsing Logic**: Moved partition parsing logic from `src/operator/src/statement/ddl.rs` to a new utility module `src/partition/src/utils.rs`. This includes functions like `parse_partitions`, `find_partition_bounds`, and `convert_one_expr`.
 - **Error Handling Improvements**: Added new error variants `ColumnNotFound`, `InvalidPartitionRule`, and `ParseSqlValue` in `src/partition/src/error.rs` to improve error reporting for partition-related operations.
 - **Dependency Updates**: Updated `Cargo.lock` and `Cargo.toml` to include new dependencies `common-time` and `session`.
 - **Code Cleanup**: Removed redundant partition parsing functions from `src/operator/src/error.rs` and `src/operator/src/statement/ddl.rs`.

* feat/column-partition:
 ## Refactor and Enhance SQL and Table Handling

 - **Refactor Column Definitions and Error Handling**
   - Made `FULLTEXT_GRPC_KEY`, `INVERTED_INDEX_GRPC_KEY`, and `SKIPPING_INDEX_GRPC_KEY` public in `column_def.rs`.
   - Removed `IllegalPrimaryKeysDef` error from `error.rs` and moved it to `sql/src/error.rs`.
   - Updated error handling in `fill_impure_default.rs` and `expr_helper.rs`.

 - **Enhance SQL Utility Functions**
   - Moved and refactored functions like `create_to_expr`, `find_primary_keys`, and `validate_create_expr` to `sql/src/util.rs`.
   - Added new utility functions for SQL parsing and validation in `sql/src/util.rs`.

 - **Improve Partition Handling**
   - Added `parse_partition_columns_and_exprs` function in `partition/src/utils.rs`.
   - Updated partition rule tests in `partition/src/multi_dim.rs` to use SQL-based partitioning.

 - **Simplify Table Name Handling**
   - Re-exported `table_idents_to_full_name` from `sql::util` in `session/src/table_name.rs`.

 - **Test Enhancements**
   - Updated tests in `partition/src/multi_dim.rs` to use SQL for partition rule creation.

* feat/column-partition:
 **Add Benchmarking and Enhance Partitioning Logic**

 - **Benchmarking**: Introduced a new benchmark for `split_record_batch` in `bench_split_record_batch.rs` using `criterion` and `rand` as development dependencies in `Cargo.toml`.
 - **Partitioning Logic**: Enhanced `MultiDimPartitionRule` in `multi_dim.rs` to include a default region for unmatched partition expressions and optimized the `split_record_batch` method.
 - **Refactoring**: Moved `sql_to_partition_rule` function to a public scope for reuse in `multi_dim.rs`.
 - **Testing**: Added new test module `test_split_record_batch` to validate the partitioning logic.

* Revert "feat/column-partition:  ### Refactor and Enhance Partition Handling"

This reverts commit 183fa19f

* fix: revert refctoring parse_partition

* revert some refactor

* feat/column-partition:
 ### Enhance Partitioning and Error Handling

 - **Benchmark Enhancements**: Added new benchmark `bench_split_record_batch_vs_row` in `bench_split_record_batch.rs` to compare row and column-based splitting.
 - **Error Handling Improvements**: Introduced new error variants in `error.rs` for better error reporting related to record batch evaluation and arrow kernel computation.
 - **Expression Handling**: Updated `expr.rs` to improve error context when converting schemas and creating physical expressions.
 - **Partition Rule Enhancements**: Made `row_at` and `record_batch_to_cols` methods public in `multi_dim.rs` and improved error handling for physical expression evaluation and boolean operations.

* feat/column-partition:
 ### Add `eq` Method and Optimize Expression Caching

 - **`expr.rs`**: Added a new `eq` method to the `Operand` struct for equality comparisons.
 - **`multi_dim.rs`**: Introduced a caching mechanism for physical expressions using `RwLock` to improve performance in `MultiDimPartitionRule`.
 - **`lib.rs`**: Enabled the `let_chains` feature for more concise code.
 - **`multi_dim.rs` Tests**: Enhanced test coverage with new test cases for multi-dimensional partitioning, including random record batch generation and default region handling.

* feat/column-partition:
 ### Add `split_record_batch` Method to `PartitionRule` Trait

 - **Files Modified**:
   - `src/partition/src/multi_dim.rs`
   - `src/partition/src/partition.rs`
   - `src/partition/src/splitter.rs`

 Added a new method `split_record_batch` to the `PartitionRule` trait, allowing record batches to be split into multiple regions based on partition values. Implemented this method in `MultiDimPartitionRule` and provided unimplemented stubs in test modules.

 ### Dependency Update

 - **File Modified**:
   - `src/operator/src/expr_helper.rs`

 Removed unused import `ColumnDataType` and `Timezone` from the test module.

 ### Miscellaneous

 - **File Modified**:
   - `src/partition/Cargo.toml`

 No functional changes; only minor formatting adjustments.

* chore: add license header

* chore: remove useless fules

* feat/column-partition:
 Add support for handling unsupported partition expression values

 - **`error.rs`**: Introduced a new error variant `UnsupportedPartitionExprValue` to handle unsupported partition expression values, and updated `ErrorExt` to map this error to `StatusCode::InvalidArguments`.
 - **`expr.rs`**: Modified the `Operand` implementation to return the new error when encountering unsupported partition expression values.
 - **`multi_dim.rs`**: Added a fast path to optimize the selection process when all rows are selected.

* feat/column-partition: Add validation for expression and region length in MultiDimPartitionRule constructor

 • Ensure the lengths of exprs and regions match to prevent mismatches.
 • Introduce error handling for length discrepancies with a descriptive error message.

* chore: add debug log

* feat/column-partition: Removed the validation check for matching lengths between exprs and regions in MultiDimPartitionRule constructor, simplifying the initialization process.

* fix: unit tests
This commit is contained in:
Lei, HUANG
2025-04-15 18:42:07 +08:00
committed by GitHub
parent 032df4c533
commit 6700c0762d
9 changed files with 753 additions and 5 deletions

3
Cargo.lock generated
View File

@@ -8138,10 +8138,13 @@ dependencies = [
"common-macro",
"common-meta",
"common-query",
"criterion 0.5.1",
"datafusion-common",
"datafusion-expr",
"datafusion-physical-expr",
"datatypes",
"itertools 0.14.0",
"rand 0.8.5",
"serde",
"serde_json",
"session",

View File

@@ -16,6 +16,7 @@ common-meta.workspace = true
common-query.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datatypes.workspace = true
itertools.workspace = true
serde.workspace = true
@@ -26,3 +27,11 @@ sql.workspace = true
sqlparser.workspace = true
store-api.workspace = true
table.workspace = true
[dev-dependencies]
criterion = "0.5"
rand = "0.8"
[[bench]]
name = "bench_split_record_batch"
harness = false

View File

@@ -0,0 +1,226 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use std::vec;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datatypes::arrow::array::{ArrayRef, Int32Array, StringArray, TimestampMillisecondArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datatypes::arrow::record_batch::RecordBatch;
use datatypes::value::Value;
use partition::expr::{col, Operand};
use partition::multi_dim::MultiDimPartitionRule;
use partition::PartitionRule;
use rand::Rng;
use store_api::storage::RegionNumber;
fn table_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("a0", DataType::Int32, false),
Field::new("a1", DataType::Utf8, false),
Field::new("a2", DataType::Int32, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
]))
}
fn create_test_rule(num_columns: usize) -> MultiDimPartitionRule {
let (columns, exprs) = match num_columns {
1 => {
let exprs = vec![
col("a0").lt(Value::Int32(50)),
col("a0").gt_eq(Value::Int32(50)),
];
(vec!["a0".to_string()], exprs)
}
2 => {
let exprs = vec![
col("a0")
.lt(Value::Int32(50))
.and(col("a1").lt(Value::String("server50".into()))),
col("a0")
.lt(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into()))),
col("a0")
.gt_eq(Value::Int32(50))
.and(col("a1").lt(Value::String("server50".into()))),
col("a0")
.gt_eq(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into()))),
];
(vec!["a0".to_string(), "a1".to_string()], exprs)
}
3 => {
let expr = vec![
col("a0")
.lt(Value::Int32(50))
.and(col("a1").lt(Value::String("server50".into())))
.and(col("a2").lt(Value::Int32(50))),
col("a0")
.lt(Operand::Value(Value::Int32(50)))
.and(col("a1").lt(Value::String("server50".into())))
.and(col("a2").gt_eq(Value::Int32(50))),
col("a0")
.lt(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into())))
.and(col("a2").lt(Value::Int32(50))),
col("a0")
.lt(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into())))
.and(col("a2").gt_eq(Value::Int32(50))),
col("a0")
.gt_eq(Value::Int32(50))
.and(col("a1").lt(Value::String("server50".into())))
.and(col("a2").lt(Value::Int32(50))),
col("a0")
.gt_eq(Operand::Value(Value::Int32(50)))
.and(col("a1").lt(Value::String("server50".into())))
.and(col("a2").gt_eq(Value::Int32(50))),
col("a0")
.gt_eq(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into())))
.and(col("a2").lt(Value::Int32(50))),
col("a0")
.gt_eq(Value::Int32(50))
.and(col("a1").gt_eq(Value::String("server50".into())))
.and(col("a2").gt_eq(Value::Int32(50))),
];
(
vec!["a0".to_string(), "a1".to_string(), "a2".to_string()],
expr,
)
}
_ => {
panic!("invalid number of columns, only 1-3 are supported");
}
};
let regions = (0..exprs.len()).map(|v| v as u32).collect();
MultiDimPartitionRule::try_new(columns, regions, exprs).unwrap()
}
fn create_test_batch(size: usize) -> RecordBatch {
let mut rng = rand::thread_rng();
let schema = table_schema();
let arrays: Vec<ArrayRef> = (0..3)
.map(|col_idx| {
if col_idx % 2 == 0 {
// Integer columns (a0, a2)
Arc::new(Int32Array::from_iter_values(
(0..size).map(|_| rng.gen_range(0..100)),
)) as ArrayRef
} else {
// String columns (a1)
let values: Vec<String> = (0..size)
.map(|_| {
let server_id: i32 = rng.gen_range(0..100);
format!("server{}", server_id)
})
.collect();
Arc::new(StringArray::from(values)) as ArrayRef
}
})
.chain(std::iter::once({
// Timestamp column (ts)
Arc::new(TimestampMillisecondArray::from_iter_values(
(0..size).map(|idx| idx as i64),
)) as ArrayRef
}))
.collect();
RecordBatch::try_new(schema, arrays).unwrap()
}
fn bench_split_record_batch_naive_vs_optimized(c: &mut Criterion) {
let mut group = c.benchmark_group("split_record_batch");
for num_columns in [1, 2, 3].iter() {
for num_rows in [100, 1000, 10000, 100000].iter() {
let rule = create_test_rule(*num_columns);
let batch = create_test_batch(*num_rows);
group.bench_function(format!("naive_{}_{}", num_columns, num_rows), |b| {
b.iter(|| {
black_box(rule.split_record_batch_naive(black_box(&batch))).unwrap();
});
});
group.bench_function(format!("optimized_{}_{}", num_columns, num_rows), |b| {
b.iter(|| {
black_box(rule.split_record_batch(black_box(&batch))).unwrap();
});
});
}
}
group.finish();
}
fn record_batch_to_rows(
rule: &MultiDimPartitionRule,
record_batch: &RecordBatch,
) -> Vec<Vec<Value>> {
let num_rows = record_batch.num_rows();
let vectors = rule.record_batch_to_cols(record_batch).unwrap();
let mut res = Vec::with_capacity(num_rows);
let mut current_row = vec![Value::Null; vectors.len()];
for row in 0..num_rows {
rule.row_at(&vectors, row, &mut current_row).unwrap();
res.push(current_row.clone());
}
res
}
fn find_all_regions(rule: &MultiDimPartitionRule, rows: &[Vec<Value>]) -> Vec<RegionNumber> {
rows.iter()
.map(|row| rule.find_region(row).unwrap())
.collect()
}
fn bench_split_record_batch_vs_row(c: &mut Criterion) {
let mut group = c.benchmark_group("bench_split_record_batch_vs_row");
for num_columns in [1, 2, 3].iter() {
for num_rows in [100, 1000, 10000, 100000].iter() {
let rule = create_test_rule(*num_columns);
let batch = create_test_batch(*num_rows);
let rows = record_batch_to_rows(&rule, &batch);
group.bench_function(format!("split_by_row_{}_{}", num_columns, num_rows), |b| {
b.iter(|| {
black_box(find_all_regions(&rule, &rows));
});
});
group.bench_function(format!("split_by_col_{}_{}", num_columns, num_rows), |b| {
b.iter(|| {
black_box(rule.split_record_batch(black_box(&batch))).unwrap();
});
});
}
}
group.finish();
}
criterion_group!(
benches,
bench_split_record_batch_naive_vs_optimized,
bench_split_record_batch_vs_row
);
criterion_main!(benches);

View File

@@ -18,6 +18,8 @@ use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use common_macro::stack_trace_debug;
use datafusion_common::ScalarValue;
use datatypes::arrow;
use datatypes::prelude::Value;
use snafu::{Location, Snafu};
use store_api::storage::RegionId;
use table::metadata::TableId;
@@ -173,6 +175,59 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to convert to vector"))]
ConvertToVector {
source: datatypes::error::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to evaluate record batch"))]
EvaluateRecordBatch {
#[snafu(source)]
error: datafusion_common::error::DataFusionError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to compute arrow kernel"))]
ComputeArrowKernel {
#[snafu(source)]
error: arrow::error::ArrowError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Unexpected evaluation result column type: {}", data_type))]
UnexpectedColumnType {
data_type: arrow::datatypes::DataType,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to convert to DataFusion's Schema"))]
ToDFSchema {
#[snafu(source)]
error: datafusion_common::error::DataFusionError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to create physical expression"))]
CreatePhysicalExpr {
#[snafu(source)]
error: datafusion_common::error::DataFusionError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Partition expr value is not supported: {:?}", value))]
UnsupportedPartitionExprValue {
value: Value,
#[snafu(implicit)]
location: Location,
},
}
impl ErrorExt for Error {
@@ -201,6 +256,13 @@ impl ErrorExt for Error {
Error::TableRouteNotFound { .. } => StatusCode::TableNotFound,
Error::TableRouteManager { source, .. } => source.status_code(),
Error::UnexpectedLogicalRouteTable { source, .. } => source.status_code(),
Error::ConvertToVector { source, .. } => source.status_code(),
Error::EvaluateRecordBatch { .. } => StatusCode::Internal,
Error::ComputeArrowKernel { .. } => StatusCode::Internal,
Error::UnexpectedColumnType { .. } => StatusCode::Internal,
Error::ToDFSchema { .. } => StatusCode::Internal,
Error::CreatePhysicalExpr { .. } => StatusCode::Internal,
Error::UnsupportedPartitionExprValue { .. } => StatusCode::InvalidArguments,
}
}

View File

@@ -13,12 +13,23 @@
// limitations under the License.
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use datatypes::value::Value;
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::Expr;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datatypes::arrow;
use datatypes::value::{
duration_to_scalar_value, time_to_scalar_value, timestamp_to_scalar_value, Value,
};
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use sql::statements::value_to_sql_value;
use sqlparser::ast::{BinaryOperator as ParserBinaryOperator, Expr as ParserExpr, Ident};
use crate::error;
/// Struct for partition expression. This can be converted back to sqlparser's [Expr].
/// by [`Self::to_parser_expr`].
///
@@ -37,6 +48,75 @@ pub enum Operand {
Expr(PartitionExpr),
}
pub fn col(column_name: impl Into<String>) -> Operand {
Operand::Column(column_name.into())
}
impl From<Value> for Operand {
fn from(value: Value) -> Self {
Operand::Value(value)
}
}
impl Operand {
pub fn try_as_logical_expr(&self) -> error::Result<Expr> {
match self {
Self::Column(c) => Ok(datafusion_expr::col(c)),
Self::Value(v) => {
let scalar_value = match v {
Value::Boolean(v) => ScalarValue::Boolean(Some(*v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(*v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(*v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(*v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(*v)),
Value::Int8(v) => ScalarValue::Int8(Some(*v)),
Value::Int16(v) => ScalarValue::Int16(Some(*v)),
Value::Int32(v) => ScalarValue::Int32(Some(*v)),
Value::Int64(v) => ScalarValue::Int64(Some(*v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::Utf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::Binary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v.val())),
Value::Null => ScalarValue::Null,
Value::Timestamp(t) => timestamp_to_scalar_value(t.unit(), Some(t.value())),
Value::Time(t) => time_to_scalar_value(*t.unit(), Some(t.value())).unwrap(),
Value::IntervalYearMonth(v) => ScalarValue::IntervalYearMonth(Some(v.to_i32())),
Value::IntervalDayTime(v) => ScalarValue::IntervalDayTime(Some((*v).into())),
Value::IntervalMonthDayNano(v) => {
ScalarValue::IntervalMonthDayNano(Some((*v).into()))
}
Value::Duration(d) => duration_to_scalar_value(d.unit(), Some(d.value())),
Value::Decimal128(d) => {
let (v, p, s) = d.to_scalar_value();
ScalarValue::Decimal128(v, p, s)
}
other => {
return error::UnsupportedPartitionExprValueSnafu {
value: other.clone(),
}
.fail()
}
};
Ok(datafusion_expr::lit(scalar_value))
}
Self::Expr(e) => e.try_as_logical_expr(),
}
}
pub fn lt(self, rhs: impl Into<Self>) -> PartitionExpr {
PartitionExpr::new(self, RestrictedOp::Lt, rhs.into())
}
pub fn gt_eq(self, rhs: impl Into<Self>) -> PartitionExpr {
PartitionExpr::new(self, RestrictedOp::GtEq, rhs.into())
}
pub fn eq(self, rhs: impl Into<Self>) -> PartitionExpr {
PartitionExpr::new(self, RestrictedOp::Eq, rhs.into())
}
}
impl Display for Operand {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
@@ -140,6 +220,41 @@ impl PartitionExpr {
right: Box::new(rhs),
}
}
pub fn try_as_logical_expr(&self) -> error::Result<Expr> {
let lhs = self.lhs.try_as_logical_expr()?;
let rhs = self.rhs.try_as_logical_expr()?;
let expr = match &self.op {
RestrictedOp::And => datafusion_expr::and(lhs, rhs),
RestrictedOp::Or => datafusion_expr::or(lhs, rhs),
RestrictedOp::Gt => lhs.gt(rhs),
RestrictedOp::GtEq => lhs.gt_eq(rhs),
RestrictedOp::Lt => lhs.lt(rhs),
RestrictedOp::LtEq => lhs.lt_eq(rhs),
RestrictedOp::Eq => lhs.eq(rhs),
RestrictedOp::NotEq => lhs.not_eq(rhs),
};
Ok(expr)
}
pub fn try_as_physical_expr(
&self,
schema: &arrow::datatypes::SchemaRef,
) -> error::Result<Arc<dyn PhysicalExpr>> {
let df_schema = schema
.clone()
.to_dfschema_ref()
.context(error::ToDFSchemaSnafu)?;
let execution_props = &ExecutionProps::default();
let expr = self.try_as_logical_expr()?;
create_physical_expr(&expr, &df_schema, execution_props)
.context(error::CreatePhysicalExprSnafu)
}
pub fn and(self, rhs: PartitionExpr) -> PartitionExpr {
PartitionExpr::new(Operand::Expr(self), RestrictedOp::And, Operand::Expr(rhs))
}
}
impl Display for PartitionExpr {

View File

@@ -13,7 +13,7 @@
// limitations under the License.
#![feature(assert_matches)]
#![feature(let_chains)]
//! Structs and traits for partitioning rule.
pub mod error;

View File

@@ -15,10 +15,18 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::arrow;
use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
use datatypes::arrow::buffer::BooleanBuffer;
use datatypes::arrow::datatypes::Schema;
use datatypes::prelude::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt};
use snafu::{ensure, OptionExt, ResultExt};
use store_api::storage::RegionNumber;
use crate::error::{
@@ -28,6 +36,11 @@ use crate::error::{
use crate::expr::{Operand, PartitionExpr, RestrictedOp};
use crate::PartitionRule;
/// The default region number when no partition exprs are matched.
const DEFAULT_REGION: RegionNumber = 0;
type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
/// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md)
///
/// This partition rule is defined by a set of simple expressions on the partition
@@ -44,6 +57,9 @@ pub struct MultiDimPartitionRule {
regions: Vec<RegionNumber>,
/// Partition expressions.
exprs: Vec<PartitionExpr>,
/// Cache of physical expressions.
#[serde(skip)]
physical_expr_cache: RwLock<PhysicalExprCache>,
}
impl MultiDimPartitionRule {
@@ -63,6 +79,7 @@ impl MultiDimPartitionRule {
name_to_index,
regions,
exprs,
physical_expr_cache: RwLock::new(None),
};
let mut checker = RuleChecker::new(&rule);
@@ -87,7 +104,7 @@ impl MultiDimPartitionRule {
}
// return the default region number
Ok(0)
Ok(DEFAULT_REGION)
}
fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
@@ -134,6 +151,133 @@ impl MultiDimPartitionRule {
Ok(result)
}
pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
for (col_idx, col) in cols.iter().enumerate() {
row[col_idx] = col.get(index);
}
Ok(())
}
pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
self.partition_columns
.iter()
.map(|col_name| {
record_batch
.column_by_name(col_name)
.context(error::UndefinedColumnSnafu { column: col_name })
.and_then(|array| {
Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
})
})
.collect::<Result<Vec<_>>>()
}
pub fn split_record_batch_naive(
&self,
record_batch: &RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
let num_rows = record_batch.num_rows();
let mut result = self
.regions
.iter()
.map(|region| {
let mut builder = BooleanBufferBuilder::new(num_rows);
builder.append_n(num_rows, false);
(*region, builder)
})
.collect::<HashMap<_, _>>();
let cols = self.record_batch_to_cols(record_batch)?;
let mut current_row = vec![Value::Null; self.partition_columns.len()];
for row_idx in 0..num_rows {
self.row_at(&cols, row_idx, &mut current_row)?;
let current_region = self.find_region(&current_row)?;
let region_mask = result
.get_mut(&current_region)
.unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
region_mask.set_bit(row_idx, true);
}
Ok(result
.into_iter()
.map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
.collect())
}
pub fn split_record_batch(
&self,
record_batch: &RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
let num_rows = record_batch.num_rows();
let physical_exprs = {
let cache_read_guard = self.physical_expr_cache.read().unwrap();
if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
&& schema == record_batch.schema_ref()
{
cached_exprs.clone()
} else {
drop(cache_read_guard); // Release the read lock before acquiring write lock
let schema = record_batch.schema();
let new_cache = self
.exprs
.iter()
.map(|e| e.try_as_physical_expr(&schema))
.collect::<Result<Vec<_>>>()?;
let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
cache_write_guard.replace((new_cache.clone(), schema));
new_cache
}
};
let mut result: HashMap<u32, BooleanArray> = physical_exprs
.iter()
.zip(self.regions.iter())
.map(|(expr, region_num)| {
let ColumnarValue::Array(column) = expr
.evaluate(record_batch)
.context(error::EvaluateRecordBatchSnafu)?
else {
unreachable!("Expected an array")
};
Ok((
*region_num,
column
.as_any()
.downcast_ref::<BooleanArray>()
.with_context(|| error::UnexpectedColumnTypeSnafu {
data_type: column.data_type().clone(),
})?
.clone(),
))
})
.collect::<error::Result<_>>()?;
let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
for region_selection in result.values() {
selected = arrow::compute::kernels::boolean::or(&selected, region_selection)
.context(error::ComputeArrowKernelSnafu)?;
}
// fast path: all rows are selected
if selected.true_count() == num_rows {
return Ok(result);
}
// find unselected rows and assign to default region
let unselected = arrow::compute::kernels::boolean::not(&selected)
.context(error::ComputeArrowKernelSnafu)?;
let default_region_selection = result
.entry(DEFAULT_REGION)
.or_insert_with(|| unselected.clone());
*default_region_selection =
arrow::compute::kernels::boolean::or(default_region_selection, &unselected)
.context(error::ComputeArrowKernelSnafu)?;
Ok(result)
}
}
impl PartitionRule for MultiDimPartitionRule {
@@ -148,6 +292,13 @@ impl PartitionRule for MultiDimPartitionRule {
fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
self.find_region(values)
}
fn split_record_batch(
&self,
record_batch: &RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
self.split_record_batch(record_batch)
}
}
/// Helper for [RuleChecker]
@@ -633,3 +784,155 @@ mod tests {
assert!(rule.is_err());
}
}
#[cfg(test)]
mod test_split_record_batch {
use std::sync::Arc;
use datatypes::arrow::array::{Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::arrow::record_batch::RecordBatch;
use rand::Rng;
use super::*;
use crate::expr::col;
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("host", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]))
}
fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
let schema = test_schema();
let mut rng = rand::thread_rng();
let mut host_array = Vec::with_capacity(num_rows);
let mut value_array = Vec::with_capacity(num_rows);
for _ in 0..num_rows {
host_array.push(format!("server{}", rng.gen_range(0..20)));
value_array.push(rng.gen_range(0..20));
}
let host_array = StringArray::from(host_array);
let value_array = Int64Array::from(value_array);
RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
}
#[test]
fn test_split_record_batch_by_one_column() {
// Create a simple MultiDimPartitionRule
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string(), "value".to_string()],
vec![0, 1],
vec![
col("host").lt(Value::String("server1".into())),
col("host").gt_eq(Value::String("server1".into())),
],
)
.unwrap();
let batch = generate_random_record_batch(1000);
// Split the batch
let result = rule.split_record_batch(&batch).unwrap();
let expected = rule.split_record_batch_naive(&batch).unwrap();
assert_eq!(result.len(), expected.len());
for (region, value) in &result {
assert_eq!(
value,
expected.get(region).unwrap(),
"failed on region: {}",
region
);
}
}
#[test]
fn test_split_record_batch_empty() {
// Create a simple MultiDimPartitionRule
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string()],
vec![1],
vec![PartitionExpr::new(
Operand::Column("host".to_string()),
RestrictedOp::Eq,
Operand::Value(Value::String("server1".into())),
)],
)
.unwrap();
let schema = test_schema();
let host_array = StringArray::from(Vec::<&str>::new());
let value_array = Int64Array::from(Vec::<i64>::new());
let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
.unwrap();
let result = rule.split_record_batch(&batch).unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_split_record_batch_by_two_columns() {
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string(), "value".to_string()],
vec![0, 1, 2, 3],
vec![
col("host")
.lt(Value::String("server10".into()))
.and(col("value").lt(Value::Int64(10))),
col("host")
.lt(Value::String("server10".into()))
.and(col("value").gt_eq(Value::Int64(10))),
col("host")
.gt_eq(Value::String("server10".into()))
.and(col("value").lt(Value::Int64(10))),
col("host")
.gt_eq(Value::String("server10".into()))
.and(col("value").gt_eq(Value::Int64(10))),
],
)
.unwrap();
let batch = generate_random_record_batch(1000);
let result = rule.split_record_batch(&batch).unwrap();
let expected = rule.split_record_batch_naive(&batch).unwrap();
assert_eq!(result.len(), expected.len());
for (region, value) in &result {
assert_eq!(value, expected.get(region).unwrap());
}
}
#[test]
fn test_default_region() {
let rule = MultiDimPartitionRule::try_new(
vec!["host".to_string(), "value".to_string()],
vec![0, 1, 2, 3],
vec![
col("host")
.lt(Value::String("server10".into()))
.and(col("value").eq(Value::Int64(10))),
col("host")
.lt(Value::String("server10".into()))
.and(col("value").eq(Value::Int64(20))),
col("host")
.gt_eq(Value::String("server10".into()))
.and(col("value").eq(Value::Int64(10))),
col("host")
.gt_eq(Value::String("server10".into()))
.and(col("value").eq(Value::Int64(20))),
],
)
.unwrap();
let schema = test_schema();
let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]);
let value_array = Int64Array::from(vec![10, 20, 30, 10]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
.unwrap();
let result = rule.split_record_batch(&batch).unwrap();
let expected = rule.split_record_batch_naive(&batch).unwrap();
assert_eq!(result.len(), expected.len());
for (region, value) in &result {
assert_eq!(value, expected.get(region).unwrap());
}
}
}

View File

@@ -13,11 +13,13 @@
// limitations under the License.
use std::any::Any;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use common_meta::rpc::router::Partition as MetaPartition;
use datafusion_expr::Operator;
use datatypes::arrow::array::{BooleanArray, RecordBatch};
use datatypes::prelude::Value;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
@@ -37,6 +39,13 @@ pub trait PartitionRule: Sync + Send {
///
/// Note that the `values` should have the same length as the `partition_columns`.
fn find_region(&self, values: &[Value]) -> Result<RegionNumber>;
/// Split the record batch into multiple regions by the partition values.
/// The result is a map from region number to a boolean array, where the boolean array is true for the rows that match the partition values.
fn split_record_batch(
&self,
record_batch: &RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>>;
}
/// The right bound(exclusive) of partition range.

View File

@@ -136,6 +136,7 @@ mod tests {
use api::v1::value::ValueData;
use api::v1::{ColumnDataType, SemanticType};
use datatypes::arrow::array::BooleanArray;
use serde::{Deserialize, Serialize};
use super::*;
@@ -209,6 +210,13 @@ mod tests {
Ok(val.parse::<u32>().unwrap() % 2)
}
fn split_record_batch(
&self,
_record_batch: &datatypes::arrow::array::RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
unimplemented!()
}
}
#[derive(Debug, Serialize, Deserialize)]
@@ -232,6 +240,13 @@ mod tests {
Ok(val)
}
fn split_record_batch(
&self,
_record_batch: &datatypes::arrow::array::RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
unimplemented!()
}
}
#[derive(Debug, Serialize, Deserialize)]
@@ -249,8 +264,14 @@ mod tests {
fn find_region(&self, _values: &[Value]) -> Result<RegionNumber> {
Ok(0)
}
}
fn split_record_batch(
&self,
_record_batch: &datatypes::arrow::array::RecordBatch,
) -> Result<HashMap<RegionNumber, BooleanArray>> {
unimplemented!()
}
}
#[test]
fn test_writer_splitter() {
let rows = mock_rows();