diff --git a/src/flow/src/recording_rules.rs b/src/flow/src/recording_rules.rs index be2645fac6..00c67cf96c 100644 --- a/src/flow/src/recording_rules.rs +++ b/src/flow/src/recording_rules.rs @@ -17,7 +17,7 @@ mod engine; mod frontend_client; -use std::collections::{BTreeSet, HashSet}; +use std::collections::BTreeSet; use std::sync::Arc; use api::helper::pb_value_to_value_ref; @@ -33,11 +33,12 @@ use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use datafusion::prelude::SessionContext; use datafusion::sql::unparser::Unparser; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; -use datafusion_common::{Column, DFSchema, TableReference}; +use datafusion_common::{DFSchema, TableReference}; use datafusion_expr::{ColumnarValue, LogicalPlan}; use datafusion_physical_expr::PhysicalExprRef; use datatypes::prelude::{ConcreteDataType, DataType}; use datatypes::scalars::ScalarVector; +use datatypes::schema::TIME_INDEX_KEY; use datatypes::value::Value; use datatypes::vectors::{ TimestampMicrosecondVector, TimestampMillisecondVector, TimestampNanosecondVector, @@ -264,6 +265,7 @@ async fn find_time_window_expr( // TODO(discord9): find the expr that do time window let mut table_name = None; + // first find the table source in the logical plan plan.apply(|plan| { let LogicalPlan::TableScan(table_scan) = plan else { @@ -316,45 +318,6 @@ async fn find_time_window_expr( ), })?.unit(); - let ts_columns: HashSet<_> = HashSet::from_iter(vec![ - format!("{catalog_name}.{schema_name}.{table_name}.{ts_col_name}"), - format!("{schema_name}.{table_name}.{ts_col_name}"), - format!("{table_name}.{ts_col_name}"), - format!("{ts_col_name}"), - ]); - let ts_columns: HashSet<_> = ts_columns - .into_iter() - .map(Column::from_qualified_name) - .collect(); - - let ts_columns_ref: HashSet<&Column> = ts_columns.iter().collect(); - - // find the time window expr which refers to the time index column - let mut time_window_expr: Option = None; - let find_time_window_expr = |plan: &LogicalPlan| { - let LogicalPlan::Aggregate(aggregate) = plan else { - return Ok(TreeNodeRecursion::Continue); - }; - - for group_expr in &aggregate.group_expr { - let refs = group_expr.column_refs(); - if refs.len() != 1 { - continue; - } - let ref_col = refs.iter().next().unwrap(); - if ts_columns_ref.contains(ref_col) { - time_window_expr = Some(group_expr.clone()); - break; - } - } - - Ok(TreeNodeRecursion::Stop) - }; - plan.apply(find_time_window_expr) - .with_context(|_| DatafusionSnafu { - context: format!("Can't find time window expr in plan {plan:?}"), - })?; - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new( ts_col_name.clone(), ts_index.data_type.as_arrow_type(), @@ -368,7 +331,75 @@ async fn find_time_window_expr( .with_context(|_e| DatafusionSnafu { context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"), })?; - Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema)) + + // find the time window expr which refers to the time index column + let mut aggr_expr = None; + let mut time_window_expr: Option = None; + + let find_inner_aggr_expr = |plan: &LogicalPlan| { + if let LogicalPlan::Aggregate(aggregate) = plan { + aggr_expr = Some(aggregate.clone()); + }; + + Ok(TreeNodeRecursion::Continue) + }; + plan.apply(find_inner_aggr_expr) + .with_context(|_| DatafusionSnafu { + context: format!("Can't find aggr expr in plan {plan:?}"), + })?; + + if let Some(aggregate) = aggr_expr { + for group_expr in &aggregate.group_expr { + let refs = group_expr.column_refs(); + if refs.len() != 1 { + continue; + } + let ref_col = refs.iter().next().unwrap(); + + let index = aggregate.input.schema().maybe_index_of_column(ref_col); + let Some(index) = index else { + continue; + }; + let field = aggregate.input.schema().field(index); + + let is_time_index = field.metadata().get(TIME_INDEX_KEY) == Some(&"true".to_string()); + + if is_time_index { + let rewrite_column = group_expr.clone(); + let rewritten = rewrite_column + .rewrite(&mut RewriteColumn { + table_name: table_name.to_string(), + }) + .with_context(|_| DatafusionSnafu { + context: format!("Rewrite expr failed, expr={:?}", group_expr), + })? + .data; + struct RewriteColumn { + table_name: String, + } + + impl TreeNodeRewriter for RewriteColumn { + type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> DfResult> { + let Expr::Column(mut column) = node else { + return Ok(Transformed::no(node)); + }; + + column.relation = Some(TableReference::bare(self.table_name.clone())); + + Ok(Transformed::yes(Expr::Column(column))) + } + } + + time_window_expr = Some(rewritten); + break; + } + } + Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema)) + } else { + // can't found time window expr, return None + Ok((ts_col_name, None, expected_time_unit, df_schema)) + } } /// Find nearest lower bound for time `current` in given `plan` for the time window expr. @@ -393,7 +424,6 @@ pub async fn find_plan_time_window_bound( let (ts_col_name, time_window_expr, expected_time_unit, df_schema) = find_time_window_expr(plan, catalog_man.clone(), query_ctx).await?; - // cast current to ts_index's type let new_current = current .convert_to(expected_time_unit) @@ -425,8 +455,6 @@ fn find_expr_time_window_lower_bound( df_schema: &DFSchema, current: Timestamp, ) -> Result, Error> { - use std::cmp::Ordering; - let phy_planner = DefaultPhysicalPlanner::default(); let phy_expr: PhysicalExprRef = phy_planner @@ -438,91 +466,8 @@ fn find_expr_time_window_lower_bound( })?; let cur_time_window = eval_ts_to_ts(&phy_expr, df_schema, current)?; - if cur_time_window == current { - return Ok(Some(current)); - } - - // search to find the lower bound - let mut offset: i64 = 1; - let lower_bound; - let mut upper_bound = Some(current); - // first expontial probe to found a range for binary search - loop { - let Some(next_val) = current.value().checked_sub(offset) else { - // no lower bound - return Ok(None); - }; - - let prev_time_probe = common_time::Timestamp::new(next_val, current.unit()); - - let prev_time_window = eval_ts_to_ts(&phy_expr, df_schema, prev_time_probe)?; - - match prev_time_window.cmp(&cur_time_window) { - Ordering::Less => { - lower_bound = Some(prev_time_probe); - break; - } - Ordering::Equal => { - upper_bound = Some(prev_time_probe); - } - Ordering::Greater => { - UnexpectedSnafu { - reason: format!( - "Unsupported time window expression, expect monotonic increasing for time window expression {expr:?}" - ), - } - .fail()? - } - } - - let Some(new_offset) = offset.checked_mul(2) else { - // no lower bound - return Ok(None); - }; - offset = new_offset; - } - - // binary search for the exact lower bound - - ensure!(lower_bound.map(|v|v.unit())==upper_bound.map(|v|v.unit()), UnexpectedSnafu{ - reason: format!(" unit mismatch for time window expression {expr:?}, found {lower_bound:?} and {upper_bound:?}"), - }); - - let input_time_unit = lower_bound - .context(UnexpectedSnafu { - reason: "should have lower bound", - })? - .unit(); - - let mut low = lower_bound - .context(UnexpectedSnafu { - reason: "should have lower bound", - })? - .value(); - let mut high = upper_bound - .context(UnexpectedSnafu { - reason: "should have upper bound", - })? - .value(); - - while low < high { - let mid = (low + high) / 2; - let mid_probe = common_time::Timestamp::new(mid, input_time_unit); - let mid_time_window = eval_ts_to_ts(&phy_expr, df_schema, mid_probe)?; - - match mid_time_window.cmp(&cur_time_window) { - Ordering::Less => low = mid + 1, - Ordering::Equal => high = mid, - Ordering::Greater => UnexpectedSnafu { - reason: format!("Binary search failed for time window expression {expr:?}"), - } - .fail()?, - } - } - - let final_lower_bound_for_time_window = common_time::Timestamp::new(low, input_time_unit); - - Ok(Some(final_lower_bound_for_time_window)) + let input_time_unit = cur_time_window.unit(); + Ok(cur_time_window.convert_to(input_time_unit)) } /// Find the upper bound for time window expression @@ -688,7 +633,7 @@ impl AddFilterRewriter { impl TreeNodeRewriter for AddFilterRewriter { type Node = LogicalPlan; - fn f_down(&mut self, node: Self::Node) -> DfResult> { + fn f_up(&mut self, node: Self::Node) -> DfResult> { if self.is_rewritten { return Ok(Transformed::no(node)); } @@ -738,7 +683,7 @@ mod test { ), ( "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10", - "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4))" + "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)" ), ( "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window", @@ -853,9 +798,53 @@ mod test { ), "SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number" ), + // subquery + ( + "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);", + Timestamp::new(23, TimeUnit::Millisecond), + ( + "ts".to_string(), + Some(Timestamp::new(0, TimeUnit::Millisecond)), + Some(Timestamp::new(300000, TimeUnit::Millisecond)), + ), + "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)" + ), + // cte + ( + "with cte as (select number, date_bin('5 minutes', ts) as time_window from numbers_with_ts GROUP BY time_window, number) select number, time_window from cte;", + Timestamp::new(23, TimeUnit::Millisecond), + ( + "ts".to_string(), + Some(Timestamp::new(0, TimeUnit::Millisecond)), + Some(Timestamp::new(300000, TimeUnit::Millisecond)), + ), + "SELECT cte.number, cte.time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number) AS cte" + ), + // complex subquery without alias + ( + "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;", + Timestamp::new(23, TimeUnit::Millisecond), + ( + "ts".to_string(), + Some(Timestamp::new(0, TimeUnit::Millisecond)), + Some(Timestamp::new(300000, TimeUnit::Millisecond)), + ), + "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name" + ), + // complex subquery alias + ( + "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;", + Timestamp::new(23, TimeUnit::Millisecond), + ( + "ts".to_string(), + Some(Timestamp::new(0, TimeUnit::Millisecond)), + Some(Timestamp::new(300000, TimeUnit::Millisecond)), + ), + "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) AS cte GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name" + ), ]; - for (sql, current, expected, unparsed) in testcases { + for (sql, current, expected, expected_unparsed) in testcases { let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, true) .await .unwrap(); @@ -887,7 +876,7 @@ mod test { } else { sql.to_string() }; - assert_eq!(unparsed, new_sql); + assert_eq!(expected_unparsed, new_sql); } } } diff --git a/src/flow/src/recording_rules/engine.rs b/src/flow/src/recording_rules/engine.rs index 37944c5b55..a9f8c791af 100644 --- a/src/flow/src/recording_rules/engine.rs +++ b/src/flow/src/recording_rules/engine.rs @@ -434,10 +434,13 @@ impl RecordingRuleTask { .dirty_time_windows .gen_filter_exprs(&col_name, lower, window_size)? } - _ => UnexpectedSnafu { - reason: format!("Can't get window size: lower={lower:?}, upper={upper:?}"), + _ => { + warn!( + "Flow id = {:?}, can't get window size: lower={lower:?}, upper={upper:?}, using the same query", self.flow_id + ); + // since no time window lower/upper bound is found, just return the original query + return Ok(Some(self.query.clone())); } - .fail()?, } }; @@ -454,6 +457,7 @@ impl RecordingRuleTask { let Some(expr) = expr else { // no new data, hence no need to update + debug!("Flow id={:?}, no new data, not update", self.flow_id); return Ok(None); };