mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-15 04:20:39 +00:00
perf: support group accumulators for state wrapper (#7826)
* perf: support group accumulators for state wrapper * new tests and avoid clone Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::StructArray;
|
||||
use arrow::array::{ArrayRef, BooleanArray, StructArray};
|
||||
use arrow_schema::{FieldRef, Fields};
|
||||
use common_telemetry::debug;
|
||||
use datafusion::functions_aggregate::all_default_aggregate_functions;
|
||||
@@ -38,8 +38,8 @@ use datafusion_common::{Column, ScalarValue};
|
||||
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
|
||||
use datafusion_expr::function::StateFieldsArgs;
|
||||
use datafusion_expr::{
|
||||
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
|
||||
Signature,
|
||||
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
|
||||
GroupsAccumulator, LogicalPlan, Signature,
|
||||
};
|
||||
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
@@ -322,6 +322,14 @@ impl StateWrapper {
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
fn fix_inner_acc_args<'b>(
|
||||
&self,
|
||||
mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
||||
) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
|
||||
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
|
||||
Ok(acc_args)
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregateUDFImpl for StateWrapper {
|
||||
@@ -331,15 +339,32 @@ impl AggregateUDFImpl for StateWrapper {
|
||||
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
||||
// fix and recover proper acc args for the original aggregate function.
|
||||
let state_type = acc_args.return_type().clone();
|
||||
let inner = {
|
||||
let mut new_acc_args = acc_args.clone();
|
||||
new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
|
||||
self.inner.accumulator(new_acc_args)?
|
||||
};
|
||||
let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
|
||||
|
||||
Ok(Box::new(StateAccum::new(inner, state_type)?))
|
||||
}
|
||||
|
||||
fn groups_accumulator_supported(
|
||||
&self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs,
|
||||
) -> bool {
|
||||
self.fix_inner_acc_args(acc_args)
|
||||
.map(|args| self.inner.inner().groups_accumulator_supported(args))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn create_groups_accumulator(
|
||||
&self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs,
|
||||
) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
|
||||
let state_type = acc_args.return_type().clone();
|
||||
let inner = self
|
||||
.inner
|
||||
.inner()
|
||||
.create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
|
||||
Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
@@ -462,6 +487,118 @@ pub struct StateAccum {
|
||||
state_fields: Fields,
|
||||
}
|
||||
|
||||
pub struct StateGroupsAccum {
|
||||
inner: Box<dyn GroupsAccumulator>,
|
||||
state_fields: Fields,
|
||||
}
|
||||
|
||||
impl StateGroupsAccum {
|
||||
fn new(
|
||||
inner: Box<dyn GroupsAccumulator>,
|
||||
state_type: DataType,
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let DataType::Struct(fields) = state_type else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected a struct type for state, got: {:?}",
|
||||
state_type
|
||||
)));
|
||||
};
|
||||
Ok(Self {
|
||||
inner,
|
||||
state_fields: fields,
|
||||
})
|
||||
}
|
||||
|
||||
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
|
||||
let array_type = arrays
|
||||
.iter()
|
||||
.map(|array| array.data_type().clone())
|
||||
.collect::<Vec<_>>();
|
||||
let expected_type = self
|
||||
.state_fields
|
||||
.iter()
|
||||
.map(|field| field.data_type().clone())
|
||||
.collect::<Vec<_>>();
|
||||
if array_type != expected_type {
|
||||
debug!(
|
||||
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
|
||||
self.state_fields.len(),
|
||||
arrays.len(),
|
||||
self.state_fields,
|
||||
array_type,
|
||||
);
|
||||
let guess_schema = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, array)| {
|
||||
Field::new(
|
||||
format!("col_{index}[mismatch_state]").as_str(),
|
||||
array.data_type().clone(),
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Fields>();
|
||||
let array = StructArray::try_new(guess_schema, arrays, None)?;
|
||||
return Ok(Arc::new(array));
|
||||
}
|
||||
|
||||
Ok(Arc::new(StructArray::try_new(
|
||||
self.state_fields.clone(),
|
||||
arrays,
|
||||
None,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl GroupsAccumulator for StateGroupsAccum {
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
values: &[ArrayRef],
|
||||
group_indices: &[usize],
|
||||
opt_filter: Option<&BooleanArray>,
|
||||
total_num_groups: usize,
|
||||
) -> datafusion_common::Result<()> {
|
||||
self.inner
|
||||
.update_batch(values, group_indices, opt_filter, total_num_groups)
|
||||
}
|
||||
|
||||
fn merge_batch(
|
||||
&mut self,
|
||||
values: &[ArrayRef],
|
||||
group_indices: &[usize],
|
||||
opt_filter: Option<&BooleanArray>,
|
||||
total_num_groups: usize,
|
||||
) -> datafusion_common::Result<()> {
|
||||
self.inner
|
||||
.merge_batch(values, group_indices, opt_filter, total_num_groups)
|
||||
}
|
||||
|
||||
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
|
||||
let state = self.inner.state(emit_to)?;
|
||||
self.wrap_state_arrays(state)
|
||||
}
|
||||
|
||||
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
|
||||
self.inner.state(emit_to)
|
||||
}
|
||||
|
||||
fn convert_to_state(
|
||||
&self,
|
||||
values: &[ArrayRef],
|
||||
opt_filter: Option<&BooleanArray>,
|
||||
) -> datafusion_common::Result<Vec<ArrayRef>> {
|
||||
self.inner.convert_to_state(values, opt_filter)
|
||||
}
|
||||
|
||||
fn supports_convert_to_state(&self) -> bool {
|
||||
self.inner.supports_convert_to_state()
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.inner.size()
|
||||
}
|
||||
}
|
||||
|
||||
impl StateAccum {
|
||||
pub fn new(
|
||||
inner: Box<dyn Accumulator>,
|
||||
|
||||
@@ -40,10 +40,13 @@ use datafusion_common::arrow::array::AsArray;
|
||||
use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type};
|
||||
use datafusion_common::{Column, TableReference};
|
||||
use datafusion_expr::expr::{AggregateFunction, NullTreatment};
|
||||
use datafusion_expr::function::AccumulatorArgs;
|
||||
use datafusion_expr::{
|
||||
Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit,
|
||||
Aggregate, AggregateUDFImpl, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr,
|
||||
TableScan, lit,
|
||||
};
|
||||
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
|
||||
use datafusion_physical_expr::expressions::col;
|
||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
use futures::{Stream, StreamExt as _};
|
||||
@@ -256,6 +259,38 @@ fn dummy_table_scan_with_ts() -> LogicalPlan {
|
||||
)
|
||||
}
|
||||
|
||||
fn create_avg_state_groups_accumulator() -> Box<dyn GroupsAccumulator> {
|
||||
let state_wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap();
|
||||
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Float64,
|
||||
true,
|
||||
)]));
|
||||
let expr = col("number", &schema).unwrap();
|
||||
let expr_field = expr.return_field(&schema).unwrap();
|
||||
let return_field = Arc::new(Field::new(
|
||||
"__avg_state(number)",
|
||||
state_wrapper.return_type(&[DataType::Float64]).unwrap(),
|
||||
true,
|
||||
));
|
||||
let exprs = [expr];
|
||||
let expr_fields = [expr_field];
|
||||
let acc_args = AccumulatorArgs {
|
||||
return_field,
|
||||
schema: &schema,
|
||||
ignore_nulls: false,
|
||||
order_bys: &[],
|
||||
is_reversed: false,
|
||||
name: "__avg_state(number)",
|
||||
is_distinct: false,
|
||||
exprs: &exprs,
|
||||
expr_fields: &expr_fields,
|
||||
};
|
||||
|
||||
assert!(state_wrapper.groups_accumulator_supported(acc_args.clone()));
|
||||
state_wrapper.create_groups_accumulator(acc_args).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_udaf() {
|
||||
let ctx = SessionContext::new();
|
||||
@@ -796,6 +831,95 @@ async fn test_last_value_order_by_udaf() {
|
||||
assert_eq!(merge_eval_res, ScalarValue::Int64(Some(4)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_state_groups_accumulator_evaluate() {
|
||||
let mut state_accum = create_avg_state_groups_accumulator();
|
||||
let values = vec![Arc::new(Float64Array::from(vec![
|
||||
Some(1.0),
|
||||
Some(2.0),
|
||||
None,
|
||||
Some(3.0),
|
||||
Some(4.0),
|
||||
Some(5.0),
|
||||
])) as ArrayRef];
|
||||
let group_indices = vec![0, 1, 0, 0, 1, 2];
|
||||
|
||||
state_accum
|
||||
.update_batch(&values, &group_indices, None, 3)
|
||||
.unwrap();
|
||||
|
||||
let result = state_accum.evaluate(EmitTo::All).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap(),
|
||||
&UInt64Array::from(vec![2, 2, 1])
|
||||
);
|
||||
assert_eq!(
|
||||
result
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.unwrap(),
|
||||
&Float64Array::from(vec![4.0, 6.0, 5.0])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_avg_state_groups_accumulator_state_merge_evaluate() {
|
||||
let mut source_accum = create_avg_state_groups_accumulator();
|
||||
let source_values = vec![Arc::new(Float64Array::from(vec![
|
||||
Some(1.0),
|
||||
Some(2.0),
|
||||
None,
|
||||
Some(3.0),
|
||||
Some(4.0),
|
||||
Some(5.0),
|
||||
])) as ArrayRef];
|
||||
let source_group_indices = vec![0, 1, 0, 0, 1, 2];
|
||||
|
||||
source_accum
|
||||
.update_batch(&source_values, &source_group_indices, None, 3)
|
||||
.unwrap();
|
||||
let source_state = source_accum.state(EmitTo::All).unwrap();
|
||||
|
||||
let mut merged_accum = create_avg_state_groups_accumulator();
|
||||
let merged_values =
|
||||
vec![Arc::new(Float64Array::from(vec![Some(10.0), Some(20.0), Some(30.0)])) as ArrayRef];
|
||||
let merged_group_indices = vec![0, 1, 2];
|
||||
|
||||
merged_accum
|
||||
.update_batch(&merged_values, &merged_group_indices, None, 3)
|
||||
.unwrap();
|
||||
merged_accum
|
||||
.merge_batch(&source_state, &[1, 2, 0], None, 3)
|
||||
.unwrap();
|
||||
|
||||
let result = merged_accum.evaluate(EmitTo::All).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap(),
|
||||
&UInt64Array::from(vec![2, 3, 3])
|
||||
);
|
||||
assert_eq!(
|
||||
result
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.unwrap(),
|
||||
&Float64Array::from(vec![15.0, 24.0, 36.0])
|
||||
);
|
||||
}
|
||||
|
||||
/// For testing whether the UDAF state fields are correctly implemented.
|
||||
/// esp. for our own custom UDAF's state fields.
|
||||
/// By compare eval results before and after split to state/merge functions.
|
||||
|
||||
Reference in New Issue
Block a user