fix: subquery&cte time window expr

This commit is contained in:
discord9
2025-03-04 20:46:11 +08:00
parent 0374c332d2
commit a4c2ada639
2 changed files with 130 additions and 137 deletions

View File

@@ -17,7 +17,7 @@
mod engine;
mod frontend_client;
use std::collections::{BTreeSet, HashSet};
use std::collections::BTreeSet;
use std::sync::Arc;
use api::helper::pb_value_to_value_ref;
@@ -33,11 +33,12 @@ use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
use datafusion::prelude::SessionContext;
use datafusion::sql::unparser::Unparser;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
use datafusion_common::{Column, DFSchema, TableReference};
use datafusion_common::{DFSchema, TableReference};
use datafusion_expr::{ColumnarValue, LogicalPlan};
use datafusion_physical_expr::PhysicalExprRef;
use datatypes::prelude::{ConcreteDataType, DataType};
use datatypes::scalars::ScalarVector;
use datatypes::schema::TIME_INDEX_KEY;
use datatypes::value::Value;
use datatypes::vectors::{
TimestampMicrosecondVector, TimestampMillisecondVector, TimestampNanosecondVector,
@@ -264,6 +265,7 @@ async fn find_time_window_expr(
// TODO(discord9): find the expr that do time window
let mut table_name = None;
// first find the table source in the logical plan
plan.apply(|plan| {
let LogicalPlan::TableScan(table_scan) = plan else {
@@ -316,45 +318,6 @@ async fn find_time_window_expr(
),
})?.unit();
let ts_columns: HashSet<_> = HashSet::from_iter(vec![
format!("{catalog_name}.{schema_name}.{table_name}.{ts_col_name}"),
format!("{schema_name}.{table_name}.{ts_col_name}"),
format!("{table_name}.{ts_col_name}"),
format!("{ts_col_name}"),
]);
let ts_columns: HashSet<_> = ts_columns
.into_iter()
.map(Column::from_qualified_name)
.collect();
let ts_columns_ref: HashSet<&Column> = ts_columns.iter().collect();
// find the time window expr which refers to the time index column
let mut time_window_expr: Option<Expr> = None;
let find_time_window_expr = |plan: &LogicalPlan| {
let LogicalPlan::Aggregate(aggregate) = plan else {
return Ok(TreeNodeRecursion::Continue);
};
for group_expr in &aggregate.group_expr {
let refs = group_expr.column_refs();
if refs.len() != 1 {
continue;
}
let ref_col = refs.iter().next().unwrap();
if ts_columns_ref.contains(ref_col) {
time_window_expr = Some(group_expr.clone());
break;
}
}
Ok(TreeNodeRecursion::Stop)
};
plan.apply(find_time_window_expr)
.with_context(|_| DatafusionSnafu {
context: format!("Can't find time window expr in plan {plan:?}"),
})?;
let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
ts_col_name.clone(),
ts_index.data_type.as_arrow_type(),
@@ -368,7 +331,75 @@ async fn find_time_window_expr(
.with_context(|_e| DatafusionSnafu {
context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"),
})?;
Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema))
// find the time window expr which refers to the time index column
let mut aggr_expr = None;
let mut time_window_expr: Option<Expr> = None;
let find_inner_aggr_expr = |plan: &LogicalPlan| {
if let LogicalPlan::Aggregate(aggregate) = plan {
aggr_expr = Some(aggregate.clone());
};
Ok(TreeNodeRecursion::Continue)
};
plan.apply(find_inner_aggr_expr)
.with_context(|_| DatafusionSnafu {
context: format!("Can't find aggr expr in plan {plan:?}"),
})?;
if let Some(aggregate) = aggr_expr {
for group_expr in &aggregate.group_expr {
let refs = group_expr.column_refs();
if refs.len() != 1 {
continue;
}
let ref_col = refs.iter().next().unwrap();
let index = aggregate.input.schema().maybe_index_of_column(ref_col);
let Some(index) = index else {
continue;
};
let field = aggregate.input.schema().field(index);
let is_time_index = field.metadata().get(TIME_INDEX_KEY) == Some(&"true".to_string());
if is_time_index {
let rewrite_column = group_expr.clone();
let rewritten = rewrite_column
.rewrite(&mut RewriteColumn {
table_name: table_name.to_string(),
})
.with_context(|_| DatafusionSnafu {
context: format!("Rewrite expr failed, expr={:?}", group_expr),
})?
.data;
struct RewriteColumn {
table_name: String,
}
impl TreeNodeRewriter for RewriteColumn {
type Node = Expr;
fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
let Expr::Column(mut column) = node else {
return Ok(Transformed::no(node));
};
column.relation = Some(TableReference::bare(self.table_name.clone()));
Ok(Transformed::yes(Expr::Column(column)))
}
}
time_window_expr = Some(rewritten);
break;
}
}
Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema))
} else {
// can't found time window expr, return None
Ok((ts_col_name, None, expected_time_unit, df_schema))
}
}
/// Find nearest lower bound for time `current` in given `plan` for the time window expr.
@@ -393,7 +424,6 @@ pub async fn find_plan_time_window_bound(
let (ts_col_name, time_window_expr, expected_time_unit, df_schema) =
find_time_window_expr(plan, catalog_man.clone(), query_ctx).await?;
// cast current to ts_index's type
let new_current = current
.convert_to(expected_time_unit)
@@ -425,8 +455,6 @@ fn find_expr_time_window_lower_bound(
df_schema: &DFSchema,
current: Timestamp,
) -> Result<Option<Timestamp>, Error> {
use std::cmp::Ordering;
let phy_planner = DefaultPhysicalPlanner::default();
let phy_expr: PhysicalExprRef = phy_planner
@@ -438,91 +466,8 @@ fn find_expr_time_window_lower_bound(
})?;
let cur_time_window = eval_ts_to_ts(&phy_expr, df_schema, current)?;
if cur_time_window == current {
return Ok(Some(current));
}
// search to find the lower bound
let mut offset: i64 = 1;
let lower_bound;
let mut upper_bound = Some(current);
// first expontial probe to found a range for binary search
loop {
let Some(next_val) = current.value().checked_sub(offset) else {
// no lower bound
return Ok(None);
};
let prev_time_probe = common_time::Timestamp::new(next_val, current.unit());
let prev_time_window = eval_ts_to_ts(&phy_expr, df_schema, prev_time_probe)?;
match prev_time_window.cmp(&cur_time_window) {
Ordering::Less => {
lower_bound = Some(prev_time_probe);
break;
}
Ordering::Equal => {
upper_bound = Some(prev_time_probe);
}
Ordering::Greater => {
UnexpectedSnafu {
reason: format!(
"Unsupported time window expression, expect monotonic increasing for time window expression {expr:?}"
),
}
.fail()?
}
}
let Some(new_offset) = offset.checked_mul(2) else {
// no lower bound
return Ok(None);
};
offset = new_offset;
}
// binary search for the exact lower bound
ensure!(lower_bound.map(|v|v.unit())==upper_bound.map(|v|v.unit()), UnexpectedSnafu{
reason: format!(" unit mismatch for time window expression {expr:?}, found {lower_bound:?} and {upper_bound:?}"),
});
let input_time_unit = lower_bound
.context(UnexpectedSnafu {
reason: "should have lower bound",
})?
.unit();
let mut low = lower_bound
.context(UnexpectedSnafu {
reason: "should have lower bound",
})?
.value();
let mut high = upper_bound
.context(UnexpectedSnafu {
reason: "should have upper bound",
})?
.value();
while low < high {
let mid = (low + high) / 2;
let mid_probe = common_time::Timestamp::new(mid, input_time_unit);
let mid_time_window = eval_ts_to_ts(&phy_expr, df_schema, mid_probe)?;
match mid_time_window.cmp(&cur_time_window) {
Ordering::Less => low = mid + 1,
Ordering::Equal => high = mid,
Ordering::Greater => UnexpectedSnafu {
reason: format!("Binary search failed for time window expression {expr:?}"),
}
.fail()?,
}
}
let final_lower_bound_for_time_window = common_time::Timestamp::new(low, input_time_unit);
Ok(Some(final_lower_bound_for_time_window))
let input_time_unit = cur_time_window.unit();
Ok(cur_time_window.convert_to(input_time_unit))
}
/// Find the upper bound for time window expression
@@ -688,7 +633,7 @@ impl AddFilterRewriter {
impl TreeNodeRewriter for AddFilterRewriter {
type Node = LogicalPlan;
fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
if self.is_rewritten {
return Ok(Transformed::no(node));
}
@@ -738,7 +683,7 @@ mod test {
),
(
"SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
"SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4))"
"SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)"
),
(
"SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
@@ -853,9 +798,53 @@ mod test {
),
"SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number"
),
// subquery
(
"SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
Timestamp::new(23, TimeUnit::Millisecond),
(
"ts".to_string(),
Some(Timestamp::new(0, TimeUnit::Millisecond)),
Some(Timestamp::new(300000, TimeUnit::Millisecond)),
),
"SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)"
),
// cte
(
"with cte as (select number, date_bin('5 minutes', ts) as time_window from numbers_with_ts GROUP BY time_window, number) select number, time_window from cte;",
Timestamp::new(23, TimeUnit::Millisecond),
(
"ts".to_string(),
Some(Timestamp::new(0, TimeUnit::Millisecond)),
Some(Timestamp::new(300000, TimeUnit::Millisecond)),
),
"SELECT cte.number, cte.time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number) AS cte"
),
// complex subquery without alias
(
"SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
Timestamp::new(23, TimeUnit::Millisecond),
(
"ts".to_string(),
Some(Timestamp::new(0, TimeUnit::Millisecond)),
Some(Timestamp::new(300000, TimeUnit::Millisecond)),
),
"SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name"
),
// complex subquery alias
(
"SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
Timestamp::new(23, TimeUnit::Millisecond),
(
"ts".to_string(),
Some(Timestamp::new(0, TimeUnit::Millisecond)),
Some(Timestamp::new(300000, TimeUnit::Millisecond)),
),
"SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) AS cte GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name"
),
];
for (sql, current, expected, unparsed) in testcases {
for (sql, current, expected, expected_unparsed) in testcases {
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
@@ -887,7 +876,7 @@ mod test {
} else {
sql.to_string()
};
assert_eq!(unparsed, new_sql);
assert_eq!(expected_unparsed, new_sql);
}
}
}

View File

@@ -434,10 +434,13 @@ impl RecordingRuleTask {
.dirty_time_windows
.gen_filter_exprs(&col_name, lower, window_size)?
}
_ => UnexpectedSnafu {
reason: format!("Can't get window size: lower={lower:?}, upper={upper:?}"),
_ => {
warn!(
"Flow id = {:?}, can't get window size: lower={lower:?}, upper={upper:?}, using the same query", self.flow_id
);
// since no time window lower/upper bound is found, just return the original query
return Ok(Some(self.query.clone()));
}
.fail()?,
}
};
@@ -454,6 +457,7 @@ impl RecordingRuleTask {
let Some(expr) = expr else {
// no new data, hence no need to update
debug!("Flow id={:?}, no new data, not update", self.flow_id);
return Ok(None);
};