From bb23334724e2b1cbece7a89d3856a430e4188815 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 19 Mar 2026 21:11:14 +0800 Subject: [PATCH] feat: flow join rewriter Signed-off-by: discord9 --- src/flow/src/batching_mode/task.rs | 708 +++++++++++++++++++++++++++- src/flow/src/batching_mode/utils.rs | 372 ++++++++++++++- 2 files changed, 1067 insertions(+), 13 deletions(-) diff --git a/src/flow/src/batching_mode/task.rs b/src/flow/src/batching_mode/task.rs index b0c3e0bc2a..44aa1665d4 100644 --- a/src/flow/src/batching_mode/task.rs +++ b/src/flow/src/batching_mode/task.rs @@ -54,8 +54,9 @@ use crate::batching_mode::frontend_client::FrontendClient; use crate::batching_mode::state::{CheckpointMode, FilterExprInfo, TaskState}; use crate::batching_mode::time_window::TimeWindowExpr; use crate::batching_mode::utils::{ - AddFilterRewriter, ColumnMatcherRewriter, FindGroupByFinalName, gen_plan_with_matching_schema, - get_table_info_df_schema, sql_to_df_plan, + AddFilterRewriter, ColumnMatcherRewriter, FindGroupByFinalName, + analyze_poc_incremental_aggregate_plan, gen_plan_with_matching_schema, + get_table_info_df_schema, rewrite_poc_incremental_aggregate_with_sink_merge, sql_to_df_plan, }; use crate::df_optimizer::apply_df_optimizer; use crate::error::{ @@ -141,6 +142,51 @@ pub struct PlanInfo { } impl BatchingTask { + async fn rewrite_incremental_sql_plan_if_needed( + &self, + plan: LogicalPlan, + ) -> Result { + if self.state.read().unwrap().checkpoint_mode() != CheckpointMode::Incremental { + return Ok(plan); + } + if self.config.query_type != QueryType::Sql { + return Ok(plan); + } + + let Some(analysis) = analyze_poc_incremental_aggregate_plan(&plan)? else { + return Ok(plan); + }; + + if !analysis.unsupported_exprs.is_empty() { + return InvalidQuerySnafu { + reason: format!( + "UNSUPPORTED_INCREMENTAL_AGG: query contains unsupported incremental aggregate expressions {:?}", + analysis.unsupported_exprs + ), + } + .fail(); + } + + let (sink_table, _) = get_table_info_df_schema( + self.config.catalog_manager.clone(), + self.config.sink_table_name.clone(), + ) + .await?; + + let rewritten = rewrite_poc_incremental_aggregate_with_sink_merge( + &plan, + &analysis, + sink_table, + &self.config.sink_table_name, + ) + .await?; + warn!( + "Flow {} rewrote incremental SQL aggregate query with POC sink merge", + self.config.flow_id, + ); + Ok(rewritten) + } + #[allow(clippy::too_many_arguments)] pub fn try_new( TaskArgs { @@ -779,15 +825,25 @@ impl BatchingTask { return Ok(None); } - let plan = gen_plan_with_matching_schema( + let plan = sql_to_df_plan( + query_ctx.clone(), + engine.clone(), &self.config.query, - query_ctx, - engine, - sink_table_schema.clone(), - primary_key_indices, - allow_partial, + false, ) .await?; + let rewritten = self.rewrite_incremental_sql_plan_if_needed(plan).await?; + let mut add_auto_column = ColumnMatcherRewriter::new( + sink_table_schema.clone(), + primary_key_indices.to_vec(), + allow_partial, + ); + let plan = rewritten + .rewrite(&mut add_auto_column) + .with_context(|_| DatafusionSnafu { + context: "Failed to align rewritten plan with sink schema".to_string(), + })? + .data; return Ok(Some(PlanInfo { plan, filter: None })); } @@ -878,14 +934,22 @@ impl BatchingTask { let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), &self.config.query, false).await?; - let rewrite = plan + let filtered = plan .clone() .rewrite(&mut add_filter) - .and_then(|p| p.data.rewrite(&mut add_auto_column)) .with_context(|_| DatafusionSnafu { context: format!("Failed to rewrite plan:\n {}\n", plan), })? .data; + let rewritten = self + .rewrite_incremental_sql_plan_if_needed(filtered) + .await?; + let rewrite = rewritten + .rewrite(&mut add_auto_column) + .with_context(|_| DatafusionSnafu { + context: "Failed to align rewritten plan with sink schema".to_string(), + })? + .data; // only apply optimize after complex rewrite is done let new_plan = apply_df_optimizer(rewrite, &query_ctx).await?; @@ -1108,19 +1172,106 @@ fn build_pk_from_aggr(plan: &LogicalPlan) -> Result, Error> { #[cfg(test)] mod test { use std::collections::BTreeMap; + use std::sync::Arc; use api::v1::column_def::try_as_column_schema; + use catalog::RegisterTableRequest; + use catalog::memory::MemoryCatalogManager; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::ext::{BoxedError, PlainError}; use common_error::status_code::StatusCode; - use common_query::Output; + use common_query::{Output, OutputData}; use common_recordbatch::adapter::{RecordBatchMetrics, RegionWatermarkEntry}; + use common_recordbatch::util; + use datatypes::arrow_array::{int_array_value_at_index, timestamp_array_value}; + use datatypes::prelude::{ConcreteDataType, MutableVector, ScalarVectorBuilder}; + use datatypes::schema::Schema; + use datatypes::timestamp::TimestampMillisecond; + use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef}; use pretty_assertions::assert_eq; use session::context::QueryContext; use snafu::GenerateImplicitData; + use table::test_util::MemTable; use super::*; use crate::test_utils::create_test_query_engine; + fn register_test_table( + query_engine: &QueryEngineRef, + table_name: &str, + rows: &[(Option, i64)], + ) { + let schema = Arc::new(Schema::new(vec![ + ColumnSchema::new("number", ConcreteDataType::uint32_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ])); + + let mut number_builder = datatypes::vectors::UInt32VectorBuilder::with_capacity(rows.len()); + for (number, _) in rows { + number_builder.push(*number); + } + let numbers: VectorRef = number_builder.to_vector_cloned(); + let mut ts_builder = TimestampMillisecondVectorBuilder::with_capacity(rows.len()); + for (_, ts) in rows { + ts_builder.push(Some(TimestampMillisecond::new(*ts))); + } + let timestamps: VectorRef = ts_builder.to_vector_cloned(); + let recordbatch = + common_recordbatch::RecordBatch::new(schema, vec![numbers, timestamps]).unwrap(); + let table = MemTable::table(table_name, recordbatch); + + let memory_catalog_manager = query_engine + .engine_state() + .catalog_manager() + .as_any() + .downcast_ref::() + .unwrap(); + memory_catalog_manager + .register_table_sync(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: table_name.to_string(), + table_id: 6000, + table, + }) + .unwrap(); + } + + fn register_test_sink_table( + query_engine: &QueryEngineRef, + table_name: &str, + rows: &[(u32, i64)], + ) { + let rows = rows + .iter() + .map(|(number, ts)| (Some(*number), *ts)) + .collect::>(); + register_test_table(query_engine, table_name, &rows); + } + + fn extract_ts_number_rows( + batches: &[common_recordbatch::RecordBatch], + ) -> Vec<(i64, Option)> { + let mut rows = Vec::new(); + for batch in batches { + let ts_col = batch.column_by_name("ts").unwrap(); + let number_col = batch.column_by_name("number").unwrap(); + for row_idx in 0..batch.num_rows() { + rows.push(( + timestamp_array_value(ts_col, row_idx).value(), + int_array_value_at_index(number_col, row_idx), + )); + } + } + rows.sort_unstable(); + rows + } + #[tokio::test] async fn test_gen_create_table_sql() { let query_engine = create_test_query_engine(); @@ -1687,4 +1838,539 @@ mod test { &BTreeMap::from([(1_u64, 30_u64), (2_u64, 40_u64)]) ); } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_for_supported_aggregate() { + let query_engine = create_test_query_engine(); + let query_ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 49, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]], + query_ctx, + catalog_manager: create_test_query_engine() + .engine_state() + .catalog_manager() + .clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + let plan_text = format!("{}", rewritten.display_indent()); + assert!(plan_text.contains("Left Join")); + assert!(!plan_text.contains("Union")); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_rejects_avg() { + let query_engine = create_test_query_engine(); + let query_ctx = QueryContext::arc(); + let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 50, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]], + query_ctx, + catalog_manager: create_test_query_engine() + .engine_state() + .catalog_manager() + .clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + task.mark_all_windows_as_dirty().unwrap(); + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + match task.gen_insert_plan(&query_engine, None).await { + Err(err) => assert!(format!("{err}").contains("UNSUPPORTED_INCREMENTAL_AGG")), + Ok(_) => panic!("expected UNSUPPORTED_INCREMENTAL_AGG error for avg query"), + } + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_sum_only_new_and_both_sides() { + let query_engine = create_test_query_engine(); + register_test_sink_table(&query_engine, "sink_semantic", &[(20, 2)]); + + let query_ctx = QueryContext::arc(); + let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts WHERE ts >= 2 AND ts <= 3 GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 51, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, Some(22)), (3, Some(3))]); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_max_only_new_and_both_sides() { + let query_engine = create_test_query_engine(); + register_test_sink_table(&query_engine, "sink_semantic_max", &[(20, 2)]); + + let query_ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts WHERE ts >= 2 AND ts <= 3 GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 52, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic_max".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_sum_nullable_delta_keeps_old_state() { + let query_engine = create_test_query_engine(); + register_test_sink_table(&query_engine, "sink_semantic_sum_null", &[(20, 2)]); + register_test_table( + &query_engine, + "numbers_with_nullable_ts", + &[(None, 2), (Some(3), 3)], + ); + + let query_ctx = QueryContext::arc(); + let sql = "SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 53, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic_sum_null".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_nullable_ts".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_max_nullable_delta_keeps_old_state() { + let query_engine = create_test_query_engine(); + register_test_sink_table(&query_engine, "sink_semantic_max_null", &[(20, 2)]); + register_test_table( + &query_engine, + "numbers_with_nullable_ts_max", + &[(None, 2), (Some(3), 3)], + ); + + let query_ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_nullable_ts_max GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 54, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic_max_null".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_nullable_ts_max".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_sum_sink_null_delta_nonnull_uses_delta() { + let query_engine = create_test_query_engine(); + register_test_table(&query_engine, "sink_semantic_sum_sink_null", &[(None, 2)]); + register_test_table( + &query_engine, + "numbers_with_nullable_ts_sink_null", + &[(Some(7), 2), (Some(3), 3)], + ); + + let query_ctx = QueryContext::arc(); + let sql = + "SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts_sink_null GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 55, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic_sum_sink_null".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_nullable_ts_sink_null".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, Some(7)), (3, Some(3))]); + } + + #[tokio::test] + async fn test_rewrite_incremental_sql_plan_semantics_sum_double_null_stays_null() { + let query_engine = create_test_query_engine(); + register_test_table(&query_engine, "sink_semantic_sum_double_null", &[(None, 2)]); + register_test_table( + &query_engine, + "numbers_with_nullable_ts_double_null", + &[(None, 2), (Some(3), 3)], + ); + + let query_ctx = QueryContext::arc(); + let sql = "SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts_double_null GROUP BY ts"; + let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true) + .await + .unwrap(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let task = BatchingTask::try_new(TaskArgs { + flow_id: 56, + query: sql, + plan, + time_window_expr: None, + expire_after: None, + sink_table_name: [ + "greptime".to_string(), + "public".to_string(), + "sink_semantic_sum_double_null".to_string(), + ], + source_table_names: vec![[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_nullable_ts_double_null".to_string(), + ]], + query_ctx, + catalog_manager: query_engine.engine_state().catalog_manager().clone(), + shutdown_rx: rx, + batch_opts: Arc::new(BatchingModeOptions::default()), + flow_eval_interval: None, + }) + .unwrap(); + + { + let mut state = task.state.write().unwrap(); + state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)])); + } + + let raw_plan = sql_to_df_plan( + task.state.read().unwrap().query_ctx.clone(), + query_engine.clone(), + sql, + false, + ) + .await + .unwrap(); + let rewritten = task + .rewrite_incremental_sql_plan_if_needed(raw_plan) + .await + .unwrap(); + + let output = query_engine + .execute(rewritten, task.state.read().unwrap().query_ctx.clone()) + .await + .unwrap(); + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::AffectedRows(_) => panic!("expected query output"), + }; + let batches = util::collect(stream).await.unwrap(); + + let rows = extract_ts_number_rows(&batches); + assert_eq!(rows, vec![(2, None), (3, Some(3))]); + } } diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index dfbadbfc72..c247e5a594 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -19,15 +19,21 @@ use std::sync::Arc; use catalog::CatalogManagerRef; use common_error::ext::BoxedError; +use common_function::aggrs::aggr_wrapper::get_aggr_func; use common_telemetry::debug; +use datafusion::datasource::DefaultTableSource; use datafusion::error::Result as DfResult; use datafusion::logical_expr::Expr; use datafusion::sql::unparser::Unparser; use datafusion_common::tree_node::{ Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{DFSchema, DataFusionError, ScalarValue}; -use datafusion_expr::{Distinct, LogicalPlan, Projection}; +use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference}; +use datafusion_expr::logical_plan::TableScan; +use datafusion_expr::{ + Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr, + bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when, +}; use datatypes::schema::{ColumnSchema, SchemaRef}; use query::QueryEngineRef; use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement}; @@ -37,12 +43,304 @@ use sql::parser::{ParseOptions, ParserContext}; use sql::statements::statement::Statement; use sql::statements::tql::Tql; use table::TableRef; +use table::table::adapter::DfTableProviderAdapter; use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL}; use crate::df_optimizer::apply_df_optimizer; use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu}; use crate::{Error, TableName}; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PocIncrementalMergeColumn { + pub output_name: String, + pub merge_function: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PocIncrementalAggregateAnalysis { + pub group_columns: Vec, + pub merge_columns: Vec, + pub unsupported_exprs: Vec, +} + +#[derive(Default)] +struct LastAggregateExprFinder { + aggr_exprs: Option>, +} + +impl TreeNodeVisitor<'_> for LastAggregateExprFinder { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result { + if let LogicalPlan::Aggregate(aggregate) = node { + self.aggr_exprs = Some(aggregate.aggr_expr.clone()); + } + Ok(TreeNodeRecursion::Continue) + } +} + +pub fn analyze_poc_incremental_aggregate_plan( + plan: &LogicalPlan, +) -> Result, Error> { + let mut group_finder = FindGroupByFinalName::default(); + plan.visit(&mut group_finder) + .with_context(|_| DatafusionSnafu { + context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"), + })?; + + let mut aggregate_finder = LastAggregateExprFinder::default(); + plan.visit(&mut aggregate_finder) + .with_context(|_| DatafusionSnafu { + context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"), + })?; + let Some(aggr_exprs) = aggregate_finder.aggr_exprs else { + return Ok(None); + }; + + let mut output_aliases = HashMap::new(); + if let LogicalPlan::Projection(projection) = plan { + for expr in &projection.expr { + match expr { + Expr::Alias(alias) => { + if let Expr::Column(col) = alias.expr.as_ref() { + output_aliases.insert(col.name.clone(), alias.name.clone()); + } + } + Expr::Column(col) => { + output_aliases.insert(col.name.clone(), col.name.clone()); + } + _ => {} + } + } + } + + let mut group_columns = group_finder + .get_group_expr_names() + .unwrap_or_default() + .into_iter() + .collect::>(); + group_columns.sort(); + + let mut merge_columns = Vec::with_capacity(aggr_exprs.len()); + let mut unsupported_exprs = Vec::new(); + for aggr_expr in aggr_exprs { + let Some(aggr_func) = get_aggr_func(&aggr_expr) else { + unsupported_exprs.push(aggr_expr.to_string()); + continue; + }; + + let aggr_name = aggr_func.func.name().to_ascii_lowercase(); + let merge_function = if aggr_func.params.distinct { + None + } else { + match aggr_name.as_str() { + "sum" => Some("sum"), + "count" => Some("sum"), + "min" => Some("min"), + "max" => Some("max"), + "bool_and" => Some("bool_and"), + "bool_or" => Some("bool_or"), + "bit_and" => Some("bit_and"), + "bit_or" => Some("bit_or"), + "bit_xor" => Some("bit_xor"), + _ => None, + } + }; + + let Some(merge_function) = merge_function else { + unsupported_exprs.push(aggr_expr.to_string()); + continue; + }; + + let raw_name = aggr_expr.qualified_name().1; + let output_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name); + merge_columns.push(PocIncrementalMergeColumn { + output_name, + merge_function: merge_function.to_string(), + }); + } + + Ok(Some(PocIncrementalAggregateAnalysis { + group_columns, + merge_columns, + unsupported_exprs, + })) +} + +pub async fn rewrite_poc_incremental_aggregate_with_sink_merge( + delta_plan: &LogicalPlan, + analysis: &PocIncrementalAggregateAnalysis, + sink_table: TableRef, + sink_table_name: &TableName, +) -> Result { + ensure!( + analysis.unsupported_exprs.is_empty(), + InvalidQuerySnafu { + reason: format!( + "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}", + analysis.unsupported_exprs + ) + } + ); + + ensure!( + !analysis.merge_columns.is_empty(), + InvalidQuerySnafu { + reason: + "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns" + .to_string() + } + ); + + let delta_alias = "__flow_delta"; + let sink_alias = "__flow_sink"; + + let mut selected_columns = analysis.group_columns.clone(); + selected_columns.extend(analysis.merge_columns.iter().map(|c| c.output_name.clone())); + + let selected_exprs = selected_columns.iter().map(col).collect::>(); + let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) + .project(selected_exprs.clone()) + .with_context(|_| DatafusionSnafu { + context: "Failed to project delta plan for incremental sink merge".to_string(), + })? + .alias(delta_alias) + .with_context(|_| DatafusionSnafu { + context: "Failed to alias delta plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build projected delta plan for incremental sink merge".to_string(), + })?; + + let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table)); + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + let sink_scan = LogicalPlan::TableScan( + TableScan::try_new( + TableReference::Full { + catalog: sink_table_name[0].clone().into(), + schema: sink_table_name[1].clone().into(), + table: sink_table_name[2].clone().into(), + }, + table_source, + None, + vec![], + None, + ) + .with_context(|_| DatafusionSnafu { + context: "Failed to build sink table scan for incremental sink merge".to_string(), + })?, + ); + + let sink_selected = LogicalPlanBuilder::from(sink_scan) + .project(selected_exprs) + .with_context(|_| DatafusionSnafu { + context: "Failed to project sink table scan for incremental sink merge".to_string(), + })? + .alias(sink_alias) + .with_context(|_| DatafusionSnafu { + context: "Failed to alias sink plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build projected sink plan for incremental sink merge".to_string(), + })?; + + let join_keys = ( + analysis + .group_columns + .iter() + .map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}"))) + .collect::>(), + analysis + .group_columns + .iter() + .map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}"))) + .collect::>(), + ); + + let joined = LogicalPlanBuilder::from(delta_selected) + .join(sink_selected, JoinType::Left, join_keys, None) + .with_context(|_| DatafusionSnafu { + context: "Failed to left join delta and sink plans for incremental sink merge" + .to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to build left join plan for incremental sink merge".to_string(), + })?; + + let mut projection_exprs = analysis + .group_columns + .iter() + .map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone())) + .collect::>(); + projection_exprs.extend( + analysis + .merge_columns + .iter() + .map(|merge_col| build_left_join_merge_expr(delta_alias, sink_alias, merge_col)), + ); + + LogicalPlanBuilder::from(joined) + .project(projection_exprs) + .with_context(|_| DatafusionSnafu { + context: "Failed to build projection merge plan for incremental sink merge".to_string(), + })? + .build() + .with_context(|_| DatafusionSnafu { + context: "Failed to finalize incremental aggregate sink merge plan".to_string(), + }) +} + +fn build_left_join_merge_expr( + delta_alias: &str, + sink_alias: &str, + merge_col: &PocIncrementalMergeColumn, +) -> Expr { + let left = col(format!("{delta_alias}.{}", merge_col.output_name)); + let right = col(format!("{sink_alias}.{}", merge_col.output_name)); + let merged = match merge_col.merge_function.as_str() { + "sum" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone())) + .unwrap(), + "min" => when(is_null(right.clone()), left.clone()) + .when(left.clone().lt_eq(right.clone()), left.clone()) + .otherwise(right.clone()) + .unwrap(), + "max" => when(is_null(right.clone()), left.clone()) + .when(left.clone().gt_eq(right.clone()), left.clone()) + .otherwise(right.clone()) + .unwrap(), + "bool_and" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(and(left.clone(), right.clone())) + .unwrap(), + "bool_or" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(or(left.clone(), right.clone())) + .unwrap(), + "bit_and" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_and(left.clone(), right.clone())) + .unwrap(), + "bit_or" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_or(left.clone(), right.clone())) + .unwrap(), + "bit_xor" => when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(bitwise_xor(left.clone(), right.clone())) + .unwrap(), + other => Expr::Literal( + ScalarValue::Utf8(Some(format!("UNSUPPORTED_INCREMENTAL_AGG:{other}"))), + None, + ), + }; + merged.alias(merge_col.output_name.clone()) +} + pub async fn get_table_info_df_schema( catalog_mr: CatalogManagerRef, table_name: TableName, @@ -907,6 +1205,76 @@ mod test { } } + #[tokio::test] + async fn test_analyze_poc_incremental_aggregate_plan() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_poc_incremental_aggregate_plan(&plan) + .unwrap() + .unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_columns.contains(&"ts".to_string())); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!(analysis.merge_columns[0].output_name, "number"); + assert_eq!(analysis.merge_columns[0].merge_function, "max"); + } + + #[tokio::test] + async fn test_analyze_poc_incremental_aggregate_plan_rejects_avg() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_poc_incremental_aggregate_plan(&plan) + .unwrap() + .unwrap(); + assert!(!analysis.unsupported_exprs.is_empty()); + } + + #[tokio::test] + async fn test_rewrite_poc_incremental_aggregate_with_left_join() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + let analysis = analyze_poc_incremental_aggregate_plan(&plan) + .unwrap() + .unwrap(); + let (sink_table, _) = get_table_info_df_schema( + query_engine.engine_state().catalog_manager().clone(), + [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + ) + .await + .unwrap(); + + let rewritten = rewrite_poc_incremental_aggregate_with_sink_merge( + &plan, + &analysis, + sink_table, + &[ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ], + ) + .await + .unwrap(); + + let plan_text = format!("{}", rewritten.display_indent()); + assert!(plan_text.contains("Left Join")); + assert!(!plan_text.contains("Union")); + } + #[tokio::test] async fn test_null_cast() { let query_engine = create_test_query_engine();