diff --git a/src/flow/src/compute/render.rs b/src/flow/src/compute/render.rs index 0476c8a6e5..bf298e86bc 100644 --- a/src/flow/src/compute/render.rs +++ b/src/flow/src/compute/render.rs @@ -238,6 +238,12 @@ mod test { for now in time_range { state.set_current_ts(now); state.run_available_with_schedule(df); + if !state.get_err_collector().is_empty() { + panic!( + "Errors occur: {:?}", + state.get_err_collector().get_all_blocking() + ) + } assert!(state.get_err_collector().is_empty()); if let Some(expected) = expected.get(&now) { assert_eq!(*output.borrow(), *expected, "at ts={}", now); diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 46b2dc196f..da2bb11f4b 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -729,15 +729,113 @@ mod test { use std::cell::RefCell; use std::rc::Rc; - use datatypes::data_type::ConcreteDataType; + use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT}; use hydroflow::scheduled::graph::Hydroflow; use super::*; use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check}; use crate::compute::state::DataflowState; - use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject}; + use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc}; use crate::repr::{ColumnType, RelationType}; + /// select avg(number) from number; + #[test] + fn test_avg_eval() { + let mut df = Hydroflow::new(); + let mut state = DataflowState::default(); + let mut ctx = harness_test_ctx(&mut df, &mut state); + + let rows = vec![ + (Row::new(vec![1u32.into()]), 1, 1), + (Row::new(vec![2u32.into()]), 1, 1), + (Row::new(vec![3u32.into()]), 1, 1), + (Row::new(vec![1u32.into()]), 1, 1), + (Row::new(vec![2u32.into()]), 1, 1), + (Row::new(vec![3u32.into()]), 1, 1), + ]; + let collection = ctx.render_constant(rows.clone()); + ctx.insert_global(GlobalId::User(1), collection); + + let aggr_exprs = vec![ + AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + AggregateExpr { + func: AggregateFunc::Count, + expr: ScalarExpr::Column(0), + distinct: false, + }, + ]; + let avg_expr = ScalarExpr::If { + cond: Box::new(ScalarExpr::Column(1).call_binary( + ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()), + BinaryFunc::NotEq, + )), + then: Box::new(ScalarExpr::Column(0).call_binary( + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, + )), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new( + Plan::Reduce { + input: Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(1)), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::int64_datatype(), false), + ])), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: aggr_exprs.clone(), + simple_aggrs: vec![ + AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), + AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + ], + distinct_aggrs: vec![], + }), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), true), + ColumnType::new(ConcreteDataType::int64_datatype(), true), + ])), + ), + mfp: MapFilterProject::new(2) + .map(vec![ + avg_expr, + // TODO(discord9): optimize mfp so to remove indirect ref + ScalarExpr::Column(2), + ]) + .unwrap() + .project(vec![3]) + .unwrap(), + }, + }; + + let bundle = ctx.render_plan(expected).unwrap(); + + let output = get_output_handle(&mut ctx, bundle); + drop(ctx); + let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]); + run_and_check(&mut state, &mut df, 1..2, expected, output); + } + /// SELECT DISTINCT col FROM table /// /// table schema: diff --git a/src/flow/src/compute/types.rs b/src/flow/src/compute/types.rs index fa8c7315cb..f2276ba755 100644 --- a/src/flow/src/compute/types.rs +++ b/src/flow/src/compute/types.rs @@ -153,6 +153,9 @@ pub struct ErrCollector { } impl ErrCollector { + pub fn get_all_blocking(&self) -> Vec { + self.inner.blocking_lock().drain(..).collect_vec() + } pub async fn get_all(&self) -> Vec { self.inner.lock().await.drain(..).collect_vec() } diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index c177dcd571..12335fdf1f 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -375,6 +375,22 @@ impl BinaryFunc { ) } + pub fn add(input_type: ConcreteDataType) -> Result { + Self::specialization(GenericFn::Add, input_type) + } + + pub fn sub(input_type: ConcreteDataType) -> Result { + Self::specialization(GenericFn::Sub, input_type) + } + + pub fn mul(input_type: ConcreteDataType) -> Result { + Self::specialization(GenericFn::Mul, input_type) + } + + pub fn div(input_type: ConcreteDataType) -> Result { + Self::specialization(GenericFn::Div, input_type) + } + /// Get the specialization of the binary function based on the generic function and the input type pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result { let rule = SPECIALIZATION.get_or_init(|| { diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index 4506bf7a55..6aa53c80ca 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -136,27 +136,44 @@ impl AggregateFunc { /// Generate signature for each aggregate function macro_rules! generate_signature { - ($value:ident, { $($user_arm:tt)* }, - [ $( - $auto_arm:ident=>($con_type:ident,$generic:ident) - ),* - ]) => { + ($value:ident, + { $($user_arm:tt)* }, + [ $( + $auto_arm:ident=>($($arg:ident),*) + ),* + ] + ) => { match $value { $($user_arm)*, $( - Self::$auto_arm => Signature { - input: smallvec![ - ConcreteDataType::$con_type(), - ConcreteDataType::$con_type(), - ], - output: ConcreteDataType::$con_type(), - generic_fn: GenericFn::$generic, - }, + Self::$auto_arm => gen_one_siginature!($($arg),*), )* } }; } +/// Generate one match arm with optional arguments +macro_rules! gen_one_siginature { + ( + $con_type:ident, $generic:ident + ) => { + Signature { + input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),], + output: ConcreteDataType::$con_type(), + generic_fn: GenericFn::$generic, + } + }; + ( + $in_type:ident, $out_type:ident, $generic:ident + ) => { + Signature { + input: smallvec![ConcreteDataType::$in_type()], + output: ConcreteDataType::$out_type(), + generic_fn: GenericFn::$generic, + } + }; +} + static SPECIALIZATION: OnceLock> = OnceLock::new(); @@ -223,6 +240,8 @@ impl AggregateFunc { /// all concrete datatypes with precision types will be returned with largest possible variant /// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64` + /// + /// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64 pub fn signature(&self) -> Signature { generate_signature!(self, { AggregateFunc::Count => Signature { @@ -263,12 +282,12 @@ impl AggregateFunc { MinTime => (time_second_datatype, Min), MinDuration => (duration_second_datatype, Min), MinInterval => (interval_year_month_datatype, Min), - SumInt16 => (int16_datatype, Sum), - SumInt32 => (int32_datatype, Sum), - SumInt64 => (int64_datatype, Sum), - SumUInt16 => (uint16_datatype, Sum), - SumUInt32 => (uint32_datatype, Sum), - SumUInt64 => (uint64_datatype, Sum), + SumInt16 => (int16_datatype, int64_datatype, Sum), + SumInt32 => (int32_datatype, int64_datatype, Sum), + SumInt64 => (int64_datatype, int64_datatype, Sum), + SumUInt16 => (uint16_datatype, uint64_datatype, Sum), + SumUInt32 => (uint32_datatype, uint64_datatype, Sum), + SumUInt64 => (uint64_datatype, uint64_datatype, Sum), SumFloat32 => (float32_datatype, Sum), SumFloat64 => (float64_datatype, Sum), Any => (boolean_datatype, Any), diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index 5b28d8c7d5..1e83d13043 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -44,7 +44,7 @@ pub struct TypedPlan { impl TypedPlan { /// directly apply a mfp to the plan pub fn mfp(self, mfp: MapFilterProject) -> Result { - let new_type = self.typ.apply_mfp(&mfp, &[])?; + let new_type = self.typ.apply_mfp(&mfp)?; let plan = match self.plan { Plan::Mfp { input, @@ -68,14 +68,14 @@ impl TypedPlan { pub fn projection(self, exprs: Vec) -> Result { let input_arity = self.typ.column_types.len(); let output_arity = exprs.len(); - let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs + let (exprs, _expr_typs): (Vec<_>, Vec<_>) = exprs .into_iter() .map(|TypedExpr { expr, typ }| (expr, typ)) .unzip(); let mfp = MapFilterProject::new(input_arity) .map(exprs)? .project(input_arity..input_arity + output_arity)?; - let out_typ = self.typ.apply_mfp(&mfp, &expr_typs)?; + let out_typ = self.typ.apply_mfp(&mfp)?; // special case for mfp to compose when the plan is already mfp let plan = match self.plan { Plan::Mfp { diff --git a/src/flow/src/repr/relation.rs b/src/flow/src/repr/relation.rs index b36dfacd44..9494a013bb 100644 --- a/src/flow/src/repr/relation.rs +++ b/src/flow/src/repr/relation.rs @@ -111,13 +111,13 @@ impl RelationType { /// then new key=`[1]`, new time index=`[0]` /// /// note that this function will remove empty keys like key=`[]` will be removed - pub fn apply_mfp(&self, mfp: &MapFilterProject, expr_typs: &[ColumnType]) -> Result { - let all_types = self - .column_types - .iter() - .chain(expr_typs.iter()) - .cloned() - .collect_vec(); + pub fn apply_mfp(&self, mfp: &MapFilterProject) -> Result { + let mut all_types = self.column_types.clone(); + for expr in &mfp.expressions { + let expr_typ = expr.typ(&self.column_types)?; + all_types.push(expr_typ); + } + let all_types = all_types; let mfp_out_types = mfp .projection .iter() @@ -131,6 +131,7 @@ impl RelationType { }) }) .try_collect()?; + let old_to_new_col = BTreeMap::from_iter( mfp.projection .clone() diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index c287e98459..3f3bf3fb7c 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use common_decimal::Decimal128; use common_time::{Date, Timestamp}; use datatypes::arrow::compute::kernels::window; use datatypes::arrow::ipc::Binary; -use datatypes::data_type::ConcreteDataType as CDT; +use datatypes::data_type::{ConcreteDataType as CDT, DataType}; use datatypes::value::Value; use hydroflow::futures::future::Map; use itertools::Itertools; @@ -83,14 +83,18 @@ impl TypedExpr { } impl AggregateExpr { + /// Convert list of `Measure` into Flow's AggregateExpr + /// + /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], typ: &RelationType, extensions: &FunctionExtensions, - ) -> Result, Error> { + ) -> Result<(Vec, MapFilterProject), Error> { let _ = ctx; - let mut aggr_exprs = vec![]; + let mut all_aggr_exprs = vec![]; + let mut post_maps = vec![]; for m in measures { let filter = &m @@ -99,7 +103,7 @@ impl AggregateExpr { .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions)) .transpose()?; - let agg_func = match &m.measure { + let (aggr_expr, post_mfp) = match &m.measure { Some(f) => { let distinct = match f.invocation { _ if f.invocation == AggregationInvocation::Distinct as i32 => true, @@ -113,12 +117,30 @@ impl AggregateExpr { } None => not_impl_err!("Aggregate without aggregate function is not supported"), }?; - aggr_exprs.push(agg_func); + // permute col index refer to the output of post_mfp, + // so to help construct a mfp at the end + let mut post_map = post_mfp.unwrap_or(ScalarExpr::Column(0)); + let cur_arity = all_aggr_exprs.len(); + let remap = (0..aggr_expr.len()).map(|i| i + cur_arity).collect_vec(); + post_map.permute(&remap)?; + + all_aggr_exprs.extend(aggr_expr); + post_maps.push(post_map); } - Ok(aggr_exprs) + + let input_arity = all_aggr_exprs.len(); + let aggr_arity = post_maps.len(); + let post_mfp_final = MapFilterProject::new(all_aggr_exprs.len()) + .map(post_maps)? + .project(input_arity..input_arity + aggr_arity)?; + + Ok((all_aggr_exprs, post_mfp_final)) } /// Convert AggregateFunction into Flow's AggregateExpr + /// + /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function + /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` pub fn from_substrait_agg_func( f: &proto::AggregateFunction, input_schema: &RelationType, @@ -126,7 +148,7 @@ impl AggregateExpr { filter: &Option, order_by: &Option>, distinct: bool, - ) -> Result { + ) -> Result<(Vec, Option), Error> { // TODO(discord9): impl filter let _ = filter; let _ = order_by; @@ -141,26 +163,74 @@ impl AggregateExpr { args.push(arg_expr); } + if args.len() != 1 { + return not_impl_err!("Aggregated function with multiple arguments is not supported"); + } + let arg = if let Some(first) = args.first() { first } else { return not_impl_err!("Aggregated function without arguments is not supported"); }; - let func = match extensions.get(&f.function_reference) { + let fn_name = extensions + .get(&f.function_reference) + .cloned() + .map(|s| s.to_lowercase()); + + match fn_name.as_ref().map(|s| s.as_ref()) { + Some(Self::AVG_NAME) => AggregateExpr::from_avg_aggr_func(arg), Some(function_name) => { - AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone())) + let func = AggregateFunc::from_str_and_type( + function_name, + Some(arg.typ.scalar_type.clone()), + )?; + let exprs = vec![AggregateExpr { + func, + expr: arg.expr.clone(), + distinct, + }]; + let ret_mfp = None; + Ok((exprs, ret_mfp)) } None => not_impl_err!( "Aggregated function not found: function anchor = {:?}", f.function_reference ), - }?; - Ok(AggregateExpr { - func, + } + } + const AVG_NAME: &'static str = "avg"; + /// convert `avg` function into `sum(x)/cast(count(x) as x_type)` + fn from_avg_aggr_func( + arg: &TypedExpr, + ) -> Result<(Vec, Option), Error> { + let arg_type = arg.typ.scalar_type.clone(); + let sum = AggregateExpr { + func: AggregateFunc::from_str_and_type("sum", Some(arg_type.clone()))?, expr: arg.expr.clone(), - distinct, - }) + distinct: false, + }; + let sum_out_type = sum.func.signature().output.clone(); + let count = AggregateExpr { + func: AggregateFunc::Count, + expr: arg.expr.clone(), + distinct: false, + }; + let count_out_type = count.func.signature().output.clone(); + let avg_output = ScalarExpr::Column(0).call_binary( + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())), + BinaryFunc::div(sum_out_type.clone())?, + ); + // make sure we wouldn't divide by zero + let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone()); + let non_zero = ScalarExpr::If { + cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)), + then: Box::new(avg_output), + els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())), + }; + let ret_aggr_exprs = vec![sum, count]; + let ret_mfp = Some(non_zero); + Ok((ret_aggr_exprs, ret_mfp)) } } @@ -217,6 +287,10 @@ impl KeyValPlan { impl TypedPlan { /// Convert AggregateRel into Flow's TypedPlan + /// + /// The output of aggr plan is: + /// + /// .. pub fn from_substrait_agg_rel( ctx: &mut FlownodeContext, agg: &proto::AggregateRel, @@ -231,7 +305,7 @@ impl TypedPlan { let group_exprs = TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?; - let mut aggr_exprs = + let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?; let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan( @@ -253,7 +327,11 @@ impl TypedPlan { )); } // TODO(discord9): try best to get time - RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec()) + if group_exprs.is_empty() { + RelationType::new(output_types) + } else { + RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec()) + } }; // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs @@ -289,10 +367,40 @@ impl TypedPlan { key_val_plan, reduce_plan: ReducePlan::Accumulable(accum_plan), }; - Ok(TypedPlan { - typ: output_type, - plan, - }) + // FIX(discord9): deal with key first + if post_mfp.is_identity() { + Ok(TypedPlan { + typ: output_type, + plan, + }) + } else { + // make post_mfp map identical mapping of keys + let input = TypedPlan { + typ: output_type.clone(), + plan, + }; + let key_arity = group_exprs.len(); + let mut post_mfp = post_mfp; + let val_arity = post_mfp.input_arity; + // offset post_mfp's col ref by `key_arity` + let shuffle = BTreeMap::from_iter((0..val_arity).map(|v| (v, v + key_arity))); + let new_arity = key_arity + val_arity; + post_mfp.permute(shuffle, new_arity)?; + // add key projection to post mfp + let (m, f, p) = post_mfp.into_map_filter_project(); + let p = (0..key_arity).chain(p).collect_vec(); + let post_mfp = MapFilterProject::new(new_arity) + .map(m)? + .filter(f)? + .project(p)?; + Ok(TypedPlan { + typ: output_type.apply_mfp(&post_mfp)?, + plan: Plan::Mfp { + input: Box::new(input), + mfp: post_mfp, + }, + }) + } } } @@ -306,6 +414,182 @@ mod test { use crate::repr::{self, ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; + #[tokio::test] + async fn test_avg_group_by() { + let engine = create_test_query_engine(); + let sql = "SELECT avg(number), number FROM numbers GROUP BY number"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let aggr_exprs = vec![ + AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + AggregateExpr { + func: AggregateFunc::Count, + expr: ScalarExpr::Column(0), + distinct: false, + }, + ]; + let avg_expr = ScalarExpr::If { + cond: Box::new(ScalarExpr::Column(2).call_binary( + ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), + BinaryFunc::NotEq, + )), + then: Box::new(ScalarExpr::Column(1).call_binary( + ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, + )), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ + ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64 + ColumnType::new(CDT::uint32_datatype(), false), // number + ]), + plan: Plan::Mfp { + input: Box::new( + Plan::Reduce { + input: Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ])), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .map(vec![ScalarExpr::Column(0)]) + .unwrap() + .project(vec![1]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: aggr_exprs.clone(), + simple_aggrs: vec![ + AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), + AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + ], + distinct_aggrs: vec![], + }), + } + .with_types( + RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number + ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum + ColumnType::new(ConcreteDataType::int64_datatype(), true), // count + ]) + .with_key(vec![0]), + ), + ), + mfp: MapFilterProject::new(3) + .map(vec![ + avg_expr, // col 3 + // TODO(discord9): optimize mfp so to remove indirect ref + ScalarExpr::Column(3), // col 4 + ScalarExpr::Column(0), // col 5 + ]) + .unwrap() + .project(vec![4, 5]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + + #[tokio::test] + async fn test_avg() { + let engine = create_test_query_engine(); + let sql = "SELECT avg(number) FROM numbers"; + let plan = sql_to_substrait(engine.clone(), sql).await; + + let mut ctx = create_test_ctx(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + + let aggr_exprs = vec![ + AggregateExpr { + func: AggregateFunc::SumUInt32, + expr: ScalarExpr::Column(0), + distinct: false, + }, + AggregateExpr { + func: AggregateFunc::Count, + expr: ScalarExpr::Column(0), + distinct: false, + }, + ]; + let avg_expr = ScalarExpr::If { + cond: Box::new(ScalarExpr::Column(1).call_binary( + ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()), + BinaryFunc::NotEq, + )), + then: Box::new(ScalarExpr::Column(0).call_binary( + ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())), + BinaryFunc::DivUInt64, + )), + els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())), + }; + let expected = TypedPlan { + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), + plan: Plan::Mfp { + input: Box::new( + Plan::Reduce { + input: Box::new( + Plan::Get { + id: crate::expr::Id::Global(GlobalId::User(0)), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ])), + ), + key_val_plan: KeyValPlan { + key_plan: MapFilterProject::new(1) + .project(vec![]) + .unwrap() + .into_safe(), + val_plan: MapFilterProject::new(1) + .project(vec![0]) + .unwrap() + .into_safe(), + }, + reduce_plan: ReducePlan::Accumulable(AccumulablePlan { + full_aggrs: aggr_exprs.clone(), + simple_aggrs: vec![ + AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), + AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1), + ], + distinct_aggrs: vec![], + }), + } + .with_types(RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::int64_datatype(), true), + ])), + ), + mfp: MapFilterProject::new(2) + .map(vec![ + avg_expr, + // TODO(discord9): optimize mfp so to remove indirect ref + ScalarExpr::Column(2), + ]) + .unwrap() + .project(vec![3]) + .unwrap(), + }, + }; + assert_eq!(flow_plan.unwrap(), expected); + } + #[tokio::test] async fn test_sum() { let engine = create_test_query_engine(); @@ -315,7 +599,7 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); let typ = RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), true, )]); let aggr_expr = AggregateExpr { @@ -324,7 +608,7 @@ mod test { distinct: false, }; let expected = TypedPlan { - typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -355,9 +639,9 @@ mod test { .with_types(typ), ), mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) + .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) .unwrap() - .project(vec![1]) + .project(vec![2]) .unwrap(), }, }; @@ -380,7 +664,7 @@ mod test { }; let expected = TypedPlan { typ: RelationType::new(vec![ - ColumnType::new(CDT::uint32_datatype(), true), // col sum(number) + ColumnType::new(CDT::uint64_datatype(), true), // col sum(number) ColumnType::new(CDT::uint32_datatype(), false), // col number ]), plan: Plan::Mfp { @@ -415,15 +699,19 @@ mod test { .with_types( RelationType::new(vec![ ColumnType::new(CDT::uint32_datatype(), false), // col number - ColumnType::new(CDT::uint32_datatype(), true), // col sum(number) + ColumnType::new(CDT::uint64_datatype(), true), // col sum(number) ]) .with_key(vec![0]), ), ), mfp: MapFilterProject::new(2) - .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)]) + .map(vec![ + ScalarExpr::Column(1), + ScalarExpr::Column(2), + ScalarExpr::Column(0), + ]) .unwrap() - .project(vec![2, 3]) + .project(vec![3, 4]) .unwrap(), }, }; @@ -446,7 +734,7 @@ mod test { distinct: false, }; let expected = TypedPlan { - typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]), + typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]), plan: Plan::Mfp { input: Box::new( Plan::Reduce { @@ -478,14 +766,14 @@ mod test { }), } .with_types(RelationType::new(vec![ColumnType::new( - CDT::uint32_datatype(), + CDT::uint64_datatype(), true, )])), ), mfp: MapFilterProject::new(1) - .map(vec![ScalarExpr::Column(0)]) + .map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]) .unwrap() - .project(vec![1]) + .project(vec![2]) .unwrap(), }, };