feat(flow): avg func rewrite to sum/count (#3955)

* feat(WIP): parse avg

* feat: RelationType::apply_mfp no need expr typs

* feat: avg&tests

* fix(WIP): avg eval

* fix: sum ret correct type

* chore: typos
This commit is contained in:
discord9
2024-05-16 18:03:56 +08:00
committed by GitHub
parent 9f4a6c6fe2
commit 93f178f3ad
8 changed files with 495 additions and 64 deletions

View File

@@ -238,6 +238,12 @@ mod test {
for now in time_range {
state.set_current_ts(now);
state.run_available_with_schedule(df);
if !state.get_err_collector().is_empty() {
panic!(
"Errors occur: {:?}",
state.get_err_collector().get_all_blocking()
)
}
assert!(state.get_err_collector().is_empty());
if let Some(expected) = expected.get(&now) {
assert_eq!(*output.borrow(), *expected, "at ts={}", now);

View File

@@ -729,15 +729,113 @@ mod test {
use std::cell::RefCell;
use std::rc::Rc;
use datatypes::data_type::ConcreteDataType;
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;
use super::*;
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
use crate::compute::state::DataflowState;
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject};
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
use crate::repr::{ColumnType, RelationType};
/// select avg(number) from number;
#[test]
fn test_avg_eval() {
let mut df = Hydroflow::new();
let mut state = DataflowState::default();
let mut ctx = harness_test_ctx(&mut df, &mut state);
let rows = vec![
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
];
let collection = ctx.render_constant(rows.clone());
ctx.insert_global(GlobalId::User(1), collection);
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::int64_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(2),
])
.unwrap()
.project(vec![3])
.unwrap(),
},
};
let bundle = ctx.render_plan(expected).unwrap();
let output = get_output_handle(&mut ctx, bundle);
drop(ctx);
let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]);
run_and_check(&mut state, &mut df, 1..2, expected, output);
}
/// SELECT DISTINCT col FROM table
///
/// table schema:

View File

@@ -153,6 +153,9 @@ pub struct ErrCollector {
}
impl ErrCollector {
pub fn get_all_blocking(&self) -> Vec<EvalError> {
self.inner.blocking_lock().drain(..).collect_vec()
}
pub async fn get_all(&self) -> Vec<EvalError> {
self.inner.lock().await.drain(..).collect_vec()
}

View File

@@ -375,6 +375,22 @@ impl BinaryFunc {
)
}
pub fn add(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Add, input_type)
}
pub fn sub(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Sub, input_type)
}
pub fn mul(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Mul, input_type)
}
pub fn div(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Div, input_type)
}
/// Get the specialization of the binary function based on the generic function and the input type
pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result<Self, Error> {
let rule = SPECIALIZATION.get_or_init(|| {

View File

@@ -136,27 +136,44 @@ impl AggregateFunc {
/// Generate signature for each aggregate function
macro_rules! generate_signature {
($value:ident, { $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
),*
]) => {
($value:ident,
{ $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($($arg:ident),*)
),*
]
) => {
match $value {
$($user_arm)*,
$(
Self::$auto_arm => Signature {
input: smallvec![
ConcreteDataType::$con_type(),
ConcreteDataType::$con_type(),
],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
},
Self::$auto_arm => gen_one_siginature!($($arg),*),
)*
}
};
}
/// Generate one match arm with optional arguments
macro_rules! gen_one_siginature {
(
$con_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
}
};
(
$in_type:ident, $out_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$in_type()],
output: ConcreteDataType::$out_type(),
generic_fn: GenericFn::$generic,
}
};
}
static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
OnceLock::new();
@@ -223,6 +240,8 @@ impl AggregateFunc {
/// all concrete datatypes with precision types will be returned with largest possible variant
/// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
///
/// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64
pub fn signature(&self) -> Signature {
generate_signature!(self, {
AggregateFunc::Count => Signature {
@@ -263,12 +282,12 @@ impl AggregateFunc {
MinTime => (time_second_datatype, Min),
MinDuration => (duration_second_datatype, Min),
MinInterval => (interval_year_month_datatype, Min),
SumInt16 => (int16_datatype, Sum),
SumInt32 => (int32_datatype, Sum),
SumInt64 => (int64_datatype, Sum),
SumUInt16 => (uint16_datatype, Sum),
SumUInt32 => (uint32_datatype, Sum),
SumUInt64 => (uint64_datatype, Sum),
SumInt16 => (int16_datatype, int64_datatype, Sum),
SumInt32 => (int32_datatype, int64_datatype, Sum),
SumInt64 => (int64_datatype, int64_datatype, Sum),
SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
SumFloat32 => (float32_datatype, Sum),
SumFloat64 => (float64_datatype, Sum),
Any => (boolean_datatype, Any),

View File

@@ -44,7 +44,7 @@ pub struct TypedPlan {
impl TypedPlan {
/// directly apply a mfp to the plan
pub fn mfp(self, mfp: MapFilterProject) -> Result<Self, Error> {
let new_type = self.typ.apply_mfp(&mfp, &[])?;
let new_type = self.typ.apply_mfp(&mfp)?;
let plan = match self.plan {
Plan::Mfp {
input,
@@ -68,14 +68,14 @@ impl TypedPlan {
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
let input_arity = self.typ.column_types.len();
let output_arity = exprs.len();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
let (exprs, _expr_typs): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let mfp = MapFilterProject::new(input_arity)
.map(exprs)?
.project(input_arity..input_arity + output_arity)?;
let out_typ = self.typ.apply_mfp(&mfp, &expr_typs)?;
let out_typ = self.typ.apply_mfp(&mfp)?;
// special case for mfp to compose when the plan is already mfp
let plan = match self.plan {
Plan::Mfp {

View File

@@ -111,13 +111,13 @@ impl RelationType {
/// then new key=`[1]`, new time index=`[0]`
///
/// note that this function will remove empty keys like key=`[]` will be removed
pub fn apply_mfp(&self, mfp: &MapFilterProject, expr_typs: &[ColumnType]) -> Result<Self> {
let all_types = self
.column_types
.iter()
.chain(expr_typs.iter())
.cloned()
.collect_vec();
pub fn apply_mfp(&self, mfp: &MapFilterProject) -> Result<Self> {
let mut all_types = self.column_types.clone();
for expr in &mfp.expressions {
let expr_typ = expr.typ(&self.column_types)?;
all_types.push(expr_typ);
}
let all_types = all_types;
let mfp_out_types = mfp
.projection
.iter()
@@ -131,6 +131,7 @@ impl RelationType {
})
})
.try_collect()?;
let old_to_new_col = BTreeMap::from_iter(
mfp.projection
.clone()

View File

@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use common_decimal::Decimal128;
use common_time::{Date, Timestamp};
use datatypes::arrow::compute::kernels::window;
use datatypes::arrow::ipc::Binary;
use datatypes::data_type::ConcreteDataType as CDT;
use datatypes::data_type::{ConcreteDataType as CDT, DataType};
use datatypes::value::Value;
use hydroflow::futures::future::Map;
use itertools::Itertools;
@@ -83,14 +83,18 @@ impl TypedExpr {
}
impl AggregateExpr {
/// Convert list of `Measure` into Flow's AggregateExpr
///
/// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
fn from_substrait_agg_measures(
ctx: &mut FlownodeContext,
measures: &[Measure],
typ: &RelationType,
extensions: &FunctionExtensions,
) -> Result<Vec<AggregateExpr>, Error> {
) -> Result<(Vec<AggregateExpr>, MapFilterProject), Error> {
let _ = ctx;
let mut aggr_exprs = vec![];
let mut all_aggr_exprs = vec![];
let mut post_maps = vec![];
for m in measures {
let filter = &m
@@ -99,7 +103,7 @@ impl AggregateExpr {
.map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
.transpose()?;
let agg_func = match &m.measure {
let (aggr_expr, post_mfp) = match &m.measure {
Some(f) => {
let distinct = match f.invocation {
_ if f.invocation == AggregationInvocation::Distinct as i32 => true,
@@ -113,12 +117,30 @@ impl AggregateExpr {
}
None => not_impl_err!("Aggregate without aggregate function is not supported"),
}?;
aggr_exprs.push(agg_func);
// permute col index refer to the output of post_mfp,
// so to help construct a mfp at the end
let mut post_map = post_mfp.unwrap_or(ScalarExpr::Column(0));
let cur_arity = all_aggr_exprs.len();
let remap = (0..aggr_expr.len()).map(|i| i + cur_arity).collect_vec();
post_map.permute(&remap)?;
all_aggr_exprs.extend(aggr_expr);
post_maps.push(post_map);
}
Ok(aggr_exprs)
let input_arity = all_aggr_exprs.len();
let aggr_arity = post_maps.len();
let post_mfp_final = MapFilterProject::new(all_aggr_exprs.len())
.map(post_maps)?
.project(input_arity..input_arity + aggr_arity)?;
Ok((all_aggr_exprs, post_mfp_final))
}
/// Convert AggregateFunction into Flow's AggregateExpr
///
/// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
/// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
pub fn from_substrait_agg_func(
f: &proto::AggregateFunction,
input_schema: &RelationType,
@@ -126,7 +148,7 @@ impl AggregateExpr {
filter: &Option<TypedExpr>,
order_by: &Option<Vec<TypedExpr>>,
distinct: bool,
) -> Result<AggregateExpr, Error> {
) -> Result<(Vec<AggregateExpr>, Option<ScalarExpr>), Error> {
// TODO(discord9): impl filter
let _ = filter;
let _ = order_by;
@@ -141,26 +163,74 @@ impl AggregateExpr {
args.push(arg_expr);
}
if args.len() != 1 {
return not_impl_err!("Aggregated function with multiple arguments is not supported");
}
let arg = if let Some(first) = args.first() {
first
} else {
return not_impl_err!("Aggregated function without arguments is not supported");
};
let func = match extensions.get(&f.function_reference) {
let fn_name = extensions
.get(&f.function_reference)
.cloned()
.map(|s| s.to_lowercase());
match fn_name.as_ref().map(|s| s.as_ref()) {
Some(Self::AVG_NAME) => AggregateExpr::from_avg_aggr_func(arg),
Some(function_name) => {
AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone()))
let func = AggregateFunc::from_str_and_type(
function_name,
Some(arg.typ.scalar_type.clone()),
)?;
let exprs = vec![AggregateExpr {
func,
expr: arg.expr.clone(),
distinct,
}];
let ret_mfp = None;
Ok((exprs, ret_mfp))
}
None => not_impl_err!(
"Aggregated function not found: function anchor = {:?}",
f.function_reference
),
}?;
Ok(AggregateExpr {
func,
}
}
const AVG_NAME: &'static str = "avg";
/// convert `avg` function into `sum(x)/cast(count(x) as x_type)`
fn from_avg_aggr_func(
arg: &TypedExpr,
) -> Result<(Vec<AggregateExpr>, Option<ScalarExpr>), Error> {
let arg_type = arg.typ.scalar_type.clone();
let sum = AggregateExpr {
func: AggregateFunc::from_str_and_type("sum", Some(arg_type.clone()))?,
expr: arg.expr.clone(),
distinct,
})
distinct: false,
};
let sum_out_type = sum.func.signature().output.clone();
let count = AggregateExpr {
func: AggregateFunc::Count,
expr: arg.expr.clone(),
distinct: false,
};
let count_out_type = count.func.signature().output.clone();
let avg_output = ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())),
BinaryFunc::div(sum_out_type.clone())?,
);
// make sure we wouldn't divide by zero
let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone());
let non_zero = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)),
then: Box::new(avg_output),
els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())),
};
let ret_aggr_exprs = vec![sum, count];
let ret_mfp = Some(non_zero);
Ok((ret_aggr_exprs, ret_mfp))
}
}
@@ -217,6 +287,10 @@ impl KeyValPlan {
impl TypedPlan {
/// Convert AggregateRel into Flow's TypedPlan
///
/// The output of aggr plan is:
///
/// <group_exprs>..<aggr_exprs>
pub fn from_substrait_agg_rel(
ctx: &mut FlownodeContext,
agg: &proto::AggregateRel,
@@ -231,7 +305,7 @@ impl TypedPlan {
let group_exprs =
TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?;
let mut aggr_exprs =
let (mut aggr_exprs, post_mfp) =
AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?;
let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
@@ -253,7 +327,11 @@ impl TypedPlan {
));
}
// TODO(discord9): try best to get time
RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
if group_exprs.is_empty() {
RelationType::new(output_types)
} else {
RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
}
};
// copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
@@ -289,10 +367,40 @@ impl TypedPlan {
key_val_plan,
reduce_plan: ReducePlan::Accumulable(accum_plan),
};
Ok(TypedPlan {
typ: output_type,
plan,
})
// FIX(discord9): deal with key first
if post_mfp.is_identity() {
Ok(TypedPlan {
typ: output_type,
plan,
})
} else {
// make post_mfp map identical mapping of keys
let input = TypedPlan {
typ: output_type.clone(),
plan,
};
let key_arity = group_exprs.len();
let mut post_mfp = post_mfp;
let val_arity = post_mfp.input_arity;
// offset post_mfp's col ref by `key_arity`
let shuffle = BTreeMap::from_iter((0..val_arity).map(|v| (v, v + key_arity)));
let new_arity = key_arity + val_arity;
post_mfp.permute(shuffle, new_arity)?;
// add key projection to post mfp
let (m, f, p) = post_mfp.into_map_filter_project();
let p = (0..key_arity).chain(p).collect_vec();
let post_mfp = MapFilterProject::new(new_arity)
.map(m)?
.filter(f)?
.project(p)?;
Ok(TypedPlan {
typ: output_type.apply_mfp(&post_mfp)?,
plan: Plan::Mfp {
input: Box::new(input),
mfp: post_mfp,
},
})
}
}
}
@@ -306,6 +414,182 @@ mod test {
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
#[tokio::test]
async fn test_avg_group_by() {
let engine = create_test_query_engine();
let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(2).call_binary(
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64
ColumnType::new(CDT::uint32_datatype(), false), // number
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
ColumnType::new(ConcreteDataType::int64_datatype(), true), // count
])
.with_key(vec![0]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
avg_expr, // col 3
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(3), // col 4
ScalarExpr::Column(0), // col 5
])
.unwrap()
.project(vec![4, 5])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_avg() {
let engine = create_test_query_engine();
let sql = "SELECT avg(number) FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint64_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(2),
])
.unwrap()
.project(vec![3])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_sum() {
let engine = create_test_query_engine();
@@ -315,7 +599,7 @@ mod test {
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let typ = RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
ConcreteDataType::uint64_datatype(),
true,
)]);
let aggr_expr = AggregateExpr {
@@ -324,7 +608,7 @@ mod test {
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
@@ -355,9 +639,9 @@ mod test {
.with_types(typ),
),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)])
.unwrap()
.project(vec![1])
.project(vec![2])
.unwrap(),
},
};
@@ -380,7 +664,7 @@ mod test {
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint32_datatype(), false), // col number
]),
plan: Plan::Mfp {
@@ -415,15 +699,19 @@ mod test {
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), false), // col number
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
])
.with_key(vec![0]),
),
),
mfp: MapFilterProject::new(2)
.map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
.map(vec![
ScalarExpr::Column(1),
ScalarExpr::Column(2),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![2, 3])
.project(vec![3, 4])
.unwrap(),
},
};
@@ -446,7 +734,7 @@ mod test {
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
@@ -478,14 +766,14 @@ mod test {
}),
}
.with_types(RelationType::new(vec![ColumnType::new(
CDT::uint32_datatype(),
CDT::uint64_datatype(),
true,
)])),
),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)])
.unwrap()
.project(vec![1])
.project(vec![2])
.unwrap(),
},
};