feat(flow): tumble window func (#3968)

* feat(WIP): tumble window rewrite parser

* tests: tumble func

* feat: add `update_at` column for all flow output

* chore: cleanup per review

* fix: update_at not as time index

* fix: demo tumble

* fix: tests&tumble signature&accept both ts&datetime

* refactor: update_at now ts millis type

* chore: per review advices
This commit is contained in:
discord9
2024-05-17 20:10:28 +08:00
committed by GitHub
parent 9baa431656
commit 3477fde0e5
15 changed files with 816 additions and 91 deletions

3
Cargo.lock generated
View File

@@ -3835,8 +3835,11 @@ dependencies = [
"common-decimal",
"common-error",
"common-frontend",
"common-function",
"common-macro",
"common-meta",
"common-query",
"common-recordbatch",
"common-runtime",
"common-telemetry",
"common-time",

View File

@@ -119,12 +119,11 @@ impl CreateFlowProcedure {
&sink_table_name.table_name,
))
.await?;
ensure!(
!exists,
error::TableAlreadyExistsSnafu {
table_name: sink_table_name.to_string(),
}
);
// TODO(discord9): due to undefined behavior in flow's plan in how to transform types in mfp, sometime flow can't deduce correct schema
// and require manually create sink table
if exists {
common_telemetry::warn!("Table already exists, table: {}", sink_table_name);
}
self.collect_source_tables().await?;
self.allocate_flow_id().await?;

View File

@@ -26,7 +26,10 @@ futures = "0.3"
# This fork is simply for keeping our dependency in our org, and pin the version
# it is the same with upstream repo
async-trait.workspace = true
common-function.workspace = true
common-meta.workspace = true
common-query.workspace = true
common-recordbatch.workspace = true
enum-as-inner = "0.6.0"
greptime-proto.workspace = true
hydroflow = { git = "https://github.com/GreptimeTeam/hydroflow.git", branch = "main" }

View File

@@ -18,7 +18,7 @@
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Instant, SystemTime};
use api::v1::{RowDeleteRequest, RowDeleteRequests, RowInsertRequest, RowInsertRequests};
use catalog::CatalogManagerRef;
@@ -49,7 +49,7 @@ use crate::adapter::worker::{create_worker, Worker, WorkerHandle};
use crate::compute::ErrCollector;
use crate::expr::GlobalId;
use crate::repr::{self, DiffRow, Row};
use crate::transform::sql_to_flow_plan;
use crate::transform::{register_function_to_query_engine, sql_to_flow_plan};
pub(crate) mod error;
mod flownode_impl;
@@ -120,6 +120,8 @@ impl FlownodeBuilder {
);
let query_engine = query_engine_factory.query_engine();
register_function_to_query_engine(&query_engine);
let (tx, rx) = oneshot::channel();
let node_id = Some(self.flow_node_id);
@@ -261,7 +263,7 @@ impl FlownodeManager {
let ctx = Arc::new(QueryContext::with(&catalog, &schema));
// TODO(discord9): instead of auto build table from request schema, actually build table
// before `create flow` to be able to assign pk and ts etc.
let (primary_keys, schema) = if let Some(table_id) = self
let (primary_keys, schema, is_auto_create) = if let Some(table_id) = self
.table_info_source
.get_table_id_from_name(&table_name)
.await?
@@ -278,54 +280,65 @@ impl FlownodeManager {
.map(|i| meta.schema.column_schemas[i].name.clone())
.collect_vec();
let schema = meta.schema.column_schemas;
(primary_keys, schema)
let is_auto_create = schema
.last()
.map(|s| s.name == "__ts_placeholder")
.unwrap_or(false);
(primary_keys, schema, is_auto_create)
} else {
// TODO(discord9): get ts column from `RelationType` once we are done rewriting flow plan to attach ts
let (primary_keys, schema) = {
let node_ctx = self.node_context.lock().await;
let gid: GlobalId = node_ctx
.table_repr
.get_by_name(&table_name)
.map(|x| x.1)
.unwrap();
let schema = node_ctx
.schema
.get(&gid)
.with_context(|| TableNotFoundSnafu {
name: format!("Table name = {:?}", table_name),
})?
.clone();
// TODO(discord9): use default key from schema
let primary_keys = schema
.keys
.first()
.map(|v| {
v.column_indices
.iter()
.map(|i| format!("Col_{i}"))
.collect_vec()
})
.unwrap_or_default();
let ts_col = ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
)
.with_time_index(true);
// TODO(discord9): condiser remove buggy auto create by schema
let wout_ts = schema
.column_types
.into_iter()
.enumerate()
.map(|(idx, typ)| {
ColumnSchema::new(format!("Col_{idx}"), typ.scalar_type, typ.nullable)
})
.collect_vec();
let mut with_ts = wout_ts.clone();
with_ts.push(ts_col);
(primary_keys, with_ts)
};
(primary_keys, schema)
let node_ctx = self.node_context.lock().await;
let gid: GlobalId = node_ctx
.table_repr
.get_by_name(&table_name)
.map(|x| x.1)
.unwrap();
let schema = node_ctx
.schema
.get(&gid)
.with_context(|| TableNotFoundSnafu {
name: format!("Table name = {:?}", table_name),
})?
.clone();
// TODO(discord9): use default key from schema
let primary_keys = schema
.keys
.first()
.map(|v| {
v.column_indices
.iter()
.map(|i| format!("Col_{i}"))
.collect_vec()
})
.unwrap_or_default();
let update_at = ColumnSchema::new(
"update_at",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
);
// TODO(discord9): bugged so we can't infer time index from flow plan, so we have to manually set one
let ts_col = ColumnSchema::new(
"__ts_placeholder",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
)
.with_time_index(true);
let wout_ts = schema
.column_types
.into_iter()
.enumerate()
.map(|(idx, typ)| {
ColumnSchema::new(format!("Col_{idx}"), typ.scalar_type, typ.nullable)
})
.collect_vec();
let mut with_ts = wout_ts.clone();
with_ts.push(update_at);
with_ts.push(ts_col);
(primary_keys, with_ts, true)
};
let proto_schema = column_schemas_to_proto(schema, &primary_keys)?;
@@ -336,16 +349,32 @@ impl FlownodeManager {
table_name.join("."),
reqs
);
let now = SystemTime::now();
let now = now
.duration_since(SystemTime::UNIX_EPOCH)
.map(|s| s.as_millis() as repr::Timestamp)
.unwrap_or_else(|_| {
-(SystemTime::UNIX_EPOCH
.duration_since(now)
.unwrap()
.as_millis() as repr::Timestamp)
});
for req in reqs {
match req {
DiffRequest::Insert(insert) => {
let rows_proto: Vec<v1::Row> = insert
.into_iter()
.map(|(mut row, _ts)| {
row.extend(Some(Value::from(
common_time::Timestamp::new_millisecond(0),
)));
// `update_at` col
row.extend([Value::from(common_time::Timestamp::new_millisecond(
now,
))]);
// ts col, if auto create
if is_auto_create {
row.extend([Value::from(
common_time::Timestamp::new_millisecond(0),
)]);
}
row.into()
})
.collect::<Vec<_>>();

View File

@@ -30,7 +30,7 @@ use crate::expr::GlobalId;
use crate::repr::{DiffRow, RelationType, BROADCAST_CAP};
/// A context that holds the information of the dataflow
#[derive(Default)]
#[derive(Default, Debug)]
pub struct FlownodeContext {
/// mapping from source table to tasks, useful for schedule which task to run when a source table is updated
pub source_to_tasks: BTreeMap<TableId, BTreeSet<FlowId>>,
@@ -64,6 +64,7 @@ pub struct FlownodeContext {
///
/// receiver still use tokio broadcast channel, since only sender side need to know
/// backpressure and adjust dataflow running duration to avoid blocking
#[derive(Debug)]
pub struct SourceSender {
sender: broadcast::Sender<DiffRow>,
send_buf: VecDeque<DiffRow>,

View File

@@ -223,11 +223,11 @@ mod test {
use hydroflow::scheduled::graph::Hydroflow;
use hydroflow::scheduled::graph_ext::GraphExt;
use hydroflow::scheduled::handoff::VecHandoff;
use pretty_assertions::{assert_eq, assert_ne};
use super::*;
use crate::expr::BinaryFunc;
use crate::repr::Row;
pub fn run_and_check(
state: &mut DataflowState,
df: &mut Hydroflow,

View File

@@ -739,6 +739,7 @@ mod test {
use std::cell::RefCell;
use std::rc::Rc;
use common_time::{DateTime, Interval, Timestamp};
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;
@@ -748,6 +749,165 @@ mod test {
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
use crate::repr::{ColumnType, RelationType};
/// SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00')
/// input table columns: number, ts
/// expected: sum(number), window_start, window_end
#[test]
fn test_tumble_group_by() {
let mut df = Hydroflow::new();
let mut state = DataflowState::default();
let mut ctx = harness_test_ctx(&mut df, &mut state);
const START: i64 = 1625097600000;
let rows = vec![
(1u32, START + 1000),
(2u32, START + 1500),
(3u32, START + 2000),
(1u32, START + 2500),
(2u32, START + 3000),
(3u32, START + 3500),
];
let rows = rows
.into_iter()
.map(|(number, ts)| {
(
Row::new(vec![number.into(), Timestamp::new_millisecond(ts).into()]),
1,
1,
)
})
.collect_vec();
let collection = ctx.render_constant(rows.clone());
ctx.insert_global(GlobalId::User(1), collection);
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
]),
// TODO(discord9): mfp indirectly ref to key columns
/*
.with_key(vec![1])
.with_time_index(Some(0)),*/
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(ConcreteDataType::datetime_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.project(vec![0, 1])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0)),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(3),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![4, 5, 6])
.unwrap(),
},
};
let bundle = ctx.render_plan(expected).unwrap();
let output = get_output_handle(&mut ctx, bundle);
drop(ctx);
let expected = BTreeMap::from([(
1,
vec![
(
Row::new(vec![
3u64.into(),
Timestamp::new_millisecond(START + 1000).into(),
Timestamp::new_millisecond(START + 2000).into(),
]),
1,
1,
),
(
Row::new(vec![
4u64.into(),
Timestamp::new_millisecond(START + 2000).into(),
Timestamp::new_millisecond(START + 3000).into(),
]),
1,
1,
),
(
Row::new(vec![
5u64.into(),
Timestamp::new_millisecond(START + 3000).into(),
Timestamp::new_millisecond(START + 4000).into(),
]),
1,
1,
),
],
)]);
run_and_check(&mut state, &mut df, 1..2, expected, output);
}
/// select avg(number) from number;
#[test]
fn test_avg_eval() {

View File

@@ -17,8 +17,10 @@
use std::collections::HashMap;
use std::sync::OnceLock;
use common_error::ext::BoxedError;
use common_telemetry::debug;
use common_time::DateTime;
use common_time::timestamp::TimeUnit;
use common_time::{DateTime, Timestamp};
use datafusion_expr::Operator;
use datatypes::data_type::ConcreteDataType;
use datatypes::types::cast;
@@ -30,14 +32,14 @@ use snafu::{ensure, OptionExt, ResultExt};
use strum::{EnumIter, IntoEnumIterator};
use substrait::df_logical_plan::consumer::name_to_op;
use crate::adapter::error::{Error, InvalidQuerySnafu, PlanSnafu};
use crate::adapter::error::{Error, ExternalSnafu, InvalidQuerySnafu, PlanSnafu};
use crate::expr::error::{
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu,
TypeMismatchSnafu,
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, OverflowSnafu,
TryFromValueSnafu, TypeMismatchSnafu,
};
use crate::expr::signature::{GenericFn, Signature};
use crate::expr::{InvalidArgumentSnafu, ScalarExpr};
use crate::repr::{value_to_internal_ts, Row};
use crate::expr::{InvalidArgumentSnafu, ScalarExpr, TypedExpr};
use crate::repr::{self, value_to_internal_ts, Row};
/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
@@ -45,6 +47,11 @@ use crate::repr::{value_to_internal_ts, Row};
pub enum UnmaterializableFunc {
Now,
CurrentSchema,
TumbleWindow {
ts: Box<TypedExpr>,
window_size: common_time::Interval,
start_time: Option<DateTime>,
},
}
impl UnmaterializableFunc {
@@ -61,14 +68,51 @@ impl UnmaterializableFunc {
output: ConcreteDataType::string_datatype(),
generic_fn: GenericFn::CurrentSchema,
},
Self::TumbleWindow { .. } => Signature {
input: smallvec![ConcreteDataType::timestamp_millisecond_datatype()],
output: ConcreteDataType::timestamp_millisecond_datatype(),
generic_fn: GenericFn::TumbleWindow,
},
}
}
/// Create a UnmaterializableFunc from a string of the function name
pub fn from_str(name: &str) -> Result<Self, Error> {
match name {
pub fn from_str_args(name: &str, args: Vec<TypedExpr>) -> Result<Self, Error> {
match name.to_lowercase().as_str() {
"now" => Ok(Self::Now),
"current_schema" => Ok(Self::CurrentSchema),
"tumble" => {
let ts = args.first().context(InvalidQuerySnafu {
reason: "Tumble window function requires a timestamp argument",
})?;
let window_size = args
.get(1)
.and_then(|expr| expr.expr.as_literal())
.context(InvalidQuerySnafu {
reason: "Tumble window function requires a window size argument"
})?.as_string() // TODO(discord9): since df to substrait convertor does not support interval type yet, we need to take a string and cast it to interval instead
.map(|s|cast(Value::from(s), &ConcreteDataType::interval_month_day_nano_datatype())).transpose().map_err(BoxedError::new).context(
ExternalSnafu
)?.and_then(|v|v.as_interval())
.with_context(||InvalidQuerySnafu {
reason: format!("Tumble window function requires window size argument to be a string describe a interval, found {:?}", args.get(1))
})?;
let start_time = match args.get(2) {
Some(start_time) => start_time.expr.as_literal(),
None => None,
}
.map(|s| cast(s.clone(), &ConcreteDataType::datetime_datatype())).transpose().map_err(BoxedError::new).context(ExternalSnafu)?.map(|v|v.as_datetime().with_context(
||InvalidQuerySnafu {
reason: format!("Tumble window function requires start time argument to be a datetime describe in string, found {:?}", args.get(2))
}
)).transpose()?;
Ok(Self::TumbleWindow {
ts: Box::new(ts.clone()),
window_size,
start_time,
})
}
_ => InvalidQuerySnafu {
reason: format!("Unknown unmaterializable function: {}", name),
}
@@ -87,6 +131,14 @@ pub enum UnaryFunc {
IsFalse,
StepTimestamp,
Cast(ConcreteDataType),
TumbleWindowFloor {
window_size: common_time::Interval,
start_time: Option<DateTime>,
},
TumbleWindowCeiling {
window_size: common_time::Interval,
start_time: Option<DateTime>,
},
}
impl UnaryFunc {
@@ -118,6 +170,16 @@ impl UnaryFunc {
output: to.clone(),
generic_fn: GenericFn::Cast,
},
Self::TumbleWindowFloor { .. } => Signature {
input: smallvec![ConcreteDataType::timestamp_millisecond_datatype()],
output: ConcreteDataType::timestamp_millisecond_datatype(),
generic_fn: GenericFn::TumbleWindow,
},
Self::TumbleWindowCeiling { .. } => Signature {
input: smallvec![ConcreteDataType::timestamp_millisecond_datatype()],
output: ConcreteDataType::timestamp_millisecond_datatype(),
generic_fn: GenericFn::TumbleWindow,
},
}
}
@@ -211,10 +273,51 @@ impl UnaryFunc {
debug!("Cast to type: {to:?}, result: {:?}", res);
res
}
Self::TumbleWindowFloor {
window_size,
start_time,
} => {
let ts = get_ts_as_millisecond(arg)?;
let start_time = start_time.map(|t| t.val()).unwrap_or(0);
let window_size = (window_size.to_nanosecond() / 1_000_000) as repr::Duration; // nanosecond to millisecond
let window_start = start_time + (ts - start_time) / window_size * window_size;
let ret = Timestamp::new_millisecond(window_start);
Ok(Value::from(ret))
}
Self::TumbleWindowCeiling {
window_size,
start_time,
} => {
let ts = get_ts_as_millisecond(arg)?;
let start_time = start_time.map(|t| t.val()).unwrap_or(0);
let window_size = (window_size.to_nanosecond() / 1_000_000) as repr::Duration; // nanosecond to millisecond
let window_start = start_time + (ts - start_time) / window_size * window_size;
let window_end = window_start + window_size;
let ret = Timestamp::new_millisecond(window_end);
Ok(Value::from(ret))
}
}
}
}
fn get_ts_as_millisecond(arg: Value) -> Result<repr::Timestamp, EvalError> {
let ts = if let Some(ts) = arg.as_timestamp() {
ts.convert_to(TimeUnit::Millisecond)
.context(OverflowSnafu)?
.value()
} else if let Some(ts) = arg.as_datetime() {
ts.val()
} else {
InvalidArgumentSnafu {
reason: "Expect input to be timestamp or datetime type",
}
.fail()?
};
Ok(ts)
}
/// BinaryFunc is a function that takes two arguments.
/// Also notice this enum doesn't contain function arguments, since the arguments are stored in the expression.
///

View File

@@ -26,10 +26,10 @@ use crate::adapter::error::{
};
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
use crate::repr::ColumnType;
use crate::repr::{ColumnType, RelationType};
/// A scalar expression with a known type.
#[derive(Debug, Clone)]
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct TypedExpr {
/// The expression.
pub expr: ScalarExpr,
@@ -43,7 +43,73 @@ impl TypedExpr {
}
}
/// TODO(discord9): add tumble function here
impl TypedExpr {
/// expand multi-value expression to multiple expressions with new indices
pub fn expand_multi_value(
input_typ: &RelationType,
exprs: &[TypedExpr],
) -> Result<Vec<TypedExpr>, Error> {
// old indices in mfp, expanded expr
let mut ret = vec![];
let input_arity = input_typ.column_types.len();
for (old_idx, expr) in exprs.iter().enumerate() {
if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::TumbleWindow {
ts,
window_size,
start_time,
}) = &expr.expr
{
let floor = UnaryFunc::TumbleWindowFloor {
window_size: *window_size,
start_time: *start_time,
};
let ceil = UnaryFunc::TumbleWindowCeiling {
window_size: *window_size,
start_time: *start_time,
};
let floor = ScalarExpr::CallUnary {
func: floor,
expr: Box::new(ts.expr.clone()),
}
.with_type(ts.typ.clone());
ret.push((None, floor));
let ceil = ScalarExpr::CallUnary {
func: ceil,
expr: Box::new(ts.expr.clone()),
}
.with_type(ts.typ.clone());
ret.push((None, ceil));
} else {
ret.push((Some(input_arity + old_idx), expr.clone()))
}
}
// get shuffled index(old_idx -> new_idx)
// note index is offset by input_arity because mfp is designed to be first include input columns then intermediate columns
let shuffle = ret
.iter()
.map(|(old_idx, _)| *old_idx) // [Option<opt_idx>]
.enumerate()
.map(|(new, old)| (old, new + input_arity))
.flat_map(|(old, new)| old.map(|o| (o, new)))
.chain((0..input_arity).map(|i| (i, i))) // also remember to chain the input columns as not changed
.collect::<BTreeMap<_, _>>();
// shuffle expr's index
let exprs = ret
.into_iter()
.map(|(_, mut expr)| {
// invariant: it is expect that no expr will try to refer the column being expanded
expr.expr.permute_map(&shuffle)?;
Ok(expr)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(dbg!(exprs))
}
}
/// A scalar expression, which can be evaluated to a value.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ScalarExpr {
@@ -84,6 +150,10 @@ pub enum ScalarExpr {
}
impl ScalarExpr {
pub fn with_type(self, typ: ColumnType) -> TypedExpr {
TypedExpr::new(self, typ)
}
/// try to determine the type of the expression
pub fn typ(&self, context: &[ColumnType]) -> Result<ColumnType, Error> {
match self {

View File

@@ -64,4 +64,5 @@ pub enum GenericFn {
// unmaterized func
Now,
CurrentSchema,
TumbleWindow,
}

View File

@@ -206,6 +206,15 @@ impl RelationType {
self
}
/// will also remove time index from keys if it's in keys
pub fn with_time_index(mut self, time_index: Option<usize>) -> Self {
self.time_index = time_index;
for key in &mut self.keys {
key.remove_col(time_index.unwrap_or(usize::MAX));
}
self
}
/// Computes the number of columns in the relation.
pub fn arity(&self) -> usize {
self.column_types.len()

View File

@@ -130,12 +130,60 @@ pub async fn sql_to_flow_plan(
Ok(flow_plan)
}
/// register flow-specific functions to the query engine
pub fn register_function_to_query_engine(engine: &Arc<dyn QueryEngine>) {
engine.register_function(Arc::new(TumbleFunction {}));
}
#[derive(Debug)]
pub struct TumbleFunction {}
const TUMBLE_NAME: &str = "tumble";
impl std::fmt::Display for TumbleFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", TUMBLE_NAME.to_ascii_uppercase())
}
}
impl common_function::function::Function for TumbleFunction {
fn name(&self) -> &str {
TUMBLE_NAME
}
fn return_type(&self, _input_types: &[CDT]) -> common_query::error::Result<CDT> {
Ok(CDT::datetime_datatype())
}
fn signature(&self) -> common_query::prelude::Signature {
common_query::prelude::Signature::variadic_any(common_query::prelude::Volatility::Immutable)
}
fn eval(
&self,
_func_ctx: common_function::function::FunctionContext,
_columns: &[datatypes::prelude::VectorRef],
) -> common_query::error::Result<datatypes::prelude::VectorRef> {
UnexpectedSnafu {
reason: "Tumbler function is not implemented for datafusion executor",
}
.fail()
.map_err(BoxedError::new)
.context(common_query::error::ExecuteSnafu)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use catalog::RegisterTableRequest;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
use common_time::{Date, DateTime};
use datatypes::prelude::*;
use datatypes::schema::Schema;
use datatypes::vectors::VectorRef;
use itertools::Itertools;
use prost::Message;
use query::parser::QueryLanguageParser;
use query::plan::LogicalPlan;
@@ -144,23 +192,45 @@ mod test {
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use substrait_proto::proto;
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
use table::test_util::MemTable;
use super::*;
use crate::adapter::node_context::IdToNameMap;
use crate::repr::ColumnType;
pub fn create_test_ctx() -> FlownodeContext {
let gid = GlobalId::User(0);
let name = [
"greptime".to_string(),
"public".to_string(),
"numbers".to_string(),
];
let schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
let mut schemas = HashMap::new();
let mut tri_map = IdToNameMap::new();
tri_map.insert(Some(name.clone()), Some(0), gid);
{
let gid = GlobalId::User(0);
let name = [
"greptime".to_string(),
"public".to_string(),
"numbers".to_string(),
];
let schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
tri_map.insert(Some(name.clone()), Some(1024), gid);
schemas.insert(gid, schema);
}
{
let gid = GlobalId::User(1);
let name = [
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
];
let schema = RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), false),
ColumnType::new(CDT::datetime_datatype(), false),
]);
schemas.insert(gid, schema);
tri_map.insert(Some(name.clone()), Some(1025), gid);
}
FlownodeContext {
schema: HashMap::from([(gid, schema)]),
schema: schemas,
table_repr: tri_map,
query_context: Some(Arc::new(QueryContext::with("greptime", "public"))),
..Default::default()
@@ -177,9 +247,37 @@ mod test {
table: NumbersTable::table(NUMBERS_TABLE_ID),
};
catalog_list.register_table_sync(req).unwrap();
let schema = vec![
datatypes::schema::ColumnSchema::new("number", CDT::uint32_datatype(), false),
datatypes::schema::ColumnSchema::new("ts", CDT::datetime_datatype(), false),
];
let mut columns = vec![];
let numbers = (1..=10).collect_vec();
let column: VectorRef = Arc::new(<u32 as Scalar>::VectorType::from_vec(numbers));
columns.push(column);
let ts = (1..=10).collect_vec();
let column: VectorRef = Arc::new(<DateTime as Scalar>::VectorType::from_vec(ts));
columns.push(column);
let schema = Arc::new(Schema::new(schema));
let recordbatch = common_recordbatch::RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("numbers_with_ts", recordbatch);
let req_with_ts = RegisterTableRequest {
catalog: DEFAULT_CATALOG_NAME.to_string(),
schema: DEFAULT_SCHEMA_NAME.to_string(),
table_name: "numbers_with_ts".to_string(),
table_id: 1024,
table,
};
catalog_list.register_table_sync(req_with_ts).unwrap();
let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, false);
let engine = factory.query_engine();
engine.register_function(Arc::new(TumbleFunction {}));
assert_eq!("datafusion", engine.name());
engine

View File

@@ -302,8 +302,26 @@ impl TypedPlan {
return not_impl_err!("Aggregate without an input is not supported");
};
let group_exprs =
TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?;
let group_exprs = {
let group_exprs = TypedExpr::from_substrait_agg_grouping(
ctx,
&agg.groupings,
&input.typ,
extensions,
)?;
TypedExpr::expand_multi_value(&input.typ, &group_exprs)?
};
let time_index = group_exprs.iter().position(|expr| {
matches!(
&expr.expr,
ScalarExpr::CallUnary {
func: UnaryFunc::TumbleWindowFloor { .. },
expr: _
}
)
});
let (mut aggr_exprs, post_mfp) =
AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?;
@@ -314,6 +332,7 @@ impl TypedPlan {
input.typ.column_types.len(),
)?;
// output type is group_exprs + aggr_exprs
let output_type = {
let mut output_types = Vec::new();
// first append group_expr as key, then aggr_expr as value
@@ -332,7 +351,8 @@ impl TypedPlan {
} else {
RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
}
};
}
.with_time_index(time_index);
// copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
// also set them input/output column
@@ -406,6 +426,7 @@ impl TypedPlan {
#[cfg(test)]
mod test {
use common_time::{DateTime, Interval};
use datatypes::prelude::ConcreteDataType;
use pretty_assertions::{assert_eq, assert_ne};
@@ -414,6 +435,106 @@ mod test {
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
#[tokio::test]
async fn test_tumble_parse() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
]),
// TODO(discord9): mfp indirectly ref to key columns
/*
.with_key(vec![1])
.with_time_index(Some(0)),*/
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(ConcreteDataType::datetime_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Interval::from_month_day_nano(
0,
0,
3_600_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Interval::from_month_day_nano(
0,
0,
3_600_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.project(vec![0, 1])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0)),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(3),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![4, 5, 6])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_avg_group_by() {
let engine = create_test_query_engine();
@@ -514,7 +635,8 @@ mod test {
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
let aggr_exprs = vec![
AggregateExpr {
@@ -587,7 +709,7 @@ mod test {
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
assert_eq!(flow_plan, expected);
}
#[tokio::test]

View File

@@ -71,7 +71,7 @@ impl TypedExpr {
),
})?;
let arg_len = f.arguments.len();
let arg_exprs: Vec<TypedExpr> = f
let arg_typed_exprs: Vec<TypedExpr> = f
.arguments
.iter()
.map(|arg| match &arg.arg_type {
@@ -83,7 +83,8 @@ impl TypedExpr {
.try_collect()?;
// literal's type is determined by the function and type of other args
let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_exprs
let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
.clone()
.into_iter()
.map(
|TypedExpr {
@@ -174,7 +175,9 @@ impl TypedExpr {
};
expr.optimize();
Ok(TypedExpr::new(expr, ret_type))
} else if let Ok(func) = UnmaterializableFunc::from_str(fn_name) {
} else if let Ok(func) =
UnmaterializableFunc::from_str_args(fn_name, arg_typed_exprs)
{
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
Ok(TypedExpr::new(
ScalarExpr::CallUnmaterializable(func),

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::BTreeMap;
use itertools::Itertools;
use snafu::OptionExt;
use substrait_proto::proto::expression::MaskExpression;
@@ -22,8 +24,8 @@ use substrait_proto::proto::{plan_rel, Plan as SubPlan, Rel};
use crate::adapter::error::{
Error, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu, UnexpectedSnafu,
};
use crate::expr::{MapFilterProject, TypedExpr};
use crate::plan::{Plan, TypedPlan};
use crate::expr::{MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc};
use crate::plan::{KeyValPlan, Plan, ReducePlan, TypedPlan};
use crate::repr::{self, RelationType};
use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
@@ -75,6 +77,7 @@ impl TypedPlan {
} else {
return not_impl_err!("Projection without an input is not supported");
};
let mut exprs: Vec<TypedExpr> = vec![];
for e in &p.expressions {
let expr = TypedExpr::from_substrait_rex(e, &input.typ, extensions)?;
@@ -97,6 +100,127 @@ impl TypedPlan {
};
Ok(TypedPlan { typ, plan })
} else {
/// if reduce_plan contains the special function like tumble floor/ceiling, add them to the proj_exprs
fn rewrite_projection_after_reduce(
key_val_plan: KeyValPlan,
_reduce_plan: ReducePlan,
reduce_output_type: &RelationType,
proj_exprs: &mut Vec<TypedExpr>,
) -> Result<(), Error> {
// TODO: get keys correctly
let key_exprs = key_val_plan
.key_plan
.projection
.clone()
.into_iter()
.map(|i| {
if i < key_val_plan.key_plan.input_arity {
ScalarExpr::Column(i)
} else {
key_val_plan.key_plan.expressions
[i - key_val_plan.key_plan.input_arity]
.clone()
}
})
.collect_vec();
let mut shift_offset = 0;
let special_keys = key_exprs
.into_iter()
.enumerate()
.filter(|(_idx, p)| {
if matches!(
p,
ScalarExpr::CallUnary {
func: UnaryFunc::TumbleWindowFloor { .. },
..
} | ScalarExpr::CallUnary {
func: UnaryFunc::TumbleWindowCeiling { .. },
..
}
) {
if matches!(
p,
ScalarExpr::CallUnary {
func: UnaryFunc::TumbleWindowFloor { .. },
..
}
) {
shift_offset += 1;
}
true
} else {
false
}
})
.collect_vec();
let spec_key_arity = special_keys.len();
if spec_key_arity == 0 {
return Ok(());
}
{
// shift proj_exprs to the right by spec_key_arity
let max_used_col_in_proj = proj_exprs
.iter()
.map(|expr| {
expr.expr
.get_all_ref_columns()
.into_iter()
.max()
.unwrap_or_default()
})
.max()
.unwrap_or_default();
let shuffle = (0..=max_used_col_in_proj)
.map(|col| (col, col + shift_offset))
.collect::<BTreeMap<_, _>>();
for proj_expr in proj_exprs.iter_mut() {
proj_expr.expr.permute_map(&shuffle)?;
} // add key to the end
for (key_idx, _key_expr) in special_keys {
// here we assume the output type of reduce operator is just first keys columns, then append value columns
proj_exprs.push(
ScalarExpr::Column(key_idx).with_type(
reduce_output_type.column_types[key_idx].clone(),
),
);
}
}
Ok(())
}
match input.plan.clone() {
Plan::Reduce {
key_val_plan,
reduce_plan,
..
} => {
rewrite_projection_after_reduce(
key_val_plan,
reduce_plan,
&input.typ,
&mut exprs,
)?;
}
Plan::Mfp { input, mfp: _ } => {
if let Plan::Reduce {
key_val_plan,
reduce_plan,
..
} = input.plan
{
rewrite_projection_after_reduce(
key_val_plan,
reduce_plan,
&input.typ,
&mut exprs,
)?;
}
}
_ => (),
}
input.projection(exprs)
}
}