mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 14:22:58 +00:00
feat(flow): impl ScalarExpr&Scalar Function (#3283)
* feat: impl for ScalarExpr * feat: plain functions * refactor: simpler trait bound&tests * chore: remove unused imports * chore: fmt * refactor: early ret on first error * refactor: remove abunant match arm * chore: per review * doc: `support` fn * chore: per review more * chore: more per review * fix: extract_bound * chore: per review * refactor: reduce nest
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3406,6 +3406,7 @@ dependencies = [
|
||||
"datatypes",
|
||||
"hydroflow",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"servers",
|
||||
"session",
|
||||
|
||||
@@ -17,6 +17,7 @@ common-time.workspace = true
|
||||
datatypes.workspace = true
|
||||
hydroflow = "0.5.0"
|
||||
itertools.workspace = true
|
||||
num-traits = "0.2"
|
||||
serde.workspace = true
|
||||
servers.workspace = true
|
||||
session.workspace = true
|
||||
|
||||
@@ -58,4 +58,7 @@ pub enum EvalError {
|
||||
|
||||
#[snafu(display("Optimize error: {reason}"))]
|
||||
Optimize { reason: String, location: Location },
|
||||
|
||||
#[snafu(display("Unsupported temporal filter: {reason}"))]
|
||||
UnsupportedTemporalFilter { reason: String, location: Location },
|
||||
}
|
||||
|
||||
@@ -21,14 +21,12 @@ use hydroflow::bincode::Error;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use super::ScalarExpr;
|
||||
use crate::expr::error::CastValueSnafu;
|
||||
use crate::expr::InvalidArgumentSnafu;
|
||||
// TODO(discord9): more function & eval
|
||||
use crate::{
|
||||
expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu},
|
||||
repr::Row,
|
||||
use crate::expr::error::{
|
||||
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu,
|
||||
TypeMismatchSnafu,
|
||||
};
|
||||
use crate::expr::{InvalidArgumentSnafu, ScalarExpr};
|
||||
use crate::repr::Row;
|
||||
|
||||
/// UnmaterializableFunc is a function that can't be eval independently,
|
||||
/// and require special handling
|
||||
@@ -47,6 +45,66 @@ pub enum UnaryFunc {
|
||||
StepTimestamp,
|
||||
Cast(ConcreteDataType),
|
||||
}
|
||||
|
||||
impl UnaryFunc {
|
||||
pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result<Value, EvalError> {
|
||||
let arg = expr.eval(values)?;
|
||||
match self {
|
||||
Self::Not => {
|
||||
let bool = if let Value::Boolean(bool) = arg {
|
||||
Ok(bool)
|
||||
} else {
|
||||
TypeMismatchSnafu {
|
||||
expected: ConcreteDataType::boolean_datatype(),
|
||||
actual: arg.data_type(),
|
||||
}
|
||||
.fail()?
|
||||
}?;
|
||||
Ok(Value::from(!bool))
|
||||
}
|
||||
Self::IsNull => Ok(Value::from(arg.is_null())),
|
||||
Self::IsTrue | Self::IsFalse => {
|
||||
let bool = if let Value::Boolean(bool) = arg {
|
||||
Ok(bool)
|
||||
} else {
|
||||
TypeMismatchSnafu {
|
||||
expected: ConcreteDataType::boolean_datatype(),
|
||||
actual: arg.data_type(),
|
||||
}
|
||||
.fail()?
|
||||
}?;
|
||||
if matches!(self, Self::IsTrue) {
|
||||
Ok(Value::from(bool))
|
||||
} else {
|
||||
Ok(Value::from(!bool))
|
||||
}
|
||||
}
|
||||
Self::StepTimestamp => {
|
||||
if let Value::DateTime(datetime) = arg {
|
||||
let datetime = DateTime::from(datetime.val() + 1);
|
||||
Ok(Value::from(datetime))
|
||||
} else {
|
||||
TypeMismatchSnafu {
|
||||
expected: ConcreteDataType::datetime_datatype(),
|
||||
actual: arg.data_type(),
|
||||
}
|
||||
.fail()?
|
||||
}
|
||||
}
|
||||
Self::Cast(to) => {
|
||||
let arg_ty = arg.data_type();
|
||||
let res = cast(arg, to).context({
|
||||
CastValueSnafu {
|
||||
from: arg_ty,
|
||||
to: to.clone(),
|
||||
}
|
||||
})?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO(discord9): support more binary functions for more types
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
|
||||
pub enum BinaryFunc {
|
||||
@@ -96,8 +154,232 @@ pub enum BinaryFunc {
|
||||
ModUInt64,
|
||||
}
|
||||
|
||||
impl BinaryFunc {
|
||||
pub fn eval(
|
||||
&self,
|
||||
values: &[Value],
|
||||
expr1: &ScalarExpr,
|
||||
expr2: &ScalarExpr,
|
||||
) -> Result<Value, EvalError> {
|
||||
let left = expr1.eval(values)?;
|
||||
let right = expr2.eval(values)?;
|
||||
match self {
|
||||
Self::Eq => Ok(Value::from(left == right)),
|
||||
Self::NotEq => Ok(Value::from(left != right)),
|
||||
Self::Lt => Ok(Value::from(left < right)),
|
||||
Self::Lte => Ok(Value::from(left <= right)),
|
||||
Self::Gt => Ok(Value::from(left > right)),
|
||||
Self::Gte => Ok(Value::from(left >= right)),
|
||||
|
||||
Self::AddInt16 => Ok(add::<i16>(left, right)?),
|
||||
Self::AddInt32 => Ok(add::<i32>(left, right)?),
|
||||
Self::AddInt64 => Ok(add::<i64>(left, right)?),
|
||||
Self::AddUInt16 => Ok(add::<u16>(left, right)?),
|
||||
Self::AddUInt32 => Ok(add::<u32>(left, right)?),
|
||||
Self::AddUInt64 => Ok(add::<u64>(left, right)?),
|
||||
Self::AddFloat32 => Ok(add::<f32>(left, right)?),
|
||||
Self::AddFloat64 => Ok(add::<f64>(left, right)?),
|
||||
|
||||
Self::SubInt16 => Ok(sub::<i16>(left, right)?),
|
||||
Self::SubInt32 => Ok(sub::<i32>(left, right)?),
|
||||
Self::SubInt64 => Ok(sub::<i64>(left, right)?),
|
||||
Self::SubUInt16 => Ok(sub::<u16>(left, right)?),
|
||||
Self::SubUInt32 => Ok(sub::<u32>(left, right)?),
|
||||
Self::SubUInt64 => Ok(sub::<u64>(left, right)?),
|
||||
Self::SubFloat32 => Ok(sub::<f32>(left, right)?),
|
||||
Self::SubFloat64 => Ok(sub::<f64>(left, right)?),
|
||||
|
||||
Self::MulInt16 => Ok(mul::<i16>(left, right)?),
|
||||
Self::MulInt32 => Ok(mul::<i32>(left, right)?),
|
||||
Self::MulInt64 => Ok(mul::<i64>(left, right)?),
|
||||
Self::MulUInt16 => Ok(mul::<u16>(left, right)?),
|
||||
Self::MulUInt32 => Ok(mul::<u32>(left, right)?),
|
||||
Self::MulUInt64 => Ok(mul::<u64>(left, right)?),
|
||||
Self::MulFloat32 => Ok(mul::<f32>(left, right)?),
|
||||
Self::MulFloat64 => Ok(mul::<f64>(left, right)?),
|
||||
|
||||
Self::DivInt16 => Ok(div::<i16>(left, right)?),
|
||||
Self::DivInt32 => Ok(div::<i32>(left, right)?),
|
||||
Self::DivInt64 => Ok(div::<i64>(left, right)?),
|
||||
Self::DivUInt16 => Ok(div::<u16>(left, right)?),
|
||||
Self::DivUInt32 => Ok(div::<u32>(left, right)?),
|
||||
Self::DivUInt64 => Ok(div::<u64>(left, right)?),
|
||||
Self::DivFloat32 => Ok(div::<f32>(left, right)?),
|
||||
Self::DivFloat64 => Ok(div::<f64>(left, right)?),
|
||||
|
||||
Self::ModInt16 => Ok(rem::<i16>(left, right)?),
|
||||
Self::ModInt32 => Ok(rem::<i32>(left, right)?),
|
||||
Self::ModInt64 => Ok(rem::<i64>(left, right)?),
|
||||
Self::ModUInt16 => Ok(rem::<u16>(left, right)?),
|
||||
Self::ModUInt32 => Ok(rem::<u32>(left, right)?),
|
||||
Self::ModUInt64 => Ok(rem::<u64>(left, right)?),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reverse the comparison operator, i.e. `a < b` becomes `b > a`,
|
||||
/// equal and not equal are unchanged.
|
||||
pub fn reverse_compare(&self) -> Result<Self, EvalError> {
|
||||
let ret = match &self {
|
||||
BinaryFunc::Eq => BinaryFunc::Eq,
|
||||
BinaryFunc::NotEq => BinaryFunc::NotEq,
|
||||
BinaryFunc::Lt => BinaryFunc::Gt,
|
||||
BinaryFunc::Lte => BinaryFunc::Gte,
|
||||
BinaryFunc::Gt => BinaryFunc::Lt,
|
||||
BinaryFunc::Gte => BinaryFunc::Lte,
|
||||
_ => {
|
||||
return InternalSnafu {
|
||||
reason: format!("Expect a comparison operator, found {:?}", self),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
|
||||
pub enum VariadicFunc {
|
||||
And,
|
||||
Or,
|
||||
}
|
||||
|
||||
impl VariadicFunc {
|
||||
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
|
||||
match self {
|
||||
VariadicFunc::And => and(values, exprs),
|
||||
VariadicFunc::Or => or(values, exprs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn and(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
|
||||
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
|
||||
let mut null = false;
|
||||
for expr in exprs {
|
||||
match expr.eval(values) {
|
||||
Ok(Value::Boolean(true)) => {}
|
||||
Ok(Value::Boolean(false)) => return Ok(Value::Boolean(false)), // short-circuit
|
||||
Ok(Value::Null) => null = true,
|
||||
Err(this_err) => {
|
||||
return Err(this_err);
|
||||
} // retain first error encountered
|
||||
Ok(x) => InvalidArgumentSnafu {
|
||||
reason: format!(
|
||||
"`and()` only support boolean type, found value {:?} of type {:?}",
|
||||
x,
|
||||
x.data_type()
|
||||
),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
match null {
|
||||
true => Ok(Value::Null),
|
||||
false => Ok(Value::Boolean(true)),
|
||||
}
|
||||
}
|
||||
|
||||
fn or(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
|
||||
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
|
||||
let mut null = false;
|
||||
for expr in exprs {
|
||||
match expr.eval(values) {
|
||||
Ok(Value::Boolean(true)) => return Ok(Value::Boolean(true)), // short-circuit
|
||||
Ok(Value::Boolean(false)) => {}
|
||||
Ok(Value::Null) => null = true,
|
||||
Err(this_err) => {
|
||||
return Err(this_err);
|
||||
} // retain first error encountered
|
||||
Ok(x) => InvalidArgumentSnafu {
|
||||
reason: format!(
|
||||
"`or()` only support boolean type, found value {:?} of type {:?}",
|
||||
x,
|
||||
x.data_type()
|
||||
),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
match null {
|
||||
true => Ok(Value::Null),
|
||||
false => Ok(Value::Boolean(false)),
|
||||
}
|
||||
}
|
||||
|
||||
fn add<T>(left: Value, right: Value) -> Result<Value, EvalError>
|
||||
where
|
||||
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
|
||||
Value: From<T>,
|
||||
{
|
||||
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
Ok(Value::from(left + right))
|
||||
}
|
||||
|
||||
fn sub<T>(left: Value, right: Value) -> Result<Value, EvalError>
|
||||
where
|
||||
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
|
||||
Value: From<T>,
|
||||
{
|
||||
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
Ok(Value::from(left - right))
|
||||
}
|
||||
|
||||
fn mul<T>(left: Value, right: Value) -> Result<Value, EvalError>
|
||||
where
|
||||
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
|
||||
Value: From<T>,
|
||||
{
|
||||
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
Ok(Value::from(left * right))
|
||||
}
|
||||
|
||||
fn div<T>(left: Value, right: Value) -> Result<Value, EvalError>
|
||||
where
|
||||
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
|
||||
<T as TryFrom<Value>>::Error: std::fmt::Debug,
|
||||
Value: From<T>,
|
||||
{
|
||||
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
if right.is_zero() {
|
||||
return Err(DivisionByZeroSnafu {}.build());
|
||||
}
|
||||
Ok(Value::from(left / right))
|
||||
}
|
||||
|
||||
fn rem<T>(left: Value, right: Value) -> Result<Value, EvalError>
|
||||
where
|
||||
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
|
||||
<T as TryFrom<Value>>::Error: std::fmt::Debug,
|
||||
Value: From<T>,
|
||||
{
|
||||
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
|
||||
Ok(Value::from(left % right))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_num_ops() {
|
||||
let left = Value::from(10);
|
||||
let right = Value::from(3);
|
||||
let res = add::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(13));
|
||||
let res = sub::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(7));
|
||||
let res = mul::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(30));
|
||||
let res = div::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(3));
|
||||
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(1));
|
||||
|
||||
let values = vec![Value::from(true), Value::from(false)];
|
||||
let exprs = vec![ScalarExpr::Column(0), ScalarExpr::Column(1)];
|
||||
let res = and(&values, &exprs).unwrap();
|
||||
assert_eq!(res, Value::from(false));
|
||||
let res = or(&values, &exprs).unwrap();
|
||||
assert_eq!(res, Value::from(true));
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::Value;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
|
||||
use crate::expr::error::{
|
||||
EvalError, InvalidArgumentSnafu, OptimizeSnafu, UnsupportedTemporalFilterSnafu,
|
||||
};
|
||||
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
|
||||
|
||||
/// A scalar expression, which can be evaluated to a value.
|
||||
@@ -59,3 +61,338 @@ pub enum ScalarExpr {
|
||||
els: Box<ScalarExpr>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
pub fn call_unary(self, func: UnaryFunc) -> Self {
|
||||
ScalarExpr::CallUnary {
|
||||
func,
|
||||
expr: Box::new(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
|
||||
ScalarExpr::CallBinary {
|
||||
func,
|
||||
expr1: Box::new(self),
|
||||
expr2: Box::new(other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
|
||||
match self {
|
||||
ScalarExpr::Column(index) => Ok(values[*index].clone()),
|
||||
ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
|
||||
ScalarExpr::CallUnmaterializable(f) => OptimizeSnafu {
|
||||
reason: "Can't eval unmaterializable function".to_string(),
|
||||
}
|
||||
.fail(),
|
||||
ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
|
||||
ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
|
||||
ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
|
||||
ScalarExpr::If { cond, then, els } => match cond.eval(values) {
|
||||
Ok(Value::Boolean(true)) => then.eval(values),
|
||||
Ok(Value::Boolean(false)) => els.eval(values),
|
||||
_ => InvalidArgumentSnafu {
|
||||
reason: "if condition must be boolean".to_string(),
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrites column indices with their value in `permutation`.
|
||||
///
|
||||
/// This method is applicable even when `permutation` is not a
|
||||
/// strict permutation, and it only needs to have entries for
|
||||
/// each column referenced in `self`.
|
||||
pub fn permute(&mut self, permutation: &[usize]) {
|
||||
self.visit_mut_post_nolimit(&mut |e| {
|
||||
if let ScalarExpr::Column(old_i) = e {
|
||||
*old_i = permutation[*old_i];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Rewrites column indices with their value in `permutation`.
|
||||
///
|
||||
/// This method is applicable even when `permutation` is not a
|
||||
/// strict permutation, and it only needs to have entries for
|
||||
/// each column referenced in `self`.
|
||||
pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) {
|
||||
self.visit_mut_post_nolimit(&mut |e| {
|
||||
if let ScalarExpr::Column(old_i) = e {
|
||||
*old_i = permutation[old_i];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns the set of columns that are referenced by `self`.
|
||||
pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
|
||||
let mut support = BTreeSet::new();
|
||||
self.visit_post_nolimit(&mut |e| {
|
||||
if let ScalarExpr::Column(i) = e {
|
||||
support.insert(*i);
|
||||
}
|
||||
});
|
||||
support
|
||||
}
|
||||
|
||||
pub fn as_literal(&self) -> Option<Value> {
|
||||
if let ScalarExpr::Literal(lit, _column_type) = self {
|
||||
Some(lit.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_literal(&self) -> bool {
|
||||
matches!(self, ScalarExpr::Literal(..))
|
||||
}
|
||||
|
||||
pub fn is_literal_true(&self) -> bool {
|
||||
Some(Value::Boolean(true)) == self.as_literal()
|
||||
}
|
||||
|
||||
pub fn is_literal_false(&self) -> bool {
|
||||
Some(Value::Boolean(false)) == self.as_literal()
|
||||
}
|
||||
|
||||
pub fn is_literal_null(&self) -> bool {
|
||||
Some(Value::Null) == self.as_literal()
|
||||
}
|
||||
|
||||
pub fn literal_null() -> Self {
|
||||
ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
|
||||
}
|
||||
|
||||
pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
|
||||
ScalarExpr::Literal(res, typ)
|
||||
}
|
||||
|
||||
pub fn literal_false() -> Self {
|
||||
ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
|
||||
}
|
||||
|
||||
pub fn literal_true() -> Self {
|
||||
ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
/// visit post-order without stack call limit, but may cause stack overflow
|
||||
fn visit_post_nolimit<F>(&self, f: &mut F)
|
||||
where
|
||||
F: FnMut(&Self),
|
||||
{
|
||||
self.visit_children(|e| e.visit_post_nolimit(f));
|
||||
f(self);
|
||||
}
|
||||
|
||||
fn visit_children<F>(&self, mut f: F)
|
||||
where
|
||||
F: FnMut(&Self),
|
||||
{
|
||||
match self {
|
||||
ScalarExpr::Column(_)
|
||||
| ScalarExpr::Literal(_, _)
|
||||
| ScalarExpr::CallUnmaterializable(_) => (),
|
||||
ScalarExpr::CallUnary { expr, .. } => f(expr),
|
||||
ScalarExpr::CallBinary { expr1, expr2, .. } => {
|
||||
f(expr1);
|
||||
f(expr2);
|
||||
}
|
||||
ScalarExpr::CallVariadic { exprs, .. } => {
|
||||
for expr in exprs {
|
||||
f(expr);
|
||||
}
|
||||
}
|
||||
ScalarExpr::If { cond, then, els } => {
|
||||
f(cond);
|
||||
f(then);
|
||||
f(els);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_mut_post_nolimit<F>(&mut self, f: &mut F)
|
||||
where
|
||||
F: FnMut(&mut Self),
|
||||
{
|
||||
self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f));
|
||||
f(self);
|
||||
}
|
||||
|
||||
fn visit_mut_children<F>(&mut self, mut f: F)
|
||||
where
|
||||
F: FnMut(&mut Self),
|
||||
{
|
||||
match self {
|
||||
ScalarExpr::Column(_)
|
||||
| ScalarExpr::Literal(_, _)
|
||||
| ScalarExpr::CallUnmaterializable(_) => (),
|
||||
ScalarExpr::CallUnary { expr, .. } => f(expr),
|
||||
ScalarExpr::CallBinary { expr1, expr2, .. } => {
|
||||
f(expr1);
|
||||
f(expr2);
|
||||
}
|
||||
ScalarExpr::CallVariadic { exprs, .. } => {
|
||||
for expr in exprs {
|
||||
f(expr);
|
||||
}
|
||||
}
|
||||
ScalarExpr::If { cond, then, els } => {
|
||||
f(cond);
|
||||
f(then);
|
||||
f(els);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
/// if expr contains function `Now`
|
||||
pub fn contains_temporal(&self) -> bool {
|
||||
let mut contains = false;
|
||||
self.visit_post_nolimit(&mut |e| {
|
||||
if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
|
||||
contains = true;
|
||||
}
|
||||
});
|
||||
contains
|
||||
}
|
||||
|
||||
/// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound`
|
||||
///
|
||||
/// returned bool indicates whether the bound is upper bound:
|
||||
///
|
||||
/// false for lower bound, true for upper bound
|
||||
/// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a`
|
||||
pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), EvalError> {
|
||||
let unsupported_err = |msg: &str| {
|
||||
UnsupportedTemporalFilterSnafu {
|
||||
reason: msg.to_string(),
|
||||
}
|
||||
.fail()
|
||||
};
|
||||
|
||||
let Self::CallBinary {
|
||||
mut func,
|
||||
mut expr1,
|
||||
mut expr2,
|
||||
} = self.clone()
|
||||
else {
|
||||
return unsupported_err("Not a binary expression");
|
||||
};
|
||||
|
||||
// TODO: support simple transform like `now() + a < b` to `now() < b - a`
|
||||
|
||||
let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
|
||||
let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
|
||||
|
||||
if !(expr1_is_now ^ expr2_is_now) {
|
||||
return unsupported_err("None of the sides of the comparison is `now()`");
|
||||
}
|
||||
|
||||
if expr2_is_now {
|
||||
std::mem::swap(&mut expr1, &mut expr2);
|
||||
func = BinaryFunc::reverse_compare(&func)?;
|
||||
}
|
||||
|
||||
let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
|
||||
match func {
|
||||
// now == expr2 -> now <= expr2 && now < expr2 + 1
|
||||
BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
|
||||
// now < expr2 -> now < expr2
|
||||
BinaryFunc::Lt => Ok((None, Some(*expr2))),
|
||||
// now <= expr2 -> now < expr2 + 1
|
||||
BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
|
||||
// now > expr2 -> now >= expr2 + 1
|
||||
BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
|
||||
// now >= expr2 -> now >= expr2
|
||||
BinaryFunc::Gte => Ok((Some(*expr2), None)),
|
||||
_ => unreachable!("Already checked"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_extract_bound() {
|
||||
let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
|
||||
// col(0) == now
|
||||
(
|
||||
ScalarExpr::CallBinary {
|
||||
func: BinaryFunc::Eq,
|
||||
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
|
||||
expr2: Box::new(ScalarExpr::Column(0)),
|
||||
},
|
||||
Ok((
|
||||
Some(ScalarExpr::Column(0)),
|
||||
Some(ScalarExpr::CallUnary {
|
||||
func: UnaryFunc::StepTimestamp,
|
||||
expr: Box::new(ScalarExpr::Column(0)),
|
||||
}),
|
||||
)),
|
||||
),
|
||||
// now < col(0)
|
||||
(
|
||||
ScalarExpr::CallBinary {
|
||||
func: BinaryFunc::Lt,
|
||||
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
|
||||
expr2: Box::new(ScalarExpr::Column(0)),
|
||||
},
|
||||
Ok((None, Some(ScalarExpr::Column(0)))),
|
||||
),
|
||||
// now <= col(0)
|
||||
(
|
||||
ScalarExpr::CallBinary {
|
||||
func: BinaryFunc::Lte,
|
||||
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
|
||||
expr2: Box::new(ScalarExpr::Column(0)),
|
||||
},
|
||||
Ok((
|
||||
None,
|
||||
Some(ScalarExpr::CallUnary {
|
||||
func: UnaryFunc::StepTimestamp,
|
||||
expr: Box::new(ScalarExpr::Column(0)),
|
||||
}),
|
||||
)),
|
||||
),
|
||||
// now > col(0) -> now >= col(0) + 1
|
||||
(
|
||||
ScalarExpr::CallBinary {
|
||||
func: BinaryFunc::Gt,
|
||||
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
|
||||
expr2: Box::new(ScalarExpr::Column(0)),
|
||||
},
|
||||
Ok((
|
||||
Some(ScalarExpr::CallUnary {
|
||||
func: UnaryFunc::StepTimestamp,
|
||||
expr: Box::new(ScalarExpr::Column(0)),
|
||||
}),
|
||||
None,
|
||||
)),
|
||||
),
|
||||
// now >= col(0)
|
||||
(
|
||||
ScalarExpr::CallBinary {
|
||||
func: BinaryFunc::Gte,
|
||||
expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
|
||||
expr2: Box::new(ScalarExpr::Column(0)),
|
||||
},
|
||||
Ok((Some(ScalarExpr::Column(0)), None)),
|
||||
),
|
||||
];
|
||||
for (expr, expected) in test_list.into_iter() {
|
||||
let actual = expr.extract_bound();
|
||||
// EvalError is not Eq, so we need to compare the error message
|
||||
match (actual, expected) {
|
||||
(Ok(l), Ok(r)) => assert_eq!(l, r),
|
||||
(Err(l), Err(r)) => assert!(matches!(l, r)),
|
||||
(l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user