diff --git a/Cargo.lock b/Cargo.lock index f1bf4eb3d0..5182582e82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8138,10 +8138,13 @@ dependencies = [ "common-macro", "common-meta", "common-query", + "criterion 0.5.1", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr", "datatypes", "itertools 0.14.0", + "rand 0.8.5", "serde", "serde_json", "session", diff --git a/src/partition/Cargo.toml b/src/partition/Cargo.toml index ebb7d68f8d..6a0904f8f2 100644 --- a/src/partition/Cargo.toml +++ b/src/partition/Cargo.toml @@ -16,6 +16,7 @@ common-meta.workspace = true common-query.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-physical-expr.workspace = true datatypes.workspace = true itertools.workspace = true serde.workspace = true @@ -26,3 +27,11 @@ sql.workspace = true sqlparser.workspace = true store-api.workspace = true table.workspace = true + +[dev-dependencies] +criterion = "0.5" +rand = "0.8" + +[[bench]] +name = "bench_split_record_batch" +harness = false diff --git a/src/partition/benches/bench_split_record_batch.rs b/src/partition/benches/bench_split_record_batch.rs new file mode 100644 index 0000000000..f6c1bd69d4 --- /dev/null +++ b/src/partition/benches/bench_split_record_batch.rs @@ -0,0 +1,226 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::vec; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datatypes::arrow::array::{ArrayRef, Int32Array, StringArray, TimestampMillisecondArray}; +use datatypes::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datatypes::arrow::record_batch::RecordBatch; +use datatypes::value::Value; +use partition::expr::{col, Operand}; +use partition::multi_dim::MultiDimPartitionRule; +use partition::PartitionRule; +use rand::Rng; +use store_api::storage::RegionNumber; + +fn table_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a0", DataType::Int32, false), + Field::new("a1", DataType::Utf8, false), + Field::new("a2", DataType::Int32, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ])) +} + +fn create_test_rule(num_columns: usize) -> MultiDimPartitionRule { + let (columns, exprs) = match num_columns { + 1 => { + let exprs = vec![ + col("a0").lt(Value::Int32(50)), + col("a0").gt_eq(Value::Int32(50)), + ]; + (vec!["a0".to_string()], exprs) + } + 2 => { + let exprs = vec![ + col("a0") + .lt(Value::Int32(50)) + .and(col("a1").lt(Value::String("server50".into()))), + col("a0") + .lt(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))), + col("a0") + .gt_eq(Value::Int32(50)) + .and(col("a1").lt(Value::String("server50".into()))), + col("a0") + .gt_eq(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))), + ]; + (vec!["a0".to_string(), "a1".to_string()], exprs) + } + 3 => { + let expr = vec![ + col("a0") + .lt(Value::Int32(50)) + .and(col("a1").lt(Value::String("server50".into()))) + .and(col("a2").lt(Value::Int32(50))), + col("a0") + .lt(Operand::Value(Value::Int32(50))) + .and(col("a1").lt(Value::String("server50".into()))) + .and(col("a2").gt_eq(Value::Int32(50))), + col("a0") + .lt(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))) + .and(col("a2").lt(Value::Int32(50))), + col("a0") + .lt(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))) + .and(col("a2").gt_eq(Value::Int32(50))), + col("a0") + .gt_eq(Value::Int32(50)) + .and(col("a1").lt(Value::String("server50".into()))) + .and(col("a2").lt(Value::Int32(50))), + col("a0") + .gt_eq(Operand::Value(Value::Int32(50))) + .and(col("a1").lt(Value::String("server50".into()))) + .and(col("a2").gt_eq(Value::Int32(50))), + col("a0") + .gt_eq(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))) + .and(col("a2").lt(Value::Int32(50))), + col("a0") + .gt_eq(Value::Int32(50)) + .and(col("a1").gt_eq(Value::String("server50".into()))) + .and(col("a2").gt_eq(Value::Int32(50))), + ]; + + ( + vec!["a0".to_string(), "a1".to_string(), "a2".to_string()], + expr, + ) + } + _ => { + panic!("invalid number of columns, only 1-3 are supported"); + } + }; + + let regions = (0..exprs.len()).map(|v| v as u32).collect(); + MultiDimPartitionRule::try_new(columns, regions, exprs).unwrap() +} + +fn create_test_batch(size: usize) -> RecordBatch { + let mut rng = rand::thread_rng(); + + let schema = table_schema(); + let arrays: Vec = (0..3) + .map(|col_idx| { + if col_idx % 2 == 0 { + // Integer columns (a0, a2) + Arc::new(Int32Array::from_iter_values( + (0..size).map(|_| rng.gen_range(0..100)), + )) as ArrayRef + } else { + // String columns (a1) + let values: Vec = (0..size) + .map(|_| { + let server_id: i32 = rng.gen_range(0..100); + format!("server{}", server_id) + }) + .collect(); + Arc::new(StringArray::from(values)) as ArrayRef + } + }) + .chain(std::iter::once({ + // Timestamp column (ts) + Arc::new(TimestampMillisecondArray::from_iter_values( + (0..size).map(|idx| idx as i64), + )) as ArrayRef + })) + .collect(); + RecordBatch::try_new(schema, arrays).unwrap() +} + +fn bench_split_record_batch_naive_vs_optimized(c: &mut Criterion) { + let mut group = c.benchmark_group("split_record_batch"); + + for num_columns in [1, 2, 3].iter() { + for num_rows in [100, 1000, 10000, 100000].iter() { + let rule = create_test_rule(*num_columns); + let batch = create_test_batch(*num_rows); + + group.bench_function(format!("naive_{}_{}", num_columns, num_rows), |b| { + b.iter(|| { + black_box(rule.split_record_batch_naive(black_box(&batch))).unwrap(); + }); + }); + group.bench_function(format!("optimized_{}_{}", num_columns, num_rows), |b| { + b.iter(|| { + black_box(rule.split_record_batch(black_box(&batch))).unwrap(); + }); + }); + } + } + + group.finish(); +} + +fn record_batch_to_rows( + rule: &MultiDimPartitionRule, + record_batch: &RecordBatch, +) -> Vec> { + let num_rows = record_batch.num_rows(); + let vectors = rule.record_batch_to_cols(record_batch).unwrap(); + let mut res = Vec::with_capacity(num_rows); + let mut current_row = vec![Value::Null; vectors.len()]; + + for row in 0..num_rows { + rule.row_at(&vectors, row, &mut current_row).unwrap(); + res.push(current_row.clone()); + } + res +} + +fn find_all_regions(rule: &MultiDimPartitionRule, rows: &[Vec]) -> Vec { + rows.iter() + .map(|row| rule.find_region(row).unwrap()) + .collect() +} + +fn bench_split_record_batch_vs_row(c: &mut Criterion) { + let mut group = c.benchmark_group("bench_split_record_batch_vs_row"); + + for num_columns in [1, 2, 3].iter() { + for num_rows in [100, 1000, 10000, 100000].iter() { + let rule = create_test_rule(*num_columns); + let batch = create_test_batch(*num_rows); + let rows = record_batch_to_rows(&rule, &batch); + + group.bench_function(format!("split_by_row_{}_{}", num_columns, num_rows), |b| { + b.iter(|| { + black_box(find_all_regions(&rule, &rows)); + }); + }); + group.bench_function(format!("split_by_col_{}_{}", num_columns, num_rows), |b| { + b.iter(|| { + black_box(rule.split_record_batch(black_box(&batch))).unwrap(); + }); + }); + } + } + + group.finish(); +} + +criterion_group!( + benches, + bench_split_record_batch_naive_vs_optimized, + bench_split_record_batch_vs_row +); +criterion_main!(benches); diff --git a/src/partition/src/error.rs b/src/partition/src/error.rs index 2487fa0974..2194583f40 100644 --- a/src/partition/src/error.rs +++ b/src/partition/src/error.rs @@ -18,6 +18,8 @@ use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_macro::stack_trace_debug; use datafusion_common::ScalarValue; +use datatypes::arrow; +use datatypes::prelude::Value; use snafu::{Location, Snafu}; use store_api::storage::RegionId; use table::metadata::TableId; @@ -173,6 +175,59 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Failed to convert to vector"))] + ConvertToVector { + source: datatypes::error::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to evaluate record batch"))] + EvaluateRecordBatch { + #[snafu(source)] + error: datafusion_common::error::DataFusionError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to compute arrow kernel"))] + ComputeArrowKernel { + #[snafu(source)] + error: arrow::error::ArrowError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Unexpected evaluation result column type: {}", data_type))] + UnexpectedColumnType { + data_type: arrow::datatypes::DataType, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to convert to DataFusion's Schema"))] + ToDFSchema { + #[snafu(source)] + error: datafusion_common::error::DataFusionError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to create physical expression"))] + CreatePhysicalExpr { + #[snafu(source)] + error: datafusion_common::error::DataFusionError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Partition expr value is not supported: {:?}", value))] + UnsupportedPartitionExprValue { + value: Value, + #[snafu(implicit)] + location: Location, + }, } impl ErrorExt for Error { @@ -201,6 +256,13 @@ impl ErrorExt for Error { Error::TableRouteNotFound { .. } => StatusCode::TableNotFound, Error::TableRouteManager { source, .. } => source.status_code(), Error::UnexpectedLogicalRouteTable { source, .. } => source.status_code(), + Error::ConvertToVector { source, .. } => source.status_code(), + Error::EvaluateRecordBatch { .. } => StatusCode::Internal, + Error::ComputeArrowKernel { .. } => StatusCode::Internal, + Error::UnexpectedColumnType { .. } => StatusCode::Internal, + Error::ToDFSchema { .. } => StatusCode::Internal, + Error::CreatePhysicalExpr { .. } => StatusCode::Internal, + Error::UnsupportedPartitionExprValue { .. } => StatusCode::InvalidArguments, } } diff --git a/src/partition/src/expr.rs b/src/partition/src/expr.rs index bec9543e72..b758d6dcba 100644 --- a/src/partition/src/expr.rs +++ b/src/partition/src/expr.rs @@ -13,12 +13,23 @@ // limitations under the License. use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; -use datatypes::value::Value; +use datafusion_common::{ScalarValue, ToDFSchema}; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::Expr; +use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; +use datatypes::arrow; +use datatypes::value::{ + duration_to_scalar_value, time_to_scalar_value, timestamp_to_scalar_value, Value, +}; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use sql::statements::value_to_sql_value; use sqlparser::ast::{BinaryOperator as ParserBinaryOperator, Expr as ParserExpr, Ident}; +use crate::error; + /// Struct for partition expression. This can be converted back to sqlparser's [Expr]. /// by [`Self::to_parser_expr`]. /// @@ -37,6 +48,75 @@ pub enum Operand { Expr(PartitionExpr), } +pub fn col(column_name: impl Into) -> Operand { + Operand::Column(column_name.into()) +} + +impl From for Operand { + fn from(value: Value) -> Self { + Operand::Value(value) + } +} + +impl Operand { + pub fn try_as_logical_expr(&self) -> error::Result { + match self { + Self::Column(c) => Ok(datafusion_expr::col(c)), + Self::Value(v) => { + let scalar_value = match v { + Value::Boolean(v) => ScalarValue::Boolean(Some(*v)), + Value::UInt8(v) => ScalarValue::UInt8(Some(*v)), + Value::UInt16(v) => ScalarValue::UInt16(Some(*v)), + Value::UInt32(v) => ScalarValue::UInt32(Some(*v)), + Value::UInt64(v) => ScalarValue::UInt64(Some(*v)), + Value::Int8(v) => ScalarValue::Int8(Some(*v)), + Value::Int16(v) => ScalarValue::Int16(Some(*v)), + Value::Int32(v) => ScalarValue::Int32(Some(*v)), + Value::Int64(v) => ScalarValue::Int64(Some(*v)), + Value::Float32(v) => ScalarValue::Float32(Some(v.0)), + Value::Float64(v) => ScalarValue::Float64(Some(v.0)), + Value::String(v) => ScalarValue::Utf8(Some(v.as_utf8().to_string())), + Value::Binary(v) => ScalarValue::Binary(Some(v.to_vec())), + Value::Date(v) => ScalarValue::Date32(Some(v.val())), + Value::Null => ScalarValue::Null, + Value::Timestamp(t) => timestamp_to_scalar_value(t.unit(), Some(t.value())), + Value::Time(t) => time_to_scalar_value(*t.unit(), Some(t.value())).unwrap(), + Value::IntervalYearMonth(v) => ScalarValue::IntervalYearMonth(Some(v.to_i32())), + Value::IntervalDayTime(v) => ScalarValue::IntervalDayTime(Some((*v).into())), + Value::IntervalMonthDayNano(v) => { + ScalarValue::IntervalMonthDayNano(Some((*v).into())) + } + Value::Duration(d) => duration_to_scalar_value(d.unit(), Some(d.value())), + Value::Decimal128(d) => { + let (v, p, s) = d.to_scalar_value(); + ScalarValue::Decimal128(v, p, s) + } + other => { + return error::UnsupportedPartitionExprValueSnafu { + value: other.clone(), + } + .fail() + } + }; + Ok(datafusion_expr::lit(scalar_value)) + } + Self::Expr(e) => e.try_as_logical_expr(), + } + } + + pub fn lt(self, rhs: impl Into) -> PartitionExpr { + PartitionExpr::new(self, RestrictedOp::Lt, rhs.into()) + } + + pub fn gt_eq(self, rhs: impl Into) -> PartitionExpr { + PartitionExpr::new(self, RestrictedOp::GtEq, rhs.into()) + } + + pub fn eq(self, rhs: impl Into) -> PartitionExpr { + PartitionExpr::new(self, RestrictedOp::Eq, rhs.into()) + } +} + impl Display for Operand { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -140,6 +220,41 @@ impl PartitionExpr { right: Box::new(rhs), } } + + pub fn try_as_logical_expr(&self) -> error::Result { + let lhs = self.lhs.try_as_logical_expr()?; + let rhs = self.rhs.try_as_logical_expr()?; + + let expr = match &self.op { + RestrictedOp::And => datafusion_expr::and(lhs, rhs), + RestrictedOp::Or => datafusion_expr::or(lhs, rhs), + RestrictedOp::Gt => lhs.gt(rhs), + RestrictedOp::GtEq => lhs.gt_eq(rhs), + RestrictedOp::Lt => lhs.lt(rhs), + RestrictedOp::LtEq => lhs.lt_eq(rhs), + RestrictedOp::Eq => lhs.eq(rhs), + RestrictedOp::NotEq => lhs.not_eq(rhs), + }; + Ok(expr) + } + + pub fn try_as_physical_expr( + &self, + schema: &arrow::datatypes::SchemaRef, + ) -> error::Result> { + let df_schema = schema + .clone() + .to_dfschema_ref() + .context(error::ToDFSchemaSnafu)?; + let execution_props = &ExecutionProps::default(); + let expr = self.try_as_logical_expr()?; + create_physical_expr(&expr, &df_schema, execution_props) + .context(error::CreatePhysicalExprSnafu) + } + + pub fn and(self, rhs: PartitionExpr) -> PartitionExpr { + PartitionExpr::new(Operand::Expr(self), RestrictedOp::And, Operand::Expr(rhs)) + } } impl Display for PartitionExpr { diff --git a/src/partition/src/lib.rs b/src/partition/src/lib.rs index b1843a1093..bc56edc584 100644 --- a/src/partition/src/lib.rs +++ b/src/partition/src/lib.rs @@ -13,7 +13,7 @@ // limitations under the License. #![feature(assert_matches)] - +#![feature(let_chains)] //! Structs and traits for partitioning rule. pub mod error; diff --git a/src/partition/src/multi_dim.rs b/src/partition/src/multi_dim.rs index f47d71f98b..551fb6a8de 100644 --- a/src/partition/src/multi_dim.rs +++ b/src/partition/src/multi_dim.rs @@ -15,10 +15,18 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; +use datatypes::arrow; +use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch}; +use datatypes::arrow::buffer::BooleanBuffer; +use datatypes::arrow::datatypes::Schema; use datatypes::prelude::Value; +use datatypes::vectors::{Helper, VectorRef}; use serde::{Deserialize, Serialize}; -use snafu::{ensure, OptionExt}; +use snafu::{ensure, OptionExt, ResultExt}; use store_api::storage::RegionNumber; use crate::error::{ @@ -28,6 +36,11 @@ use crate::error::{ use crate::expr::{Operand, PartitionExpr, RestrictedOp}; use crate::PartitionRule; +/// The default region number when no partition exprs are matched. +const DEFAULT_REGION: RegionNumber = 0; + +type PhysicalExprCache = Option<(Vec>, Arc)>; + /// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md) /// /// This partition rule is defined by a set of simple expressions on the partition @@ -44,6 +57,9 @@ pub struct MultiDimPartitionRule { regions: Vec, /// Partition expressions. exprs: Vec, + /// Cache of physical expressions. + #[serde(skip)] + physical_expr_cache: RwLock, } impl MultiDimPartitionRule { @@ -63,6 +79,7 @@ impl MultiDimPartitionRule { name_to_index, regions, exprs, + physical_expr_cache: RwLock::new(None), }; let mut checker = RuleChecker::new(&rule); @@ -87,7 +104,7 @@ impl MultiDimPartitionRule { } // return the default region number - Ok(0) + Ok(DEFAULT_REGION) } fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result { @@ -134,6 +151,133 @@ impl MultiDimPartitionRule { Ok(result) } + + pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> { + for (col_idx, col) in cols.iter().enumerate() { + row[col_idx] = col.get(index); + } + Ok(()) + } + + pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result> { + self.partition_columns + .iter() + .map(|col_name| { + record_batch + .column_by_name(col_name) + .context(error::UndefinedColumnSnafu { column: col_name }) + .and_then(|array| { + Helper::try_into_vector(array).context(error::ConvertToVectorSnafu) + }) + }) + .collect::>>() + } + + pub fn split_record_batch_naive( + &self, + record_batch: &RecordBatch, + ) -> Result> { + let num_rows = record_batch.num_rows(); + + let mut result = self + .regions + .iter() + .map(|region| { + let mut builder = BooleanBufferBuilder::new(num_rows); + builder.append_n(num_rows, false); + (*region, builder) + }) + .collect::>(); + + let cols = self.record_batch_to_cols(record_batch)?; + let mut current_row = vec![Value::Null; self.partition_columns.len()]; + for row_idx in 0..num_rows { + self.row_at(&cols, row_idx, &mut current_row)?; + let current_region = self.find_region(¤t_row)?; + let region_mask = result + .get_mut(¤t_region) + .unwrap_or_else(|| panic!("Region {} must be initialized", current_region)); + region_mask.set_bit(row_idx, true); + } + + Ok(result + .into_iter() + .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None))) + .collect()) + } + + pub fn split_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result> { + let num_rows = record_batch.num_rows(); + let physical_exprs = { + let cache_read_guard = self.physical_expr_cache.read().unwrap(); + if let Some((cached_exprs, schema)) = cache_read_guard.as_ref() + && schema == record_batch.schema_ref() + { + cached_exprs.clone() + } else { + drop(cache_read_guard); // Release the read lock before acquiring write lock + + let schema = record_batch.schema(); + let new_cache = self + .exprs + .iter() + .map(|e| e.try_as_physical_expr(&schema)) + .collect::>>()?; + + let mut cache_write_guard = self.physical_expr_cache.write().unwrap(); + cache_write_guard.replace((new_cache.clone(), schema)); + new_cache + } + }; + + let mut result: HashMap = physical_exprs + .iter() + .zip(self.regions.iter()) + .map(|(expr, region_num)| { + let ColumnarValue::Array(column) = expr + .evaluate(record_batch) + .context(error::EvaluateRecordBatchSnafu)? + else { + unreachable!("Expected an array") + }; + Ok(( + *region_num, + column + .as_any() + .downcast_ref::() + .with_context(|| error::UnexpectedColumnTypeSnafu { + data_type: column.data_type().clone(), + })? + .clone(), + )) + }) + .collect::>()?; + + let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None); + for region_selection in result.values() { + selected = arrow::compute::kernels::boolean::or(&selected, region_selection) + .context(error::ComputeArrowKernelSnafu)?; + } + + // fast path: all rows are selected + if selected.true_count() == num_rows { + return Ok(result); + } + + // find unselected rows and assign to default region + let unselected = arrow::compute::kernels::boolean::not(&selected) + .context(error::ComputeArrowKernelSnafu)?; + let default_region_selection = result + .entry(DEFAULT_REGION) + .or_insert_with(|| unselected.clone()); + *default_region_selection = + arrow::compute::kernels::boolean::or(default_region_selection, &unselected) + .context(error::ComputeArrowKernelSnafu)?; + Ok(result) + } } impl PartitionRule for MultiDimPartitionRule { @@ -148,6 +292,13 @@ impl PartitionRule for MultiDimPartitionRule { fn find_region(&self, values: &[Value]) -> Result { self.find_region(values) } + + fn split_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result> { + self.split_record_batch(record_batch) + } } /// Helper for [RuleChecker] @@ -633,3 +784,155 @@ mod tests { assert!(rule.is_err()); } } + +#[cfg(test)] +mod test_split_record_batch { + use std::sync::Arc; + + use datatypes::arrow::array::{Int64Array, StringArray}; + use datatypes::arrow::datatypes::{DataType, Field, Schema}; + use datatypes::arrow::record_batch::RecordBatch; + use rand::Rng; + + use super::*; + use crate::expr::col; + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("host", DataType::Utf8, false), + Field::new("value", DataType::Int64, false), + ])) + } + + fn generate_random_record_batch(num_rows: usize) -> RecordBatch { + let schema = test_schema(); + let mut rng = rand::thread_rng(); + let mut host_array = Vec::with_capacity(num_rows); + let mut value_array = Vec::with_capacity(num_rows); + for _ in 0..num_rows { + host_array.push(format!("server{}", rng.gen_range(0..20))); + value_array.push(rng.gen_range(0..20)); + } + let host_array = StringArray::from(host_array); + let value_array = Int64Array::from(value_array); + RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap() + } + + #[test] + fn test_split_record_batch_by_one_column() { + // Create a simple MultiDimPartitionRule + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![0, 1], + vec![ + col("host").lt(Value::String("server1".into())), + col("host").gt_eq(Value::String("server1".into())), + ], + ) + .unwrap(); + + let batch = generate_random_record_batch(1000); + // Split the batch + let result = rule.split_record_batch(&batch).unwrap(); + let expected = rule.split_record_batch_naive(&batch).unwrap(); + assert_eq!(result.len(), expected.len()); + for (region, value) in &result { + assert_eq!( + value, + expected.get(region).unwrap(), + "failed on region: {}", + region + ); + } + } + + #[test] + fn test_split_record_batch_empty() { + // Create a simple MultiDimPartitionRule + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string()], + vec![1], + vec![PartitionExpr::new( + Operand::Column("host".to_string()), + RestrictedOp::Eq, + Operand::Value(Value::String("server1".into())), + )], + ) + .unwrap(); + + let schema = test_schema(); + let host_array = StringArray::from(Vec::<&str>::new()); + let value_array = Int64Array::from(Vec::::new()); + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + let result = rule.split_record_batch(&batch).unwrap(); + assert_eq!(result.len(), 1); + } + + #[test] + fn test_split_record_batch_by_two_columns() { + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![0, 1, 2, 3], + vec![ + col("host") + .lt(Value::String("server10".into())) + .and(col("value").lt(Value::Int64(10))), + col("host") + .lt(Value::String("server10".into())) + .and(col("value").gt_eq(Value::Int64(10))), + col("host") + .gt_eq(Value::String("server10".into())) + .and(col("value").lt(Value::Int64(10))), + col("host") + .gt_eq(Value::String("server10".into())) + .and(col("value").gt_eq(Value::Int64(10))), + ], + ) + .unwrap(); + + let batch = generate_random_record_batch(1000); + let result = rule.split_record_batch(&batch).unwrap(); + let expected = rule.split_record_batch_naive(&batch).unwrap(); + assert_eq!(result.len(), expected.len()); + for (region, value) in &result { + assert_eq!(value, expected.get(region).unwrap()); + } + } + + #[test] + fn test_default_region() { + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![0, 1, 2, 3], + vec![ + col("host") + .lt(Value::String("server10".into())) + .and(col("value").eq(Value::Int64(10))), + col("host") + .lt(Value::String("server10".into())) + .and(col("value").eq(Value::Int64(20))), + col("host") + .gt_eq(Value::String("server10".into())) + .and(col("value").eq(Value::Int64(10))), + col("host") + .gt_eq(Value::String("server10".into())) + .and(col("value").eq(Value::Int64(20))), + ], + ) + .unwrap(); + + let schema = test_schema(); + let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]); + let value_array = Int64Array::from(vec![10, 20, 30, 10]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + let result = rule.split_record_batch(&batch).unwrap(); + let expected = rule.split_record_batch_naive(&batch).unwrap(); + assert_eq!(result.len(), expected.len()); + for (region, value) in &result { + assert_eq!(value, expected.get(region).unwrap()); + } + } +} diff --git a/src/partition/src/partition.rs b/src/partition/src/partition.rs index ac965034c6..a190d33eca 100644 --- a/src/partition/src/partition.rs +++ b/src/partition/src/partition.rs @@ -13,11 +13,13 @@ // limitations under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; use common_meta::rpc::router::Partition as MetaPartition; use datafusion_expr::Operator; +use datatypes::arrow::array::{BooleanArray, RecordBatch}; use datatypes::prelude::Value; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -37,6 +39,13 @@ pub trait PartitionRule: Sync + Send { /// /// Note that the `values` should have the same length as the `partition_columns`. fn find_region(&self, values: &[Value]) -> Result; + + /// Split the record batch into multiple regions by the partition values. + /// The result is a map from region number to a boolean array, where the boolean array is true for the rows that match the partition values. + fn split_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result>; } /// The right bound(exclusive) of partition range. diff --git a/src/partition/src/splitter.rs b/src/partition/src/splitter.rs index f62210a6b5..87c04a4942 100644 --- a/src/partition/src/splitter.rs +++ b/src/partition/src/splitter.rs @@ -136,6 +136,7 @@ mod tests { use api::v1::value::ValueData; use api::v1::{ColumnDataType, SemanticType}; + use datatypes::arrow::array::BooleanArray; use serde::{Deserialize, Serialize}; use super::*; @@ -209,6 +210,13 @@ mod tests { Ok(val.parse::().unwrap() % 2) } + + fn split_record_batch( + &self, + _record_batch: &datatypes::arrow::array::RecordBatch, + ) -> Result> { + unimplemented!() + } } #[derive(Debug, Serialize, Deserialize)] @@ -232,6 +240,13 @@ mod tests { Ok(val) } + + fn split_record_batch( + &self, + _record_batch: &datatypes::arrow::array::RecordBatch, + ) -> Result> { + unimplemented!() + } } #[derive(Debug, Serialize, Deserialize)] @@ -249,8 +264,14 @@ mod tests { fn find_region(&self, _values: &[Value]) -> Result { Ok(0) } - } + fn split_record_batch( + &self, + _record_batch: &datatypes::arrow::array::RecordBatch, + ) -> Result> { + unimplemented!() + } + } #[test] fn test_writer_splitter() { let rows = mock_rows();