mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-06 05:12:54 +00:00
refactor(flow): make from_substrait_* async& worker handle refactor (#4210)
* refactor: use oneshot to receive result * refactor: make from_substrait_* async * refacrot: remove serde for plan&expr
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3949,6 +3949,7 @@ version = "0.8.2"
|
||||
dependencies = [
|
||||
"api",
|
||||
"arrow-schema",
|
||||
"async-recursion",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"catalog",
|
||||
|
||||
@@ -10,6 +10,7 @@ workspace = true
|
||||
[dependencies]
|
||||
api.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
async-recursion = "1.0"
|
||||
async-trait.workspace = true
|
||||
bytes.workspace = true
|
||||
catalog.workspace = true
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
//! For single-thread flow worker
|
||||
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_telemetry::info;
|
||||
use enum_as_inner::EnumAsInner;
|
||||
use hydroflow::scheduled::graph::Hydroflow;
|
||||
use snafu::{ensure, OptionExt};
|
||||
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||
use snafu::ensure;
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
|
||||
|
||||
use crate::adapter::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu};
|
||||
use crate::adapter::FlowId;
|
||||
@@ -39,7 +39,7 @@ type ReqId = usize;
|
||||
pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) {
|
||||
let (itc_client, itc_server) = create_inter_thread_call();
|
||||
let worker_handle = WorkerHandle {
|
||||
itc_client: Mutex::new(itc_client),
|
||||
itc_client,
|
||||
shutdown: AtomicBool::new(false),
|
||||
};
|
||||
let worker = Worker {
|
||||
@@ -106,7 +106,7 @@ impl<'subgraph> ActiveDataflowState<'subgraph> {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerHandle {
|
||||
itc_client: Mutex<InterThreadCallClient>,
|
||||
itc_client: InterThreadCallClient,
|
||||
shutdown: AtomicBool,
|
||||
}
|
||||
|
||||
@@ -122,12 +122,7 @@ impl WorkerHandle {
|
||||
}
|
||||
);
|
||||
|
||||
let ret = self
|
||||
.itc_client
|
||||
.lock()
|
||||
.await
|
||||
.call_with_resp(create_reqs)
|
||||
.await?;
|
||||
let ret = self.itc_client.call_with_resp(create_reqs).await?;
|
||||
ret.into_create().map_err(|ret| {
|
||||
InternalSnafu {
|
||||
reason: format!(
|
||||
@@ -141,7 +136,8 @@ impl WorkerHandle {
|
||||
/// remove task, return task id
|
||||
pub async fn remove_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
|
||||
let req = Request::Remove { flow_id };
|
||||
let ret = self.itc_client.lock().await.call_with_resp(req).await?;
|
||||
|
||||
let ret = self.itc_client.call_with_resp(req).await?;
|
||||
|
||||
ret.into_remove().map_err(|ret| {
|
||||
InternalSnafu {
|
||||
@@ -157,15 +153,12 @@ impl WorkerHandle {
|
||||
///
|
||||
/// the returned error is unrecoverable, and the worker should be shutdown/rebooted
|
||||
pub async fn run_available(&self, now: repr::Timestamp) -> Result<(), Error> {
|
||||
self.itc_client
|
||||
.lock()
|
||||
.await
|
||||
.call_no_resp(Request::RunAvail { now })
|
||||
self.itc_client.call_no_resp(Request::RunAvail { now })
|
||||
}
|
||||
|
||||
pub async fn contains_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
|
||||
let req = Request::ContainTask { flow_id };
|
||||
let ret = self.itc_client.lock().await.call_with_resp(req).await?;
|
||||
let ret = self.itc_client.call_with_resp(req).await?;
|
||||
|
||||
ret.into_contain_task().map_err(|ret| {
|
||||
InternalSnafu {
|
||||
@@ -178,23 +171,9 @@ impl WorkerHandle {
|
||||
}
|
||||
|
||||
/// shutdown the worker
|
||||
pub async fn shutdown(&self) -> Result<(), Error> {
|
||||
pub fn shutdown(&self) -> Result<(), Error> {
|
||||
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
|
||||
self.itc_client.lock().await.call_no_resp(Request::Shutdown)
|
||||
} else {
|
||||
UnexpectedSnafu {
|
||||
reason: "Worker already shutdown",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
|
||||
/// shutdown the worker
|
||||
pub fn shutdown_blocking(&self) -> Result<(), Error> {
|
||||
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
|
||||
self.itc_client
|
||||
.blocking_lock()
|
||||
.call_no_resp(Request::Shutdown)
|
||||
self.itc_client.call_no_resp(Request::Shutdown)
|
||||
} else {
|
||||
UnexpectedSnafu {
|
||||
reason: "Worker already shutdown",
|
||||
@@ -206,8 +185,7 @@ impl WorkerHandle {
|
||||
|
||||
impl Drop for WorkerHandle {
|
||||
fn drop(&mut self) {
|
||||
let ret = futures::executor::block_on(async { self.shutdown().await });
|
||||
if let Err(ret) = ret {
|
||||
if let Err(ret) = self.shutdown() {
|
||||
common_telemetry::error!(
|
||||
ret;
|
||||
"While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
|
||||
@@ -276,7 +254,7 @@ impl<'s> Worker<'s> {
|
||||
/// Run the worker, blocking, until shutdown signal is received
|
||||
pub fn run(&mut self) {
|
||||
loop {
|
||||
let (req_id, req) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
|
||||
let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
|
||||
ret
|
||||
} else {
|
||||
common_telemetry::error!(
|
||||
@@ -285,19 +263,26 @@ impl<'s> Worker<'s> {
|
||||
break;
|
||||
};
|
||||
|
||||
let ret = self.handle_req(req_id, req);
|
||||
match ret {
|
||||
Ok(Some((id, resp))) => {
|
||||
if let Err(err) = self.itc_server.blocking_lock().resp(id, resp) {
|
||||
let ret = self.handle_req(req);
|
||||
match (ret, ret_tx) {
|
||||
(Ok(Some(resp)), Some(ret_tx)) => {
|
||||
if let Err(err) = ret_tx.send(resp) {
|
||||
common_telemetry::error!(
|
||||
err;
|
||||
"Worker's itc server has been closed unexpectedly, shutting down worker"
|
||||
"Result receiver is dropped, can't send result"
|
||||
);
|
||||
break;
|
||||
};
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(()) => {
|
||||
(Ok(None), None) => continue,
|
||||
(Ok(Some(resp)), None) => {
|
||||
common_telemetry::error!(
|
||||
"Expect no result for current request, but found {resp:?}"
|
||||
)
|
||||
}
|
||||
(Ok(None), Some(_)) => {
|
||||
common_telemetry::error!("Expect result for current request, but found nothing")
|
||||
}
|
||||
(Err(()), _) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -315,7 +300,7 @@ impl<'s> Worker<'s> {
|
||||
/// handle request, return response if any, Err if receive shutdown signal
|
||||
///
|
||||
/// return `Err(())` if receive shutdown request
|
||||
fn handle_req(&mut self, req_id: ReqId, req: Request) -> Result<Option<(ReqId, Response)>, ()> {
|
||||
fn handle_req(&mut self, req: Request) -> Result<Option<Response>, ()> {
|
||||
let ret = match req {
|
||||
Request::Create {
|
||||
flow_id,
|
||||
@@ -339,16 +324,13 @@ impl<'s> Worker<'s> {
|
||||
create_if_not_exists,
|
||||
err_collector,
|
||||
);
|
||||
Some((
|
||||
req_id,
|
||||
Response::Create {
|
||||
result: task_create_result,
|
||||
},
|
||||
))
|
||||
Some(Response::Create {
|
||||
result: task_create_result,
|
||||
})
|
||||
}
|
||||
Request::Remove { flow_id } => {
|
||||
let ret = self.remove_flow(flow_id);
|
||||
Some((req_id, Response::Remove { result: ret }))
|
||||
Some(Response::Remove { result: ret })
|
||||
}
|
||||
Request::RunAvail { now } => {
|
||||
self.run_tick(now);
|
||||
@@ -356,7 +338,7 @@ impl<'s> Worker<'s> {
|
||||
}
|
||||
Request::ContainTask { flow_id } => {
|
||||
let ret = self.task_states.contains_key(&flow_id);
|
||||
Some((req_id, Response::ContainTask { result: ret }))
|
||||
Some(Response::ContainTask { result: ret })
|
||||
}
|
||||
Request::Shutdown => return Err(()),
|
||||
};
|
||||
@@ -406,83 +388,50 @@ enum Response {
|
||||
|
||||
fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) {
|
||||
let (arg_send, arg_recv) = mpsc::unbounded_channel();
|
||||
let (ret_send, ret_recv) = mpsc::unbounded_channel();
|
||||
let client = InterThreadCallClient {
|
||||
call_id: AtomicUsize::new(0),
|
||||
arg_sender: arg_send,
|
||||
ret_recv,
|
||||
};
|
||||
let server = InterThreadCallServer {
|
||||
arg_recv,
|
||||
ret_sender: ret_send,
|
||||
};
|
||||
let server = InterThreadCallServer { arg_recv };
|
||||
(client, server)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InterThreadCallClient {
|
||||
call_id: AtomicUsize,
|
||||
arg_sender: mpsc::UnboundedSender<(ReqId, Request)>,
|
||||
ret_recv: mpsc::UnboundedReceiver<(ReqId, Response)>,
|
||||
arg_sender: mpsc::UnboundedSender<(Request, Option<oneshot::Sender<Response>>)>,
|
||||
}
|
||||
|
||||
impl InterThreadCallClient {
|
||||
/// call without expecting responses or blocking
|
||||
fn call_no_resp(&self, req: Request) -> Result<(), Error> {
|
||||
// TODO(discord9): relax memory order later
|
||||
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
|
||||
self.arg_sender
|
||||
.send((call_id, req))
|
||||
.map_err(from_send_error)
|
||||
self.arg_sender.send((req, None)).map_err(from_send_error)
|
||||
}
|
||||
|
||||
/// call blocking, and return the result
|
||||
async fn call_with_resp(&mut self, req: Request) -> Result<Response, Error> {
|
||||
// TODO(discord9): relax memory order later
|
||||
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
|
||||
async fn call_with_resp(&self, req: Request) -> Result<Response, Error> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.arg_sender
|
||||
.send((call_id, req))
|
||||
.send((req, Some(tx)))
|
||||
.map_err(from_send_error)?;
|
||||
|
||||
// TODO(discord9): better inter thread call impl, i.e. support multiple client(also consider if it's necessary)
|
||||
// since one node manger might manage multiple worker, but one worker should only belong to one node manager
|
||||
let (ret_call_id, ret) = self
|
||||
.ret_recv
|
||||
.recv()
|
||||
.await
|
||||
.context(InternalSnafu { reason: "InterThreadCallClient call_blocking failed, ret_recv has been closed and there are no remaining messages in the channel's buffer" })?;
|
||||
|
||||
ensure!(
|
||||
ret_call_id == call_id,
|
||||
rx.await.map_err(|_| {
|
||||
InternalSnafu {
|
||||
reason: "call id mismatch, worker/worker handler should be in sync",
|
||||
reason: "Sender is dropped",
|
||||
}
|
||||
);
|
||||
Ok(ret)
|
||||
.build()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InterThreadCallServer {
|
||||
pub arg_recv: mpsc::UnboundedReceiver<(ReqId, Request)>,
|
||||
pub ret_sender: mpsc::UnboundedSender<(ReqId, Response)>,
|
||||
pub arg_recv: mpsc::UnboundedReceiver<(Request, Option<oneshot::Sender<Response>>)>,
|
||||
}
|
||||
|
||||
impl InterThreadCallServer {
|
||||
pub async fn recv(&mut self) -> Option<(usize, Request)> {
|
||||
pub async fn recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
|
||||
self.arg_recv.recv().await
|
||||
}
|
||||
|
||||
pub fn blocking_recv(&mut self) -> Option<(usize, Request)> {
|
||||
pub fn blocking_recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
|
||||
self.arg_recv.blocking_recv()
|
||||
}
|
||||
|
||||
/// Send response back to the client
|
||||
pub fn resp(&self, call_id: ReqId, resp: Response) -> Result<(), Error> {
|
||||
self.ret_sender
|
||||
.send((call_id, resp))
|
||||
.map_err(from_send_error)
|
||||
}
|
||||
}
|
||||
|
||||
fn from_send_error<T>(err: mpsc::error::SendError<T>) -> Error {
|
||||
@@ -546,7 +495,10 @@ mod test {
|
||||
create_if_not_exists: true,
|
||||
err_collector: ErrCollector::default(),
|
||||
};
|
||||
handle.create_flow(create_reqs).await.unwrap();
|
||||
assert_eq!(
|
||||
handle.create_flow(create_reqs).await.unwrap(),
|
||||
Some(flow_id)
|
||||
);
|
||||
tx.send((Row::empty(), 0, 0)).unwrap();
|
||||
handle.run_available(0).await.unwrap();
|
||||
assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty());
|
||||
|
||||
@@ -43,7 +43,7 @@ use crate::repr::{self, value_to_internal_ts, Row};
|
||||
|
||||
/// UnmaterializableFunc is a function that can't be eval independently,
|
||||
/// and require special handling
|
||||
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
|
||||
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub enum UnmaterializableFunc {
|
||||
Now,
|
||||
CurrentSchema,
|
||||
|
||||
@@ -49,7 +49,7 @@ use crate::repr::{self, value_to_internal_ts, Diff, Row};
|
||||
/// expressions in `self.expressions`, even though this is not something
|
||||
/// we can directly evaluate. The plan creation methods will defensively
|
||||
/// ensure that the right thing happens.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct MapFilterProject {
|
||||
/// A sequence of expressions that should be appended to the row.
|
||||
///
|
||||
@@ -462,7 +462,7 @@ impl MapFilterProject {
|
||||
}
|
||||
|
||||
/// A wrapper type which indicates it is safe to simply evaluate all expressions.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct SafeMfpPlan {
|
||||
/// the inner `MapFilterProject` that is safe to evaluate.
|
||||
pub(crate) mfp: MapFilterProject,
|
||||
|
||||
@@ -23,7 +23,7 @@ mod accum;
|
||||
mod func;
|
||||
|
||||
/// Describes an aggregation expression.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct AggregateExpr {
|
||||
/// Names the aggregation function.
|
||||
pub func: AggregateFunc,
|
||||
@@ -32,6 +32,5 @@ pub struct AggregateExpr {
|
||||
/// so it only used in generate KeyValPlan from AggregateExpr
|
||||
pub expr: ScalarExpr,
|
||||
/// Should the aggregation be applied only to distinct results in each group.
|
||||
#[serde(default)]
|
||||
pub distinct: bool,
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFun
|
||||
use crate::repr::{ColumnType, RelationDesc, RelationType};
|
||||
use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions};
|
||||
/// A scalar expression with a known type.
|
||||
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
|
||||
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct TypedExpr {
|
||||
/// The expression.
|
||||
pub expr: ScalarExpr,
|
||||
@@ -129,7 +129,7 @@ impl TypedExpr {
|
||||
}
|
||||
|
||||
/// A scalar expression, which can be evaluated to a value.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum ScalarExpr {
|
||||
/// A column of the input row
|
||||
Column(usize),
|
||||
@@ -191,9 +191,9 @@ impl DfScalarFunction {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
|
||||
pub async fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
fn_impl: raw_fn.get_fn_impl()?,
|
||||
fn_impl: raw_fn.get_fn_impl().await?,
|
||||
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
|
||||
raw_fn,
|
||||
})
|
||||
@@ -264,27 +264,7 @@ impl DfScalarFunction {
|
||||
}
|
||||
}
|
||||
|
||||
// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct
|
||||
impl Serialize for DfScalarFunction {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
self.raw_fn.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
|
||||
DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct RawDfScalarFn {
|
||||
/// The raw bytes encoded datafusion scalar function
|
||||
pub(crate) f: bytes::BytesMut,
|
||||
@@ -311,7 +291,7 @@ impl RawDfScalarFn {
|
||||
extensions,
|
||||
})
|
||||
}
|
||||
fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
|
||||
async fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
|
||||
let f = ScalarFunction::decode(&mut self.f.as_ref())
|
||||
.context(DecodeRelSnafu)
|
||||
.map_err(BoxedError::new)
|
||||
@@ -320,7 +300,7 @@ impl RawDfScalarFn {
|
||||
let input_schema = &self.input_schema;
|
||||
let extensions = &self.extensions;
|
||||
|
||||
from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions)
|
||||
from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -894,10 +874,7 @@ mod test {
|
||||
.unwrap();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
|
||||
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
|
||||
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap();
|
||||
let as_str = serde_json::to_string(&df_func).unwrap();
|
||||
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
|
||||
assert_eq!(df_func, from_str);
|
||||
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await.unwrap();
|
||||
assert_eq!(
|
||||
df_func
|
||||
.eval(&[Value::Null], &[ScalarExpr::Column(0)])
|
||||
|
||||
@@ -33,7 +33,7 @@ pub(crate) use crate::plan::reduce::{AccumulablePlan, AggrWithIndex, KeyValPlan,
|
||||
use crate::repr::{ColumnType, DiffRow, RelationDesc, RelationType};
|
||||
|
||||
/// A plan for a dataflow component. But with type to indicate the output type of the relation.
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct TypedPlan {
|
||||
/// output type of the relation
|
||||
pub schema: RelationDesc,
|
||||
@@ -121,7 +121,7 @@ impl TypedPlan {
|
||||
|
||||
/// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n)
|
||||
/// Plan describe how to transform data in dataflow
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub enum Plan {
|
||||
/// A constant collection of rows.
|
||||
Constant { rows: Vec<DiffRow> },
|
||||
|
||||
@@ -18,13 +18,13 @@ use crate::expr::ScalarExpr;
|
||||
use crate::plan::SafeMfpPlan;
|
||||
|
||||
/// TODO(discord9): consider impl more join strategies
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub enum JoinPlan {
|
||||
Linear(LinearJoinPlan),
|
||||
}
|
||||
|
||||
/// Determine if a given row should stay in the output. And apply a map filter project before output the row
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct JoinFilter {
|
||||
/// each element in the outer vector will check if each expr in itself can be eval to same value
|
||||
/// if not, the row will be filtered out. Useful for equi-join(join based on equality of some columns)
|
||||
@@ -37,7 +37,7 @@ pub struct JoinFilter {
|
||||
///
|
||||
/// A linear join is a sequence of stages, each of which introduces
|
||||
/// a new collection. Each stage is represented by a [LinearStagePlan].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct LinearJoinPlan {
|
||||
/// The source relation from which we start the join.
|
||||
pub source_relation: usize,
|
||||
@@ -60,7 +60,7 @@ pub struct LinearJoinPlan {
|
||||
/// Each stage is a binary join between the current accumulated
|
||||
/// join results, and a new collection. The former is referred to
|
||||
/// as the "stream" and the latter the "lookup".
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct LinearStagePlan {
|
||||
/// The index of the relation into which we will look up.
|
||||
pub lookup_relation: usize,
|
||||
|
||||
@@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::expr::{AggregateExpr, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr};
|
||||
|
||||
/// Describe how to extract key-value pair from a `Row`
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct KeyValPlan {
|
||||
/// Extract key from row
|
||||
pub key_plan: SafeMfpPlan,
|
||||
@@ -27,7 +27,7 @@ pub struct KeyValPlan {
|
||||
|
||||
/// TODO(discord9): def&impl of Hierarchical aggregates(for min/max with support to deletion) and
|
||||
/// basic aggregates(for other aggregate functions) and mixed aggregate
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub enum ReducePlan {
|
||||
/// Plan for not computing any aggregations, just determining the set of
|
||||
/// distinct keys.
|
||||
@@ -38,7 +38,7 @@ pub enum ReducePlan {
|
||||
}
|
||||
|
||||
/// Accumulable plan for the execution of a reduction.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct AccumulablePlan {
|
||||
/// All of the aggregations we were asked to compute, stored
|
||||
/// in order.
|
||||
@@ -57,7 +57,7 @@ pub struct AccumulablePlan {
|
||||
|
||||
/// Invariant: the output index is the index of the aggregation in `full_aggrs`
|
||||
/// which means output index is always smaller than the length of `full_aggrs`
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct AggrWithIndex {
|
||||
/// aggregation expression
|
||||
pub expr: AggregateExpr,
|
||||
|
||||
@@ -140,7 +140,7 @@ pub async fn sql_to_flow_plan(
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
|
||||
let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan)?;
|
||||
let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
|
||||
|
||||
Ok(flow_plan)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ use crate::repr::{self, ColumnType, RelationDesc, RelationType};
|
||||
use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
|
||||
|
||||
impl TypedExpr {
|
||||
fn from_substrait_agg_grouping(
|
||||
async fn from_substrait_agg_grouping(
|
||||
ctx: &mut FlownodeContext,
|
||||
groupings: &[Grouping],
|
||||
typ: &RelationDesc,
|
||||
@@ -69,7 +69,7 @@ impl TypedExpr {
|
||||
match groupings.len() {
|
||||
1 => {
|
||||
for e in &groupings[0].grouping_expressions {
|
||||
let x = TypedExpr::from_substrait_rex(e, typ, extensions)?;
|
||||
let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
|
||||
group_expr.push(x);
|
||||
}
|
||||
}
|
||||
@@ -87,7 +87,7 @@ impl AggregateExpr {
|
||||
/// Convert list of `Measure` into Flow's AggregateExpr
|
||||
///
|
||||
/// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
|
||||
fn from_substrait_agg_measures(
|
||||
async fn from_substrait_agg_measures(
|
||||
ctx: &mut FlownodeContext,
|
||||
measures: &[Measure],
|
||||
typ: &RelationDesc,
|
||||
@@ -98,11 +98,15 @@ impl AggregateExpr {
|
||||
let mut post_maps = vec![];
|
||||
|
||||
for m in measures {
|
||||
let filter = &m
|
||||
let filter = match m
|
||||
.filter
|
||||
.as_ref()
|
||||
.map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
|
||||
.transpose()?;
|
||||
{
|
||||
Some(fut) => Some(fut.await),
|
||||
None => None,
|
||||
}
|
||||
.transpose()?;
|
||||
|
||||
let (aggr_expr, post_mfp) = match &m.measure {
|
||||
Some(f) => {
|
||||
@@ -112,9 +116,10 @@ impl AggregateExpr {
|
||||
_ => false,
|
||||
};
|
||||
AggregateExpr::from_substrait_agg_func(
|
||||
f, typ, extensions, filter, // TODO(discord9): impl order_by
|
||||
f, typ, extensions, &filter, // TODO(discord9): impl order_by
|
||||
&None, distinct,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => not_impl_err!("Aggregate without aggregate function is not supported"),
|
||||
}?;
|
||||
@@ -142,7 +147,7 @@ impl AggregateExpr {
|
||||
///
|
||||
/// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
|
||||
/// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
|
||||
pub fn from_substrait_agg_func(
|
||||
pub async fn from_substrait_agg_func(
|
||||
f: &proto::AggregateFunction,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
@@ -157,7 +162,7 @@ impl AggregateExpr {
|
||||
for arg in &f.arguments {
|
||||
let arg_expr = match &arg.arg_type {
|
||||
Some(ArgType::Value(e)) => {
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions)
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions).await
|
||||
}
|
||||
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
|
||||
}?;
|
||||
@@ -306,13 +311,14 @@ impl TypedPlan {
|
||||
/// The output of aggr plan is:
|
||||
///
|
||||
/// <group_exprs>..<aggr_exprs>
|
||||
pub fn from_substrait_agg_rel(
|
||||
#[async_recursion::async_recursion]
|
||||
pub async fn from_substrait_agg_rel(
|
||||
ctx: &mut FlownodeContext,
|
||||
agg: &proto::AggregateRel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
let input = if let Some(input) = agg.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions).await?
|
||||
} else {
|
||||
return not_impl_err!("Aggregate without an input is not supported");
|
||||
};
|
||||
@@ -323,7 +329,8 @@ impl TypedPlan {
|
||||
&agg.groupings,
|
||||
&input.schema,
|
||||
extensions,
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
TypedExpr::expand_multi_value(&input.schema.typ, &group_exprs)?
|
||||
};
|
||||
@@ -335,7 +342,8 @@ impl TypedPlan {
|
||||
&agg.measures,
|
||||
&input.schema,
|
||||
extensions,
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
|
||||
&mut aggr_exprs,
|
||||
@@ -479,7 +487,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan).is_err());
|
||||
assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -489,7 +499,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
@@ -578,6 +590,7 @@ mod test {
|
||||
},
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
exprs: vec![ScalarExpr::Column(0)],
|
||||
}])
|
||||
@@ -630,7 +643,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
@@ -743,6 +758,7 @@ mod test {
|
||||
]),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap(),
|
||||
exprs: vec![ScalarExpr::Column(3)],
|
||||
},
|
||||
@@ -766,7 +782,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
@@ -913,7 +931,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
@@ -1029,7 +1049,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
@@ -1145,7 +1167,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
@@ -1250,7 +1272,9 @@ mod test {
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_exprs = vec![
|
||||
AggregateExpr {
|
||||
@@ -1341,7 +1365,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
let typ = RelationType::new(vec![ColumnType::new(
|
||||
ConcreteDataType::uint64_datatype(),
|
||||
true,
|
||||
@@ -1404,7 +1428,9 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
@@ -1482,7 +1508,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
|
||||
@@ -18,7 +18,6 @@ use std::sync::Arc;
|
||||
|
||||
use datafusion_physical_expr::PhysicalExpr;
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use itertools::Itertools;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
|
||||
use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
|
||||
@@ -60,7 +59,7 @@ fn typename_to_cdt(name: &str) -> CDT {
|
||||
}
|
||||
|
||||
/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`]
|
||||
pub(crate) fn from_scalar_fn_to_df_fn_impl(
|
||||
pub(crate) async fn from_scalar_fn_to_df_fn_impl(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
@@ -70,7 +69,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl(
|
||||
};
|
||||
let schema = input_schema.to_df_schema()?;
|
||||
|
||||
let df_expr = futures::executor::block_on(async {
|
||||
let df_expr =
|
||||
// TODO(discord9): consider coloring everything async....
|
||||
substrait::df_logical_plan::consumer::from_substrait_rex(
|
||||
&datafusion::prelude::SessionContext::new(),
|
||||
@@ -79,7 +78,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl(
|
||||
&extensions.inner_ref(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
;
|
||||
let expr = df_expr.map_err(|err| {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
@@ -138,7 +137,7 @@ fn rewrite_scalar_function(f: &ScalarFunction) -> ScalarFunction {
|
||||
}
|
||||
|
||||
impl TypedExpr {
|
||||
pub fn from_substrait_to_datafusion_scalar_func(
|
||||
pub async fn from_substrait_to_datafusion_scalar_func(
|
||||
f: &ScalarFunction,
|
||||
arg_exprs_typed: Vec<TypedExpr>,
|
||||
extensions: &FunctionExtensions,
|
||||
@@ -152,7 +151,7 @@ impl TypedExpr {
|
||||
let raw_fn =
|
||||
RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?;
|
||||
|
||||
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn)?;
|
||||
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?;
|
||||
let expr = ScalarExpr::CallDf {
|
||||
df_scalar_fn: df_func,
|
||||
exprs: arg_exprs,
|
||||
@@ -163,7 +162,7 @@ impl TypedExpr {
|
||||
}
|
||||
|
||||
/// Convert ScalarFunction into Flow's ScalarExpr
|
||||
pub fn from_substrait_scalar_func(
|
||||
pub async fn from_substrait_scalar_func(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
@@ -178,16 +177,19 @@ impl TypedExpr {
|
||||
),
|
||||
})?;
|
||||
let arg_len = f.arguments.len();
|
||||
let arg_typed_exprs: Vec<TypedExpr> = f
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|arg| match &arg.arg_type {
|
||||
Some(ArgType::Value(e)) => {
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions)
|
||||
}
|
||||
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
|
||||
})
|
||||
.try_collect()?;
|
||||
let arg_typed_exprs: Vec<TypedExpr> = {
|
||||
let mut rets = Vec::new();
|
||||
for arg in f.arguments.iter() {
|
||||
let ret = match &arg.arg_type {
|
||||
Some(ArgType::Value(e)) => {
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions).await
|
||||
}
|
||||
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
|
||||
}?;
|
||||
rets.push(ret);
|
||||
}
|
||||
rets
|
||||
};
|
||||
|
||||
// literal's type is determined by the function and type of other args
|
||||
let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
|
||||
@@ -293,7 +295,8 @@ impl TypedExpr {
|
||||
f,
|
||||
arg_typed_exprs,
|
||||
extensions,
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
Ok(try_as_df)
|
||||
}
|
||||
}
|
||||
@@ -301,38 +304,44 @@ impl TypedExpr {
|
||||
}
|
||||
|
||||
/// Convert IfThen into Flow's ScalarExpr
|
||||
pub fn from_substrait_ifthen_rex(
|
||||
pub async fn from_substrait_ifthen_rex(
|
||||
if_then: &IfThen,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let ifs: Vec<_> = if_then
|
||||
.ifs
|
||||
.iter()
|
||||
.map(|if_clause| {
|
||||
let ifs: Vec<_> = {
|
||||
let mut ifs = Vec::new();
|
||||
for if_clause in if_then.ifs.iter() {
|
||||
let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "IfThen clause without if",
|
||||
})?;
|
||||
let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "IfThen clause without then",
|
||||
})?;
|
||||
let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?;
|
||||
let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?;
|
||||
Ok((cond, then))
|
||||
})
|
||||
.try_collect()?;
|
||||
let cond =
|
||||
TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?;
|
||||
let then =
|
||||
TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?;
|
||||
ifs.push((cond, then));
|
||||
}
|
||||
ifs
|
||||
};
|
||||
// if no else is presented
|
||||
let els = if_then
|
||||
let els = match if_then
|
||||
.r#else
|
||||
.as_ref()
|
||||
.map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| {
|
||||
TypedExpr::new(
|
||||
ScalarExpr::literal_null(),
|
||||
ColumnType::new_nullable(CDT::null_datatype()),
|
||||
)
|
||||
});
|
||||
{
|
||||
Some(fut) => Some(fut.await),
|
||||
None => None,
|
||||
}
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| {
|
||||
TypedExpr::new(
|
||||
ScalarExpr::literal_null(),
|
||||
ColumnType::new_nullable(CDT::null_datatype()),
|
||||
)
|
||||
});
|
||||
|
||||
fn build_if_then_recur(
|
||||
mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
|
||||
@@ -356,7 +365,8 @@ impl TypedExpr {
|
||||
Ok(expr_if)
|
||||
}
|
||||
/// Convert Substrait Rex into Flow's ScalarExpr
|
||||
pub fn from_substrait_rex(
|
||||
#[async_recursion::async_recursion]
|
||||
pub async fn from_substrait_rex(
|
||||
e: &Expression,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
@@ -377,7 +387,7 @@ impl TypedExpr {
|
||||
if !s.options.is_empty() {
|
||||
return not_impl_err!("In list expression is not supported");
|
||||
}
|
||||
TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions)
|
||||
TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await
|
||||
}
|
||||
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
|
||||
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
|
||||
@@ -400,16 +410,16 @@ impl TypedExpr {
|
||||
_ => not_impl_err!("unsupported field ref type"),
|
||||
},
|
||||
Some(RexType::ScalarFunction(f)) => {
|
||||
TypedExpr::from_substrait_scalar_func(f, input_schema, extensions)
|
||||
TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await
|
||||
}
|
||||
Some(RexType::IfThen(if_then)) => {
|
||||
TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions)
|
||||
TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await
|
||||
}
|
||||
Some(RexType::Cast(cast)) => {
|
||||
let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "Cast expression without input",
|
||||
})?;
|
||||
let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?;
|
||||
let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?;
|
||||
let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
|
||||
InvalidQuerySnafu {
|
||||
reason: "Cast expression without type",
|
||||
@@ -453,7 +463,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
// optimize binary and to variadic and
|
||||
let filter = ScalarExpr::CallVariadic {
|
||||
@@ -509,7 +519,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)])
|
||||
@@ -534,7 +544,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
|
||||
@@ -572,7 +582,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)])
|
||||
@@ -611,7 +621,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
|
||||
@@ -641,8 +651,8 @@ mod test {
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_func_sig() {
|
||||
#[tokio::test]
|
||||
async fn test_func_sig() {
|
||||
fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument {
|
||||
use substrait_proto::proto::expression;
|
||||
let expr = Expression {
|
||||
@@ -669,7 +679,9 @@ mod test {
|
||||
let input_schema =
|
||||
RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
@@ -695,7 +707,9 @@ mod test {
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
@@ -722,7 +736,9 @@ mod test {
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
@@ -750,7 +766,9 @@ mod test {
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
|
||||
@@ -172,7 +172,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
|
||||
|
||||
impl TypedPlan {
|
||||
/// Convert Substrait Plan into Flow's TypedPlan
|
||||
pub fn from_substrait_plan(
|
||||
pub async fn from_substrait_plan(
|
||||
ctx: &mut FlownodeContext,
|
||||
plan: &SubPlan,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
@@ -45,13 +45,13 @@ impl TypedPlan {
|
||||
match plan.relations[0].rel_type.as_ref() {
|
||||
Some(rt) => match rt {
|
||||
plan_rel::RelType::Rel(rel) => {
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?)
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension).await?)
|
||||
},
|
||||
plan_rel::RelType::Root(root) => {
|
||||
let input = root.input.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "Root relation without input",
|
||||
})?;
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?)
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension).await?)
|
||||
}
|
||||
},
|
||||
None => plan_err!("Cannot parse plan relation: None")
|
||||
@@ -64,13 +64,14 @@ impl TypedPlan {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_substrait_project(
|
||||
#[async_recursion::async_recursion]
|
||||
pub async fn from_substrait_project(
|
||||
ctx: &mut FlownodeContext,
|
||||
p: &ProjectRel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
let input = if let Some(input) = p.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions).await?
|
||||
} else {
|
||||
return not_impl_err!("Projection without an input is not supported");
|
||||
};
|
||||
@@ -93,7 +94,7 @@ impl TypedPlan {
|
||||
|
||||
let mut exprs: Vec<TypedExpr> = Vec::with_capacity(p.expressions.len());
|
||||
for e in &p.expressions {
|
||||
let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions)?;
|
||||
let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions).await?;
|
||||
exprs.push(expr);
|
||||
}
|
||||
let is_literal = exprs.iter().all(|expr| expr.expr.is_literal());
|
||||
@@ -131,26 +132,27 @@ impl TypedPlan {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_substrait_filter(
|
||||
#[async_recursion::async_recursion]
|
||||
pub async fn from_substrait_filter(
|
||||
ctx: &mut FlownodeContext,
|
||||
filter: &FilterRel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
let input = if let Some(input) = filter.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions).await?
|
||||
} else {
|
||||
return not_impl_err!("Filter without an input is not supported");
|
||||
};
|
||||
|
||||
let expr = if let Some(condition) = filter.condition.as_ref() {
|
||||
TypedExpr::from_substrait_rex(condition, &input.schema, extensions)?
|
||||
TypedExpr::from_substrait_rex(condition, &input.schema, extensions).await?
|
||||
} else {
|
||||
return not_impl_err!("Filter without an condition is not valid");
|
||||
};
|
||||
input.filter(expr)
|
||||
}
|
||||
|
||||
pub fn from_substrait_read(
|
||||
pub async fn from_substrait_read(
|
||||
ctx: &mut FlownodeContext,
|
||||
read: &ReadRel,
|
||||
_extensions: &FunctionExtensions,
|
||||
@@ -212,16 +214,22 @@ impl TypedPlan {
|
||||
|
||||
/// Convert Substrait Rel into Flow's TypedPlan
|
||||
/// TODO(discord9): SELECT DISTINCT(does it get compile with something else?)
|
||||
pub fn from_substrait_rel(
|
||||
pub async fn from_substrait_rel(
|
||||
ctx: &mut FlownodeContext,
|
||||
rel: &Rel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
match &rel.rel_type {
|
||||
Some(RelType::Project(p)) => Self::from_substrait_project(ctx, p.as_ref(), extensions),
|
||||
Some(RelType::Filter(filter)) => Self::from_substrait_filter(ctx, filter, extensions),
|
||||
Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions),
|
||||
Some(RelType::Aggregate(agg)) => Self::from_substrait_agg_rel(ctx, agg, extensions),
|
||||
Some(RelType::Project(p)) => {
|
||||
Self::from_substrait_project(ctx, p.as_ref(), extensions).await
|
||||
}
|
||||
Some(RelType::Filter(filter)) => {
|
||||
Self::from_substrait_filter(ctx, filter, extensions).await
|
||||
}
|
||||
Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions).await,
|
||||
Some(RelType::Aggregate(agg)) => {
|
||||
Self::from_substrait_agg_rel(ctx, agg, extensions).await
|
||||
}
|
||||
_ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type),
|
||||
}
|
||||
}
|
||||
@@ -353,7 +361,7 @@ mod test {
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
|
||||
|
||||
let expected = TypedPlan {
|
||||
schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)])
|
||||
|
||||
@@ -40,7 +40,7 @@ pub type Spine = BTreeMap<Timestamp, Batch>;
|
||||
/// If a key is expired, any future updates to it should be ignored.
|
||||
///
|
||||
/// Note that key is expired by it's event timestamp (contained in the key), not by the time it's inserted (system timestamp).
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct KeyExpiryManager {
|
||||
/// A map from event timestamp to key, used for expire keys.
|
||||
event_ts_to_key: BTreeMap<Timestamp, BTreeSet<Row>>,
|
||||
@@ -157,7 +157,7 @@ impl KeyExpiryManager {
|
||||
///
|
||||
/// Note the two way arrow between reduce operator and arrange, it's because reduce operator need to query existing state
|
||||
/// and also need to update existing state.
|
||||
#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct Arrangement {
|
||||
/// A name or identifier for the arrangement which can be used for debugging or logging purposes.
|
||||
/// This field is not critical to the functionality but aids in monitoring and management of arrangements.
|
||||
|
||||
Reference in New Issue
Block a user