mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 14:30:43 +00:00
feat(flow): avg func rewrite to sum/count (#3955)
* feat(WIP): parse avg * feat: RelationType::apply_mfp no need expr typs * feat: avg&tests * fix(WIP): avg eval * fix: sum ret correct type * chore: typos
This commit is contained in:
@@ -238,6 +238,12 @@ mod test {
|
||||
for now in time_range {
|
||||
state.set_current_ts(now);
|
||||
state.run_available_with_schedule(df);
|
||||
if !state.get_err_collector().is_empty() {
|
||||
panic!(
|
||||
"Errors occur: {:?}",
|
||||
state.get_err_collector().get_all_blocking()
|
||||
)
|
||||
}
|
||||
assert!(state.get_err_collector().is_empty());
|
||||
if let Some(expected) = expected.get(&now) {
|
||||
assert_eq!(*output.borrow(), *expected, "at ts={}", now);
|
||||
|
||||
@@ -729,15 +729,113 @@ mod test {
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
|
||||
use hydroflow::scheduled::graph::Hydroflow;
|
||||
|
||||
use super::*;
|
||||
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
|
||||
use crate::compute::state::DataflowState;
|
||||
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject};
|
||||
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
|
||||
use crate::repr::{ColumnType, RelationType};
|
||||
|
||||
/// select avg(number) from number;
|
||||
#[test]
|
||||
fn test_avg_eval() {
|
||||
let mut df = Hydroflow::new();
|
||||
let mut state = DataflowState::default();
|
||||
let mut ctx = harness_test_ctx(&mut df, &mut state);
|
||||
|
||||
let rows = vec![
|
||||
(Row::new(vec![1u32.into()]), 1, 1),
|
||||
(Row::new(vec![2u32.into()]), 1, 1),
|
||||
(Row::new(vec![3u32.into()]), 1, 1),
|
||||
(Row::new(vec![1u32.into()]), 1, 1),
|
||||
(Row::new(vec![2u32.into()]), 1, 1),
|
||||
(Row::new(vec![3u32.into()]), 1, 1),
|
||||
];
|
||||
let collection = ctx.render_constant(rows.clone());
|
||||
ctx.insert_global(GlobalId::User(1), collection);
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::Count,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
];
|
||||
let avg_expr = ScalarExpr::If {
|
||||
cond: Box::new(ScalarExpr::Column(1).call_binary(
|
||||
ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()),
|
||||
BinaryFunc::NotEq,
|
||||
)),
|
||||
then: Box::new(ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
|
||||
BinaryFunc::DivUInt64,
|
||||
)),
|
||||
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(
|
||||
Plan::Reduce {
|
||||
input: Box::new(
|
||||
Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(1)),
|
||||
}
|
||||
.with_types(RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::int64_datatype(), false),
|
||||
])),
|
||||
),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.project(vec![])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.project(vec![0])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: aggr_exprs.clone(),
|
||||
simple_aggrs: vec![
|
||||
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
|
||||
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
|
||||
],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}
|
||||
.with_types(RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
|
||||
ColumnType::new(ConcreteDataType::int64_datatype(), true),
|
||||
])),
|
||||
),
|
||||
mfp: MapFilterProject::new(2)
|
||||
.map(vec![
|
||||
avg_expr,
|
||||
// TODO(discord9): optimize mfp so to remove indirect ref
|
||||
ScalarExpr::Column(2),
|
||||
])
|
||||
.unwrap()
|
||||
.project(vec![3])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
let bundle = ctx.render_plan(expected).unwrap();
|
||||
|
||||
let output = get_output_handle(&mut ctx, bundle);
|
||||
drop(ctx);
|
||||
let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]);
|
||||
run_and_check(&mut state, &mut df, 1..2, expected, output);
|
||||
}
|
||||
|
||||
/// SELECT DISTINCT col FROM table
|
||||
///
|
||||
/// table schema:
|
||||
|
||||
@@ -153,6 +153,9 @@ pub struct ErrCollector {
|
||||
}
|
||||
|
||||
impl ErrCollector {
|
||||
pub fn get_all_blocking(&self) -> Vec<EvalError> {
|
||||
self.inner.blocking_lock().drain(..).collect_vec()
|
||||
}
|
||||
pub async fn get_all(&self) -> Vec<EvalError> {
|
||||
self.inner.lock().await.drain(..).collect_vec()
|
||||
}
|
||||
|
||||
@@ -375,6 +375,22 @@ impl BinaryFunc {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add(input_type: ConcreteDataType) -> Result<Self, Error> {
|
||||
Self::specialization(GenericFn::Add, input_type)
|
||||
}
|
||||
|
||||
pub fn sub(input_type: ConcreteDataType) -> Result<Self, Error> {
|
||||
Self::specialization(GenericFn::Sub, input_type)
|
||||
}
|
||||
|
||||
pub fn mul(input_type: ConcreteDataType) -> Result<Self, Error> {
|
||||
Self::specialization(GenericFn::Mul, input_type)
|
||||
}
|
||||
|
||||
pub fn div(input_type: ConcreteDataType) -> Result<Self, Error> {
|
||||
Self::specialization(GenericFn::Div, input_type)
|
||||
}
|
||||
|
||||
/// Get the specialization of the binary function based on the generic function and the input type
|
||||
pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result<Self, Error> {
|
||||
let rule = SPECIALIZATION.get_or_init(|| {
|
||||
|
||||
@@ -136,27 +136,44 @@ impl AggregateFunc {
|
||||
|
||||
/// Generate signature for each aggregate function
|
||||
macro_rules! generate_signature {
|
||||
($value:ident, { $($user_arm:tt)* },
|
||||
[ $(
|
||||
$auto_arm:ident=>($con_type:ident,$generic:ident)
|
||||
),*
|
||||
]) => {
|
||||
($value:ident,
|
||||
{ $($user_arm:tt)* },
|
||||
[ $(
|
||||
$auto_arm:ident=>($($arg:ident),*)
|
||||
),*
|
||||
]
|
||||
) => {
|
||||
match $value {
|
||||
$($user_arm)*,
|
||||
$(
|
||||
Self::$auto_arm => Signature {
|
||||
input: smallvec![
|
||||
ConcreteDataType::$con_type(),
|
||||
ConcreteDataType::$con_type(),
|
||||
],
|
||||
output: ConcreteDataType::$con_type(),
|
||||
generic_fn: GenericFn::$generic,
|
||||
},
|
||||
Self::$auto_arm => gen_one_siginature!($($arg),*),
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Generate one match arm with optional arguments
|
||||
macro_rules! gen_one_siginature {
|
||||
(
|
||||
$con_type:ident, $generic:ident
|
||||
) => {
|
||||
Signature {
|
||||
input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
|
||||
output: ConcreteDataType::$con_type(),
|
||||
generic_fn: GenericFn::$generic,
|
||||
}
|
||||
};
|
||||
(
|
||||
$in_type:ident, $out_type:ident, $generic:ident
|
||||
) => {
|
||||
Signature {
|
||||
input: smallvec![ConcreteDataType::$in_type()],
|
||||
output: ConcreteDataType::$out_type(),
|
||||
generic_fn: GenericFn::$generic,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
|
||||
OnceLock::new();
|
||||
|
||||
@@ -223,6 +240,8 @@ impl AggregateFunc {
|
||||
|
||||
/// all concrete datatypes with precision types will be returned with largest possible variant
|
||||
/// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
|
||||
///
|
||||
/// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64
|
||||
pub fn signature(&self) -> Signature {
|
||||
generate_signature!(self, {
|
||||
AggregateFunc::Count => Signature {
|
||||
@@ -263,12 +282,12 @@ impl AggregateFunc {
|
||||
MinTime => (time_second_datatype, Min),
|
||||
MinDuration => (duration_second_datatype, Min),
|
||||
MinInterval => (interval_year_month_datatype, Min),
|
||||
SumInt16 => (int16_datatype, Sum),
|
||||
SumInt32 => (int32_datatype, Sum),
|
||||
SumInt64 => (int64_datatype, Sum),
|
||||
SumUInt16 => (uint16_datatype, Sum),
|
||||
SumUInt32 => (uint32_datatype, Sum),
|
||||
SumUInt64 => (uint64_datatype, Sum),
|
||||
SumInt16 => (int16_datatype, int64_datatype, Sum),
|
||||
SumInt32 => (int32_datatype, int64_datatype, Sum),
|
||||
SumInt64 => (int64_datatype, int64_datatype, Sum),
|
||||
SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
|
||||
SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
|
||||
SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
|
||||
SumFloat32 => (float32_datatype, Sum),
|
||||
SumFloat64 => (float64_datatype, Sum),
|
||||
Any => (boolean_datatype, Any),
|
||||
|
||||
@@ -44,7 +44,7 @@ pub struct TypedPlan {
|
||||
impl TypedPlan {
|
||||
/// directly apply a mfp to the plan
|
||||
pub fn mfp(self, mfp: MapFilterProject) -> Result<Self, Error> {
|
||||
let new_type = self.typ.apply_mfp(&mfp, &[])?;
|
||||
let new_type = self.typ.apply_mfp(&mfp)?;
|
||||
let plan = match self.plan {
|
||||
Plan::Mfp {
|
||||
input,
|
||||
@@ -68,14 +68,14 @@ impl TypedPlan {
|
||||
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
|
||||
let input_arity = self.typ.column_types.len();
|
||||
let output_arity = exprs.len();
|
||||
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
|
||||
let (exprs, _expr_typs): (Vec<_>, Vec<_>) = exprs
|
||||
.into_iter()
|
||||
.map(|TypedExpr { expr, typ }| (expr, typ))
|
||||
.unzip();
|
||||
let mfp = MapFilterProject::new(input_arity)
|
||||
.map(exprs)?
|
||||
.project(input_arity..input_arity + output_arity)?;
|
||||
let out_typ = self.typ.apply_mfp(&mfp, &expr_typs)?;
|
||||
let out_typ = self.typ.apply_mfp(&mfp)?;
|
||||
// special case for mfp to compose when the plan is already mfp
|
||||
let plan = match self.plan {
|
||||
Plan::Mfp {
|
||||
|
||||
@@ -111,13 +111,13 @@ impl RelationType {
|
||||
/// then new key=`[1]`, new time index=`[0]`
|
||||
///
|
||||
/// note that this function will remove empty keys like key=`[]` will be removed
|
||||
pub fn apply_mfp(&self, mfp: &MapFilterProject, expr_typs: &[ColumnType]) -> Result<Self> {
|
||||
let all_types = self
|
||||
.column_types
|
||||
.iter()
|
||||
.chain(expr_typs.iter())
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
pub fn apply_mfp(&self, mfp: &MapFilterProject) -> Result<Self> {
|
||||
let mut all_types = self.column_types.clone();
|
||||
for expr in &mfp.expressions {
|
||||
let expr_typ = expr.typ(&self.column_types)?;
|
||||
all_types.push(expr_typ);
|
||||
}
|
||||
let all_types = all_types;
|
||||
let mfp_out_types = mfp
|
||||
.projection
|
||||
.iter()
|
||||
@@ -131,6 +131,7 @@ impl RelationType {
|
||||
})
|
||||
})
|
||||
.try_collect()?;
|
||||
|
||||
let old_to_new_col = BTreeMap::from_iter(
|
||||
mfp.projection
|
||||
.clone()
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
||||
use common_decimal::Decimal128;
|
||||
use common_time::{Date, Timestamp};
|
||||
use datatypes::arrow::compute::kernels::window;
|
||||
use datatypes::arrow::ipc::Binary;
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use datatypes::data_type::{ConcreteDataType as CDT, DataType};
|
||||
use datatypes::value::Value;
|
||||
use hydroflow::futures::future::Map;
|
||||
use itertools::Itertools;
|
||||
@@ -83,14 +83,18 @@ impl TypedExpr {
|
||||
}
|
||||
|
||||
impl AggregateExpr {
|
||||
/// Convert list of `Measure` into Flow's AggregateExpr
|
||||
///
|
||||
/// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
|
||||
fn from_substrait_agg_measures(
|
||||
ctx: &mut FlownodeContext,
|
||||
measures: &[Measure],
|
||||
typ: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<Vec<AggregateExpr>, Error> {
|
||||
) -> Result<(Vec<AggregateExpr>, MapFilterProject), Error> {
|
||||
let _ = ctx;
|
||||
let mut aggr_exprs = vec![];
|
||||
let mut all_aggr_exprs = vec![];
|
||||
let mut post_maps = vec![];
|
||||
|
||||
for m in measures {
|
||||
let filter = &m
|
||||
@@ -99,7 +103,7 @@ impl AggregateExpr {
|
||||
.map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
|
||||
.transpose()?;
|
||||
|
||||
let agg_func = match &m.measure {
|
||||
let (aggr_expr, post_mfp) = match &m.measure {
|
||||
Some(f) => {
|
||||
let distinct = match f.invocation {
|
||||
_ if f.invocation == AggregationInvocation::Distinct as i32 => true,
|
||||
@@ -113,12 +117,30 @@ impl AggregateExpr {
|
||||
}
|
||||
None => not_impl_err!("Aggregate without aggregate function is not supported"),
|
||||
}?;
|
||||
aggr_exprs.push(agg_func);
|
||||
// permute col index refer to the output of post_mfp,
|
||||
// so to help construct a mfp at the end
|
||||
let mut post_map = post_mfp.unwrap_or(ScalarExpr::Column(0));
|
||||
let cur_arity = all_aggr_exprs.len();
|
||||
let remap = (0..aggr_expr.len()).map(|i| i + cur_arity).collect_vec();
|
||||
post_map.permute(&remap)?;
|
||||
|
||||
all_aggr_exprs.extend(aggr_expr);
|
||||
post_maps.push(post_map);
|
||||
}
|
||||
Ok(aggr_exprs)
|
||||
|
||||
let input_arity = all_aggr_exprs.len();
|
||||
let aggr_arity = post_maps.len();
|
||||
let post_mfp_final = MapFilterProject::new(all_aggr_exprs.len())
|
||||
.map(post_maps)?
|
||||
.project(input_arity..input_arity + aggr_arity)?;
|
||||
|
||||
Ok((all_aggr_exprs, post_mfp_final))
|
||||
}
|
||||
|
||||
/// Convert AggregateFunction into Flow's AggregateExpr
|
||||
///
|
||||
/// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
|
||||
/// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
|
||||
pub fn from_substrait_agg_func(
|
||||
f: &proto::AggregateFunction,
|
||||
input_schema: &RelationType,
|
||||
@@ -126,7 +148,7 @@ impl AggregateExpr {
|
||||
filter: &Option<TypedExpr>,
|
||||
order_by: &Option<Vec<TypedExpr>>,
|
||||
distinct: bool,
|
||||
) -> Result<AggregateExpr, Error> {
|
||||
) -> Result<(Vec<AggregateExpr>, Option<ScalarExpr>), Error> {
|
||||
// TODO(discord9): impl filter
|
||||
let _ = filter;
|
||||
let _ = order_by;
|
||||
@@ -141,26 +163,74 @@ impl AggregateExpr {
|
||||
args.push(arg_expr);
|
||||
}
|
||||
|
||||
if args.len() != 1 {
|
||||
return not_impl_err!("Aggregated function with multiple arguments is not supported");
|
||||
}
|
||||
|
||||
let arg = if let Some(first) = args.first() {
|
||||
first
|
||||
} else {
|
||||
return not_impl_err!("Aggregated function without arguments is not supported");
|
||||
};
|
||||
|
||||
let func = match extensions.get(&f.function_reference) {
|
||||
let fn_name = extensions
|
||||
.get(&f.function_reference)
|
||||
.cloned()
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
match fn_name.as_ref().map(|s| s.as_ref()) {
|
||||
Some(Self::AVG_NAME) => AggregateExpr::from_avg_aggr_func(arg),
|
||||
Some(function_name) => {
|
||||
AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone()))
|
||||
let func = AggregateFunc::from_str_and_type(
|
||||
function_name,
|
||||
Some(arg.typ.scalar_type.clone()),
|
||||
)?;
|
||||
let exprs = vec![AggregateExpr {
|
||||
func,
|
||||
expr: arg.expr.clone(),
|
||||
distinct,
|
||||
}];
|
||||
let ret_mfp = None;
|
||||
Ok((exprs, ret_mfp))
|
||||
}
|
||||
None => not_impl_err!(
|
||||
"Aggregated function not found: function anchor = {:?}",
|
||||
f.function_reference
|
||||
),
|
||||
}?;
|
||||
Ok(AggregateExpr {
|
||||
func,
|
||||
}
|
||||
}
|
||||
const AVG_NAME: &'static str = "avg";
|
||||
/// convert `avg` function into `sum(x)/cast(count(x) as x_type)`
|
||||
fn from_avg_aggr_func(
|
||||
arg: &TypedExpr,
|
||||
) -> Result<(Vec<AggregateExpr>, Option<ScalarExpr>), Error> {
|
||||
let arg_type = arg.typ.scalar_type.clone();
|
||||
let sum = AggregateExpr {
|
||||
func: AggregateFunc::from_str_and_type("sum", Some(arg_type.clone()))?,
|
||||
expr: arg.expr.clone(),
|
||||
distinct,
|
||||
})
|
||||
distinct: false,
|
||||
};
|
||||
let sum_out_type = sum.func.signature().output.clone();
|
||||
let count = AggregateExpr {
|
||||
func: AggregateFunc::Count,
|
||||
expr: arg.expr.clone(),
|
||||
distinct: false,
|
||||
};
|
||||
let count_out_type = count.func.signature().output.clone();
|
||||
let avg_output = ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(sum_out_type.clone())),
|
||||
BinaryFunc::div(sum_out_type.clone())?,
|
||||
);
|
||||
// make sure we wouldn't divide by zero
|
||||
let zero = ScalarExpr::literal(count_out_type.default_value(), count_out_type.clone());
|
||||
let non_zero = ScalarExpr::If {
|
||||
cond: Box::new(ScalarExpr::Column(1).call_binary(zero.clone(), BinaryFunc::NotEq)),
|
||||
then: Box::new(avg_output),
|
||||
els: Box::new(ScalarExpr::literal(Value::Null, sum_out_type.clone())),
|
||||
};
|
||||
let ret_aggr_exprs = vec![sum, count];
|
||||
let ret_mfp = Some(non_zero);
|
||||
Ok((ret_aggr_exprs, ret_mfp))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +287,10 @@ impl KeyValPlan {
|
||||
|
||||
impl TypedPlan {
|
||||
/// Convert AggregateRel into Flow's TypedPlan
|
||||
///
|
||||
/// The output of aggr plan is:
|
||||
///
|
||||
/// <group_exprs>..<aggr_exprs>
|
||||
pub fn from_substrait_agg_rel(
|
||||
ctx: &mut FlownodeContext,
|
||||
agg: &proto::AggregateRel,
|
||||
@@ -231,7 +305,7 @@ impl TypedPlan {
|
||||
let group_exprs =
|
||||
TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?;
|
||||
|
||||
let mut aggr_exprs =
|
||||
let (mut aggr_exprs, post_mfp) =
|
||||
AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?;
|
||||
|
||||
let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
|
||||
@@ -253,7 +327,11 @@ impl TypedPlan {
|
||||
));
|
||||
}
|
||||
// TODO(discord9): try best to get time
|
||||
RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
|
||||
if group_exprs.is_empty() {
|
||||
RelationType::new(output_types)
|
||||
} else {
|
||||
RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
|
||||
}
|
||||
};
|
||||
|
||||
// copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
|
||||
@@ -289,10 +367,40 @@ impl TypedPlan {
|
||||
key_val_plan,
|
||||
reduce_plan: ReducePlan::Accumulable(accum_plan),
|
||||
};
|
||||
Ok(TypedPlan {
|
||||
typ: output_type,
|
||||
plan,
|
||||
})
|
||||
// FIX(discord9): deal with key first
|
||||
if post_mfp.is_identity() {
|
||||
Ok(TypedPlan {
|
||||
typ: output_type,
|
||||
plan,
|
||||
})
|
||||
} else {
|
||||
// make post_mfp map identical mapping of keys
|
||||
let input = TypedPlan {
|
||||
typ: output_type.clone(),
|
||||
plan,
|
||||
};
|
||||
let key_arity = group_exprs.len();
|
||||
let mut post_mfp = post_mfp;
|
||||
let val_arity = post_mfp.input_arity;
|
||||
// offset post_mfp's col ref by `key_arity`
|
||||
let shuffle = BTreeMap::from_iter((0..val_arity).map(|v| (v, v + key_arity)));
|
||||
let new_arity = key_arity + val_arity;
|
||||
post_mfp.permute(shuffle, new_arity)?;
|
||||
// add key projection to post mfp
|
||||
let (m, f, p) = post_mfp.into_map_filter_project();
|
||||
let p = (0..key_arity).chain(p).collect_vec();
|
||||
let post_mfp = MapFilterProject::new(new_arity)
|
||||
.map(m)?
|
||||
.filter(f)?
|
||||
.project(p)?;
|
||||
Ok(TypedPlan {
|
||||
typ: output_type.apply_mfp(&post_mfp)?,
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(input),
|
||||
mfp: post_mfp,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,6 +414,182 @@ mod test {
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_avg_group_by() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::Count,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
];
|
||||
let avg_expr = ScalarExpr::If {
|
||||
cond: Box::new(ScalarExpr::Column(2).call_binary(
|
||||
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
|
||||
BinaryFunc::NotEq,
|
||||
)),
|
||||
then: Box::new(ScalarExpr::Column(1).call_binary(
|
||||
ScalarExpr::Column(2).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
|
||||
BinaryFunc::DivUInt64,
|
||||
)),
|
||||
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![
|
||||
ColumnType::new(CDT::uint64_datatype(), true), // sum(number) -> u64
|
||||
ColumnType::new(CDT::uint32_datatype(), false), // number
|
||||
]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(
|
||||
Plan::Reduce {
|
||||
input: Box::new(
|
||||
Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}
|
||||
.with_types(RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
|
||||
])),
|
||||
),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.project(vec![0])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: aggr_exprs.clone(),
|
||||
simple_aggrs: vec![
|
||||
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
|
||||
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
|
||||
],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}
|
||||
.with_types(
|
||||
RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
|
||||
ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
|
||||
ColumnType::new(ConcreteDataType::int64_datatype(), true), // count
|
||||
])
|
||||
.with_key(vec![0]),
|
||||
),
|
||||
),
|
||||
mfp: MapFilterProject::new(3)
|
||||
.map(vec![
|
||||
avg_expr, // col 3
|
||||
// TODO(discord9): optimize mfp so to remove indirect ref
|
||||
ScalarExpr::Column(3), // col 4
|
||||
ScalarExpr::Column(0), // col 5
|
||||
])
|
||||
.unwrap()
|
||||
.project(vec![4, 5])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_avg() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT avg(number) FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
AggregateExpr {
|
||||
func: AggregateFunc::Count,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
},
|
||||
];
|
||||
let avg_expr = ScalarExpr::If {
|
||||
cond: Box::new(ScalarExpr::Column(1).call_binary(
|
||||
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
|
||||
BinaryFunc::NotEq,
|
||||
)),
|
||||
then: Box::new(ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
|
||||
BinaryFunc::DivUInt64,
|
||||
)),
|
||||
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(
|
||||
Plan::Reduce {
|
||||
input: Box::new(
|
||||
Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}
|
||||
.with_types(RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
|
||||
])),
|
||||
),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.project(vec![])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.project(vec![0])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: aggr_exprs.clone(),
|
||||
simple_aggrs: vec![
|
||||
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
|
||||
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
|
||||
],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}
|
||||
.with_types(RelationType::new(vec![
|
||||
ColumnType::new(ConcreteDataType::uint64_datatype(), true),
|
||||
ColumnType::new(ConcreteDataType::int64_datatype(), true),
|
||||
])),
|
||||
),
|
||||
mfp: MapFilterProject::new(2)
|
||||
.map(vec![
|
||||
avg_expr,
|
||||
// TODO(discord9): optimize mfp so to remove indirect ref
|
||||
ScalarExpr::Column(2),
|
||||
])
|
||||
.unwrap()
|
||||
.project(vec![3])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum() {
|
||||
let engine = create_test_query_engine();
|
||||
@@ -315,7 +599,7 @@ mod test {
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let typ = RelationType::new(vec![ColumnType::new(
|
||||
ConcreteDataType::uint32_datatype(),
|
||||
ConcreteDataType::uint64_datatype(),
|
||||
true,
|
||||
)]);
|
||||
let aggr_expr = AggregateExpr {
|
||||
@@ -324,7 +608,7 @@ mod test {
|
||||
distinct: false,
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(
|
||||
Plan::Reduce {
|
||||
@@ -355,9 +639,9 @@ mod test {
|
||||
.with_types(typ),
|
||||
),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.project(vec![2])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
@@ -380,7 +664,7 @@ mod test {
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![
|
||||
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
|
||||
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
|
||||
ColumnType::new(CDT::uint32_datatype(), false), // col number
|
||||
]),
|
||||
plan: Plan::Mfp {
|
||||
@@ -415,15 +699,19 @@ mod test {
|
||||
.with_types(
|
||||
RelationType::new(vec![
|
||||
ColumnType::new(CDT::uint32_datatype(), false), // col number
|
||||
ColumnType::new(CDT::uint32_datatype(), true), // col sum(number)
|
||||
ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
|
||||
])
|
||||
.with_key(vec![0]),
|
||||
),
|
||||
),
|
||||
mfp: MapFilterProject::new(2)
|
||||
.map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
|
||||
.map(vec![
|
||||
ScalarExpr::Column(1),
|
||||
ScalarExpr::Column(2),
|
||||
ScalarExpr::Column(0),
|
||||
])
|
||||
.unwrap()
|
||||
.project(vec![2, 3])
|
||||
.project(vec![3, 4])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
@@ -446,7 +734,7 @@ mod test {
|
||||
distinct: false,
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(
|
||||
Plan::Reduce {
|
||||
@@ -478,14 +766,14 @@ mod test {
|
||||
}),
|
||||
}
|
||||
.with_types(RelationType::new(vec![ColumnType::new(
|
||||
CDT::uint32_datatype(),
|
||||
CDT::uint64_datatype(),
|
||||
true,
|
||||
)])),
|
||||
),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.map(vec![ScalarExpr::Column(0), ScalarExpr::Column(1)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.project(vec![2])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user