feat(flow): impl ScalarExpr&Scalar Function (#3283)

* feat: impl for ScalarExpr

* feat: plain functions

* refactor: simpler trait bound&tests

* chore: remove unused imports

* chore: fmt

* refactor: early ret on first error

* refactor: remove abunant match arm

* chore: per review

* doc: `support` fn

* chore: per review more

* chore: more per review

* fix: extract_bound

* chore: per review

* refactor: reduce nest
This commit is contained in:
discord9
2024-02-21 20:53:16 +08:00
committed by GitHub
parent 7c88d721c2
commit 860b1e9d9e
5 changed files with 632 additions and 8 deletions

1
Cargo.lock generated
View File

@@ -3406,6 +3406,7 @@ dependencies = [
"datatypes",
"hydroflow",
"itertools 0.10.5",
"num-traits",
"serde",
"servers",
"session",

View File

@@ -17,6 +17,7 @@ common-time.workspace = true
datatypes.workspace = true
hydroflow = "0.5.0"
itertools.workspace = true
num-traits = "0.2"
serde.workspace = true
servers.workspace = true
session.workspace = true

View File

@@ -58,4 +58,7 @@ pub enum EvalError {
#[snafu(display("Optimize error: {reason}"))]
Optimize { reason: String, location: Location },
#[snafu(display("Unsupported temporal filter: {reason}"))]
UnsupportedTemporalFilter { reason: String, location: Location },
}

View File

@@ -21,14 +21,12 @@ use hydroflow::bincode::Error;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use super::ScalarExpr;
use crate::expr::error::CastValueSnafu;
use crate::expr::InvalidArgumentSnafu;
// TODO(discord9): more function & eval
use crate::{
expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu},
repr::Row,
use crate::expr::error::{
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu,
TypeMismatchSnafu,
};
use crate::expr::{InvalidArgumentSnafu, ScalarExpr};
use crate::repr::Row;
/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
@@ -47,6 +45,66 @@ pub enum UnaryFunc {
StepTimestamp,
Cast(ConcreteDataType),
}
impl UnaryFunc {
pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result<Value, EvalError> {
let arg = expr.eval(values)?;
match self {
Self::Not => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
Ok(Value::from(!bool))
}
Self::IsNull => Ok(Value::from(arg.is_null())),
Self::IsTrue | Self::IsFalse => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
if matches!(self, Self::IsTrue) {
Ok(Value::from(bool))
} else {
Ok(Value::from(!bool))
}
}
Self::StepTimestamp => {
if let Value::DateTime(datetime) = arg {
let datetime = DateTime::from(datetime.val() + 1);
Ok(Value::from(datetime))
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::datetime_datatype(),
actual: arg.data_type(),
}
.fail()?
}
}
Self::Cast(to) => {
let arg_ty = arg.data_type();
let res = cast(arg, to).context({
CastValueSnafu {
from: arg_ty,
to: to.clone(),
}
})?;
Ok(res)
}
}
}
}
/// TODO(discord9): support more binary functions for more types
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum BinaryFunc {
@@ -96,8 +154,232 @@ pub enum BinaryFunc {
ModUInt64,
}
impl BinaryFunc {
pub fn eval(
&self,
values: &[Value],
expr1: &ScalarExpr,
expr2: &ScalarExpr,
) -> Result<Value, EvalError> {
let left = expr1.eval(values)?;
let right = expr2.eval(values)?;
match self {
Self::Eq => Ok(Value::from(left == right)),
Self::NotEq => Ok(Value::from(left != right)),
Self::Lt => Ok(Value::from(left < right)),
Self::Lte => Ok(Value::from(left <= right)),
Self::Gt => Ok(Value::from(left > right)),
Self::Gte => Ok(Value::from(left >= right)),
Self::AddInt16 => Ok(add::<i16>(left, right)?),
Self::AddInt32 => Ok(add::<i32>(left, right)?),
Self::AddInt64 => Ok(add::<i64>(left, right)?),
Self::AddUInt16 => Ok(add::<u16>(left, right)?),
Self::AddUInt32 => Ok(add::<u32>(left, right)?),
Self::AddUInt64 => Ok(add::<u64>(left, right)?),
Self::AddFloat32 => Ok(add::<f32>(left, right)?),
Self::AddFloat64 => Ok(add::<f64>(left, right)?),
Self::SubInt16 => Ok(sub::<i16>(left, right)?),
Self::SubInt32 => Ok(sub::<i32>(left, right)?),
Self::SubInt64 => Ok(sub::<i64>(left, right)?),
Self::SubUInt16 => Ok(sub::<u16>(left, right)?),
Self::SubUInt32 => Ok(sub::<u32>(left, right)?),
Self::SubUInt64 => Ok(sub::<u64>(left, right)?),
Self::SubFloat32 => Ok(sub::<f32>(left, right)?),
Self::SubFloat64 => Ok(sub::<f64>(left, right)?),
Self::MulInt16 => Ok(mul::<i16>(left, right)?),
Self::MulInt32 => Ok(mul::<i32>(left, right)?),
Self::MulInt64 => Ok(mul::<i64>(left, right)?),
Self::MulUInt16 => Ok(mul::<u16>(left, right)?),
Self::MulUInt32 => Ok(mul::<u32>(left, right)?),
Self::MulUInt64 => Ok(mul::<u64>(left, right)?),
Self::MulFloat32 => Ok(mul::<f32>(left, right)?),
Self::MulFloat64 => Ok(mul::<f64>(left, right)?),
Self::DivInt16 => Ok(div::<i16>(left, right)?),
Self::DivInt32 => Ok(div::<i32>(left, right)?),
Self::DivInt64 => Ok(div::<i64>(left, right)?),
Self::DivUInt16 => Ok(div::<u16>(left, right)?),
Self::DivUInt32 => Ok(div::<u32>(left, right)?),
Self::DivUInt64 => Ok(div::<u64>(left, right)?),
Self::DivFloat32 => Ok(div::<f32>(left, right)?),
Self::DivFloat64 => Ok(div::<f64>(left, right)?),
Self::ModInt16 => Ok(rem::<i16>(left, right)?),
Self::ModInt32 => Ok(rem::<i32>(left, right)?),
Self::ModInt64 => Ok(rem::<i64>(left, right)?),
Self::ModUInt16 => Ok(rem::<u16>(left, right)?),
Self::ModUInt32 => Ok(rem::<u32>(left, right)?),
Self::ModUInt64 => Ok(rem::<u64>(left, right)?),
}
}
/// Reverse the comparison operator, i.e. `a < b` becomes `b > a`,
/// equal and not equal are unchanged.
pub fn reverse_compare(&self) -> Result<Self, EvalError> {
let ret = match &self {
BinaryFunc::Eq => BinaryFunc::Eq,
BinaryFunc::NotEq => BinaryFunc::NotEq,
BinaryFunc::Lt => BinaryFunc::Gt,
BinaryFunc::Lte => BinaryFunc::Gte,
BinaryFunc::Gt => BinaryFunc::Lt,
BinaryFunc::Gte => BinaryFunc::Lte,
_ => {
return InternalSnafu {
reason: format!("Expect a comparison operator, found {:?}", self),
}
.fail();
}
};
Ok(ret)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum VariadicFunc {
And,
Or,
}
impl VariadicFunc {
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
match self {
VariadicFunc::And => and(values, exprs),
VariadicFunc::Or => or(values, exprs),
}
}
}
fn and(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => {}
Ok(Value::Boolean(false)) => return Ok(Value::Boolean(false)), // short-circuit
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`and()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(true)),
}
}
fn or(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => return Ok(Value::Boolean(true)), // short-circuit
Ok(Value::Boolean(false)) => {}
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`or()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(false)),
}
}
fn add<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left + right))
}
fn sub<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left - right))
}
fn mul<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left * right))
}
fn div<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
if right.is_zero() {
return Err(DivisionByZeroSnafu {}.build());
}
Ok(Value::from(left / right))
}
fn rem<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left % right))
}
#[test]
fn test_num_ops() {
let left = Value::from(10);
let right = Value::from(3);
let res = add::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(13));
let res = sub::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(7));
let res = mul::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(30));
let res = div::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(3));
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(1));
let values = vec![Value::from(true), Value::from(false)];
let exprs = vec![ScalarExpr::Column(0), ScalarExpr::Column(1)];
let res = and(&values, &exprs).unwrap();
assert_eq!(res, Value::from(false));
let res = or(&values, &exprs).unwrap();
assert_eq!(res, Value::from(true));
}

View File

@@ -18,7 +18,9 @@ use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
use serde::{Deserialize, Serialize};
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
use crate::expr::error::{
EvalError, InvalidArgumentSnafu, OptimizeSnafu, UnsupportedTemporalFilterSnafu,
};
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
/// A scalar expression, which can be evaluated to a value.
@@ -59,3 +61,338 @@ pub enum ScalarExpr {
els: Box<ScalarExpr>,
},
}
impl ScalarExpr {
pub fn call_unary(self, func: UnaryFunc) -> Self {
ScalarExpr::CallUnary {
func,
expr: Box::new(self),
}
}
pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
ScalarExpr::CallBinary {
func,
expr1: Box::new(self),
expr2: Box::new(other),
}
}
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
match self {
ScalarExpr::Column(index) => Ok(values[*index].clone()),
ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
ScalarExpr::CallUnmaterializable(f) => OptimizeSnafu {
reason: "Can't eval unmaterializable function".to_string(),
}
.fail(),
ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
ScalarExpr::If { cond, then, els } => match cond.eval(values) {
Ok(Value::Boolean(true)) => then.eval(values),
Ok(Value::Boolean(false)) => els.eval(values),
_ => InvalidArgumentSnafu {
reason: "if condition must be boolean".to_string(),
}
.fail(),
},
}
}
/// Rewrites column indices with their value in `permutation`.
///
/// This method is applicable even when `permutation` is not a
/// strict permutation, and it only needs to have entries for
/// each column referenced in `self`.
pub fn permute(&mut self, permutation: &[usize]) {
self.visit_mut_post_nolimit(&mut |e| {
if let ScalarExpr::Column(old_i) = e {
*old_i = permutation[*old_i];
}
});
}
/// Rewrites column indices with their value in `permutation`.
///
/// This method is applicable even when `permutation` is not a
/// strict permutation, and it only needs to have entries for
/// each column referenced in `self`.
pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
self.visit_mut_post_nolimit(&mut |e| {
if let ScalarExpr::Column(old_i) = e {
*old_i = permutation[old_i];
}
});
}
/// Returns the set of columns that are referenced by `self`.
pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
let mut support = BTreeSet::new();
self.visit_post_nolimit(&mut |e| {
if let ScalarExpr::Column(i) = e {
support.insert(*i);
}
});
support
}
pub fn as_literal(&self) -> Option<Value> {
if let ScalarExpr::Literal(lit, _column_type) = self {
Some(lit.clone())
} else {
None
}
}
pub fn is_literal(&self) -> bool {
matches!(self, ScalarExpr::Literal(..))
}
pub fn is_literal_true(&self) -> bool {
Some(Value::Boolean(true)) == self.as_literal()
}
pub fn is_literal_false(&self) -> bool {
Some(Value::Boolean(false)) == self.as_literal()
}
pub fn is_literal_null(&self) -> bool {
Some(Value::Null) == self.as_literal()
}
pub fn literal_null() -> Self {
ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
}
pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
ScalarExpr::Literal(res, typ)
}
pub fn literal_false() -> Self {
ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
}
pub fn literal_true() -> Self {
ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
}
}
impl ScalarExpr {
/// visit post-order without stack call limit, but may cause stack overflow
fn visit_post_nolimit<F>(&self, f: &mut F)
where
F: FnMut(&Self),
{
self.visit_children(|e| e.visit_post_nolimit(f));
f(self);
}
fn visit_children<F>(&self, mut f: F)
where
F: FnMut(&Self),
{
match self {
ScalarExpr::Column(_)
| ScalarExpr::Literal(_, _)
| ScalarExpr::CallUnmaterializable(_) => (),
ScalarExpr::CallUnary { expr, .. } => f(expr),
ScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
ScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
ScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
}
fn visit_mut_post_nolimit<F>(&mut self, f: &mut F)
where
F: FnMut(&mut Self),
{
self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f));
f(self);
}
fn visit_mut_children<F>(&mut self, mut f: F)
where
F: FnMut(&mut Self),
{
match self {
ScalarExpr::Column(_)
| ScalarExpr::Literal(_, _)
| ScalarExpr::CallUnmaterializable(_) => (),
ScalarExpr::CallUnary { expr, .. } => f(expr),
ScalarExpr::CallBinary { expr1, expr2, .. } => {
f(expr1);
f(expr2);
}
ScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs {
f(expr);
}
}
ScalarExpr::If { cond, then, els } => {
f(cond);
f(then);
f(els);
}
}
}
}
impl ScalarExpr {
/// if expr contains function `Now`
pub fn contains_temporal(&self) -> bool {
let mut contains = false;
self.visit_post_nolimit(&mut |e| {
if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
contains = true;
}
});
contains
}
/// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound`
///
/// returned bool indicates whether the bound is upper bound:
///
/// false for lower bound, true for upper bound
/// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a`
pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), EvalError> {
let unsupported_err = |msg: &str| {
UnsupportedTemporalFilterSnafu {
reason: msg.to_string(),
}
.fail()
};
let Self::CallBinary {
mut func,
mut expr1,
mut expr2,
} = self.clone()
else {
return unsupported_err("Not a binary expression");
};
// TODO: support simple transform like `now() + a < b` to `now() < b - a`
let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
if !(expr1_is_now ^ expr2_is_now) {
return unsupported_err("None of the sides of the comparison is `now()`");
}
if expr2_is_now {
std::mem::swap(&mut expr1, &mut expr2);
func = BinaryFunc::reverse_compare(&func)?;
}
let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
match func {
// now == expr2 -> now <= expr2 && now < expr2 + 1
BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
// now < expr2 -> now < expr2
BinaryFunc::Lt => Ok((None, Some(*expr2))),
// now <= expr2 -> now < expr2 + 1
BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
// now > expr2 -> now >= expr2 + 1
BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
// now >= expr2 -> now >= expr2
BinaryFunc::Gte => Ok((Some(*expr2), None)),
_ => unreachable!("Already checked"),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_extract_bound() {
let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
// col(0) == now
(
ScalarExpr::CallBinary {
func: BinaryFunc::Eq,
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
expr2: Box::new(ScalarExpr::Column(0)),
},
Ok((
Some(ScalarExpr::Column(0)),
Some(ScalarExpr::CallUnary {
func: UnaryFunc::StepTimestamp,
expr: Box::new(ScalarExpr::Column(0)),
}),
)),
),
// now < col(0)
(
ScalarExpr::CallBinary {
func: BinaryFunc::Lt,
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
expr2: Box::new(ScalarExpr::Column(0)),
},
Ok((None, Some(ScalarExpr::Column(0)))),
),
// now <= col(0)
(
ScalarExpr::CallBinary {
func: BinaryFunc::Lte,
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
expr2: Box::new(ScalarExpr::Column(0)),
},
Ok((
None,
Some(ScalarExpr::CallUnary {
func: UnaryFunc::StepTimestamp,
expr: Box::new(ScalarExpr::Column(0)),
}),
)),
),
// now > col(0) -> now >= col(0) + 1
(
ScalarExpr::CallBinary {
func: BinaryFunc::Gt,
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
expr2: Box::new(ScalarExpr::Column(0)),
},
Ok((
Some(ScalarExpr::CallUnary {
func: UnaryFunc::StepTimestamp,
expr: Box::new(ScalarExpr::Column(0)),
}),
None,
)),
),
// now >= col(0)
(
ScalarExpr::CallBinary {
func: BinaryFunc::Gte,
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
expr2: Box::new(ScalarExpr::Column(0)),
},
Ok((Some(ScalarExpr::Column(0)), None)),
),
];
for (expr, expected) in test_list.into_iter() {
let actual = expr.extract_bound();
// EvalError is not Eq, so we need to compare the error message
match (actual, expected) {
(Ok(l), Ok(r)) => assert_eq!(l, r),
(Err(l), Err(r)) => assert!(matches!(l, r)),
(l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
}
}
}
}