feat: flow join rewriter

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-03-19 21:11:14 +08:00
parent 751895cf8d
commit bb23334724
2 changed files with 1067 additions and 13 deletions

View File

@@ -54,8 +54,9 @@ use crate::batching_mode::frontend_client::FrontendClient;
use crate::batching_mode::state::{CheckpointMode, FilterExprInfo, TaskState};
use crate::batching_mode::time_window::TimeWindowExpr;
use crate::batching_mode::utils::{
AddFilterRewriter, ColumnMatcherRewriter, FindGroupByFinalName, gen_plan_with_matching_schema,
get_table_info_df_schema, sql_to_df_plan,
AddFilterRewriter, ColumnMatcherRewriter, FindGroupByFinalName,
analyze_poc_incremental_aggregate_plan, gen_plan_with_matching_schema,
get_table_info_df_schema, rewrite_poc_incremental_aggregate_with_sink_merge, sql_to_df_plan,
};
use crate::df_optimizer::apply_df_optimizer;
use crate::error::{
@@ -141,6 +142,51 @@ pub struct PlanInfo {
}
impl BatchingTask {
async fn rewrite_incremental_sql_plan_if_needed(
&self,
plan: LogicalPlan,
) -> Result<LogicalPlan, Error> {
if self.state.read().unwrap().checkpoint_mode() != CheckpointMode::Incremental {
return Ok(plan);
}
if self.config.query_type != QueryType::Sql {
return Ok(plan);
}
let Some(analysis) = analyze_poc_incremental_aggregate_plan(&plan)? else {
return Ok(plan);
};
if !analysis.unsupported_exprs.is_empty() {
return InvalidQuerySnafu {
reason: format!(
"UNSUPPORTED_INCREMENTAL_AGG: query contains unsupported incremental aggregate expressions {:?}",
analysis.unsupported_exprs
),
}
.fail();
}
let (sink_table, _) = get_table_info_df_schema(
self.config.catalog_manager.clone(),
self.config.sink_table_name.clone(),
)
.await?;
let rewritten = rewrite_poc_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&self.config.sink_table_name,
)
.await?;
warn!(
"Flow {} rewrote incremental SQL aggregate query with POC sink merge",
self.config.flow_id,
);
Ok(rewritten)
}
#[allow(clippy::too_many_arguments)]
pub fn try_new(
TaskArgs {
@@ -779,15 +825,25 @@ impl BatchingTask {
return Ok(None);
}
let plan = gen_plan_with_matching_schema(
let plan = sql_to_df_plan(
query_ctx.clone(),
engine.clone(),
&self.config.query,
query_ctx,
engine,
sink_table_schema.clone(),
primary_key_indices,
allow_partial,
false,
)
.await?;
let rewritten = self.rewrite_incremental_sql_plan_if_needed(plan).await?;
let mut add_auto_column = ColumnMatcherRewriter::new(
sink_table_schema.clone(),
primary_key_indices.to_vec(),
allow_partial,
);
let plan = rewritten
.rewrite(&mut add_auto_column)
.with_context(|_| DatafusionSnafu {
context: "Failed to align rewritten plan with sink schema".to_string(),
})?
.data;
return Ok(Some(PlanInfo { plan, filter: None }));
}
@@ -878,14 +934,22 @@ impl BatchingTask {
let plan =
sql_to_df_plan(query_ctx.clone(), engine.clone(), &self.config.query, false).await?;
let rewrite = plan
let filtered = plan
.clone()
.rewrite(&mut add_filter)
.and_then(|p| p.data.rewrite(&mut add_auto_column))
.with_context(|_| DatafusionSnafu {
context: format!("Failed to rewrite plan:\n {}\n", plan),
})?
.data;
let rewritten = self
.rewrite_incremental_sql_plan_if_needed(filtered)
.await?;
let rewrite = rewritten
.rewrite(&mut add_auto_column)
.with_context(|_| DatafusionSnafu {
context: "Failed to align rewritten plan with sink schema".to_string(),
})?
.data;
// only apply optimize after complex rewrite is done
let new_plan = apply_df_optimizer(rewrite, &query_ctx).await?;
@@ -1108,19 +1172,106 @@ fn build_pk_from_aggr(plan: &LogicalPlan) -> Result<Option<TableDef>, Error> {
#[cfg(test)]
mod test {
use std::collections::BTreeMap;
use std::sync::Arc;
use api::v1::column_def::try_as_column_schema;
use catalog::RegisterTableRequest;
use catalog::memory::MemoryCatalogManager;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::Output;
use common_query::{Output, OutputData};
use common_recordbatch::adapter::{RecordBatchMetrics, RegionWatermarkEntry};
use common_recordbatch::util;
use datatypes::arrow_array::{int_array_value_at_index, timestamp_array_value};
use datatypes::prelude::{ConcreteDataType, MutableVector, ScalarVectorBuilder};
use datatypes::schema::Schema;
use datatypes::timestamp::TimestampMillisecond;
use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef};
use pretty_assertions::assert_eq;
use session::context::QueryContext;
use snafu::GenerateImplicitData;
use table::test_util::MemTable;
use super::*;
use crate::test_utils::create_test_query_engine;
fn register_test_table(
query_engine: &QueryEngineRef,
table_name: &str,
rows: &[(Option<u32>, i64)],
) {
let schema = Arc::new(Schema::new(vec![
ColumnSchema::new("number", ConcreteDataType::uint32_datatype(), true),
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
]));
let mut number_builder = datatypes::vectors::UInt32VectorBuilder::with_capacity(rows.len());
for (number, _) in rows {
number_builder.push(*number);
}
let numbers: VectorRef = number_builder.to_vector_cloned();
let mut ts_builder = TimestampMillisecondVectorBuilder::with_capacity(rows.len());
for (_, ts) in rows {
ts_builder.push(Some(TimestampMillisecond::new(*ts)));
}
let timestamps: VectorRef = ts_builder.to_vector_cloned();
let recordbatch =
common_recordbatch::RecordBatch::new(schema, vec![numbers, timestamps]).unwrap();
let table = MemTable::table(table_name, recordbatch);
let memory_catalog_manager = query_engine
.engine_state()
.catalog_manager()
.as_any()
.downcast_ref::<MemoryCatalogManager>()
.unwrap();
memory_catalog_manager
.register_table_sync(RegisterTableRequest {
catalog: DEFAULT_CATALOG_NAME.to_string(),
schema: DEFAULT_SCHEMA_NAME.to_string(),
table_name: table_name.to_string(),
table_id: 6000,
table,
})
.unwrap();
}
fn register_test_sink_table(
query_engine: &QueryEngineRef,
table_name: &str,
rows: &[(u32, i64)],
) {
let rows = rows
.iter()
.map(|(number, ts)| (Some(*number), *ts))
.collect::<Vec<_>>();
register_test_table(query_engine, table_name, &rows);
}
fn extract_ts_number_rows(
batches: &[common_recordbatch::RecordBatch],
) -> Vec<(i64, Option<i64>)> {
let mut rows = Vec::new();
for batch in batches {
let ts_col = batch.column_by_name("ts").unwrap();
let number_col = batch.column_by_name("number").unwrap();
for row_idx in 0..batch.num_rows() {
rows.push((
timestamp_array_value(ts_col, row_idx).value(),
int_array_value_at_index(number_col, row_idx),
));
}
}
rows.sort_unstable();
rows
}
#[tokio::test]
async fn test_gen_create_table_sql() {
let query_engine = create_test_query_engine();
@@ -1687,4 +1838,539 @@ mod test {
&BTreeMap::from([(1_u64, 30_u64), (2_u64, 40_u64)])
);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_for_supported_aggregate() {
let query_engine = create_test_query_engine();
let query_ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 49,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
]],
query_ctx,
catalog_manager: create_test_query_engine()
.engine_state()
.catalog_manager()
.clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
assert!(!plan_text.contains("Union"));
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_rejects_avg() {
let query_engine = create_test_query_engine();
let query_ctx = QueryContext::arc();
let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 50,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
]],
query_ctx,
catalog_manager: create_test_query_engine()
.engine_state()
.catalog_manager()
.clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
task.mark_all_windows_as_dirty().unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
match task.gen_insert_plan(&query_engine, None).await {
Err(err) => assert!(format!("{err}").contains("UNSUPPORTED_INCREMENTAL_AGG")),
Ok(_) => panic!("expected UNSUPPORTED_INCREMENTAL_AGG error for avg query"),
}
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_sum_only_new_and_both_sides() {
let query_engine = create_test_query_engine();
register_test_sink_table(&query_engine, "sink_semantic", &[(20, 2)]);
let query_ctx = QueryContext::arc();
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts WHERE ts >= 2 AND ts <= 3 GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 51,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, Some(22)), (3, Some(3))]);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_max_only_new_and_both_sides() {
let query_engine = create_test_query_engine();
register_test_sink_table(&query_engine, "sink_semantic_max", &[(20, 2)]);
let query_ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts WHERE ts >= 2 AND ts <= 3 GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 52,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic_max".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_sum_nullable_delta_keeps_old_state() {
let query_engine = create_test_query_engine();
register_test_sink_table(&query_engine, "sink_semantic_sum_null", &[(20, 2)]);
register_test_table(
&query_engine,
"numbers_with_nullable_ts",
&[(None, 2), (Some(3), 3)],
);
let query_ctx = QueryContext::arc();
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 53,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic_sum_null".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_nullable_ts".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_max_nullable_delta_keeps_old_state() {
let query_engine = create_test_query_engine();
register_test_sink_table(&query_engine, "sink_semantic_max_null", &[(20, 2)]);
register_test_table(
&query_engine,
"numbers_with_nullable_ts_max",
&[(None, 2), (Some(3), 3)],
);
let query_ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_nullable_ts_max GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 54,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic_max_null".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_nullable_ts_max".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, Some(20)), (3, Some(3))]);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_sum_sink_null_delta_nonnull_uses_delta() {
let query_engine = create_test_query_engine();
register_test_table(&query_engine, "sink_semantic_sum_sink_null", &[(None, 2)]);
register_test_table(
&query_engine,
"numbers_with_nullable_ts_sink_null",
&[(Some(7), 2), (Some(3), 3)],
);
let query_ctx = QueryContext::arc();
let sql =
"SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts_sink_null GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 55,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic_sum_sink_null".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_nullable_ts_sink_null".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, Some(7)), (3, Some(3))]);
}
#[tokio::test]
async fn test_rewrite_incremental_sql_plan_semantics_sum_double_null_stays_null() {
let query_engine = create_test_query_engine();
register_test_table(&query_engine, "sink_semantic_sum_double_null", &[(None, 2)]);
register_test_table(
&query_engine,
"numbers_with_nullable_ts_double_null",
&[(None, 2), (Some(3), 3)],
);
let query_ctx = QueryContext::arc();
let sql = "SELECT sum(number) AS number, ts FROM numbers_with_nullable_ts_double_null GROUP BY ts";
let plan = sql_to_df_plan(query_ctx.clone(), query_engine.clone(), sql, true)
.await
.unwrap();
let (_tx, rx) = tokio::sync::oneshot::channel();
let task = BatchingTask::try_new(TaskArgs {
flow_id: 56,
query: sql,
plan,
time_window_expr: None,
expire_after: None,
sink_table_name: [
"greptime".to_string(),
"public".to_string(),
"sink_semantic_sum_double_null".to_string(),
],
source_table_names: vec![[
"greptime".to_string(),
"public".to_string(),
"numbers_with_nullable_ts_double_null".to_string(),
]],
query_ctx,
catalog_manager: query_engine.engine_state().catalog_manager().clone(),
shutdown_rx: rx,
batch_opts: Arc::new(BatchingModeOptions::default()),
flow_eval_interval: None,
})
.unwrap();
{
let mut state = task.state.write().unwrap();
state.advance_checkpoints(HashMap::from([(1_u64, 10_u64)]));
}
let raw_plan = sql_to_df_plan(
task.state.read().unwrap().query_ctx.clone(),
query_engine.clone(),
sql,
false,
)
.await
.unwrap();
let rewritten = task
.rewrite_incremental_sql_plan_if_needed(raw_plan)
.await
.unwrap();
let output = query_engine
.execute(rewritten, task.state.read().unwrap().query_ctx.clone())
.await
.unwrap();
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::AffectedRows(_) => panic!("expected query output"),
};
let batches = util::collect(stream).await.unwrap();
let rows = extract_ts_number_rows(&batches);
assert_eq!(rows, vec![(2, None), (3, Some(3))]);
}
}

View File

@@ -19,15 +19,21 @@ use std::sync::Arc;
use catalog::CatalogManagerRef;
use common_error::ext::BoxedError;
use common_function::aggrs::aggr_wrapper::get_aggr_func;
use common_telemetry::debug;
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result as DfResult;
use datafusion::logical_expr::Expr;
use datafusion::sql::unparser::Unparser;
use datafusion_common::tree_node::{
Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
use datafusion_expr::{Distinct, LogicalPlan, Projection};
use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference};
use datafusion_expr::logical_plan::TableScan;
use datafusion_expr::{
Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr,
bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when,
};
use datatypes::schema::{ColumnSchema, SchemaRef};
use query::QueryEngineRef;
use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
@@ -37,12 +43,304 @@ use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;
use sql::statements::tql::Tql;
use table::TableRef;
use table::table::adapter::DfTableProviderAdapter;
use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
use crate::df_optimizer::apply_df_optimizer;
use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
use crate::{Error, TableName};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PocIncrementalMergeColumn {
pub output_name: String,
pub merge_function: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PocIncrementalAggregateAnalysis {
pub group_columns: Vec<String>,
pub merge_columns: Vec<PocIncrementalMergeColumn>,
pub unsupported_exprs: Vec<String>,
}
#[derive(Default)]
struct LastAggregateExprFinder {
aggr_exprs: Option<Vec<Expr>>,
}
impl TreeNodeVisitor<'_> for LastAggregateExprFinder {
type Node = LogicalPlan;
fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
if let LogicalPlan::Aggregate(aggregate) = node {
self.aggr_exprs = Some(aggregate.aggr_expr.clone());
}
Ok(TreeNodeRecursion::Continue)
}
}
pub fn analyze_poc_incremental_aggregate_plan(
plan: &LogicalPlan,
) -> Result<Option<PocIncrementalAggregateAnalysis>, Error> {
let mut group_finder = FindGroupByFinalName::default();
plan.visit(&mut group_finder)
.with_context(|_| DatafusionSnafu {
context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
})?;
let mut aggregate_finder = LastAggregateExprFinder::default();
plan.visit(&mut aggregate_finder)
.with_context(|_| DatafusionSnafu {
context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"),
})?;
let Some(aggr_exprs) = aggregate_finder.aggr_exprs else {
return Ok(None);
};
let mut output_aliases = HashMap::new();
if let LogicalPlan::Projection(projection) = plan {
for expr in &projection.expr {
match expr {
Expr::Alias(alias) => {
if let Expr::Column(col) = alias.expr.as_ref() {
output_aliases.insert(col.name.clone(), alias.name.clone());
}
}
Expr::Column(col) => {
output_aliases.insert(col.name.clone(), col.name.clone());
}
_ => {}
}
}
}
let mut group_columns = group_finder
.get_group_expr_names()
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
group_columns.sort();
let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
let mut unsupported_exprs = Vec::new();
for aggr_expr in aggr_exprs {
let Some(aggr_func) = get_aggr_func(&aggr_expr) else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
let aggr_name = aggr_func.func.name().to_ascii_lowercase();
let merge_function = if aggr_func.params.distinct {
None
} else {
match aggr_name.as_str() {
"sum" => Some("sum"),
"count" => Some("sum"),
"min" => Some("min"),
"max" => Some("max"),
"bool_and" => Some("bool_and"),
"bool_or" => Some("bool_or"),
"bit_and" => Some("bit_and"),
"bit_or" => Some("bit_or"),
"bit_xor" => Some("bit_xor"),
_ => None,
}
};
let Some(merge_function) = merge_function else {
unsupported_exprs.push(aggr_expr.to_string());
continue;
};
let raw_name = aggr_expr.qualified_name().1;
let output_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name);
merge_columns.push(PocIncrementalMergeColumn {
output_name,
merge_function: merge_function.to_string(),
});
}
Ok(Some(PocIncrementalAggregateAnalysis {
group_columns,
merge_columns,
unsupported_exprs,
}))
}
pub async fn rewrite_poc_incremental_aggregate_with_sink_merge(
delta_plan: &LogicalPlan,
analysis: &PocIncrementalAggregateAnalysis,
sink_table: TableRef,
sink_table_name: &TableName,
) -> Result<LogicalPlan, Error> {
ensure!(
analysis.unsupported_exprs.is_empty(),
InvalidQuerySnafu {
reason: format!(
"UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
analysis.unsupported_exprs
)
}
);
ensure!(
!analysis.merge_columns.is_empty(),
InvalidQuerySnafu {
reason:
"UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
.to_string()
}
);
let delta_alias = "__flow_delta";
let sink_alias = "__flow_sink";
let mut selected_columns = analysis.group_columns.clone();
selected_columns.extend(analysis.merge_columns.iter().map(|c| c.output_name.clone()));
let selected_exprs = selected_columns.iter().map(col).collect::<Vec<_>>();
let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
.project(selected_exprs.clone())
.with_context(|_| DatafusionSnafu {
context: "Failed to project delta plan for incremental sink merge".to_string(),
})?
.alias(delta_alias)
.with_context(|_| DatafusionSnafu {
context: "Failed to alias delta plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build projected delta plan for incremental sink merge".to_string(),
})?;
let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
let table_source = Arc::new(DefaultTableSource::new(table_provider));
let sink_scan = LogicalPlan::TableScan(
TableScan::try_new(
TableReference::Full {
catalog: sink_table_name[0].clone().into(),
schema: sink_table_name[1].clone().into(),
table: sink_table_name[2].clone().into(),
},
table_source,
None,
vec![],
None,
)
.with_context(|_| DatafusionSnafu {
context: "Failed to build sink table scan for incremental sink merge".to_string(),
})?,
);
let sink_selected = LogicalPlanBuilder::from(sink_scan)
.project(selected_exprs)
.with_context(|_| DatafusionSnafu {
context: "Failed to project sink table scan for incremental sink merge".to_string(),
})?
.alias(sink_alias)
.with_context(|_| DatafusionSnafu {
context: "Failed to alias sink plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build projected sink plan for incremental sink merge".to_string(),
})?;
let join_keys = (
analysis
.group_columns
.iter()
.map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}")))
.collect::<Vec<_>>(),
analysis
.group_columns
.iter()
.map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}")))
.collect::<Vec<_>>(),
);
let joined = LogicalPlanBuilder::from(delta_selected)
.join(sink_selected, JoinType::Left, join_keys, None)
.with_context(|_| DatafusionSnafu {
context: "Failed to left join delta and sink plans for incremental sink merge"
.to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to build left join plan for incremental sink merge".to_string(),
})?;
let mut projection_exprs = analysis
.group_columns
.iter()
.map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone()))
.collect::<Vec<_>>();
projection_exprs.extend(
analysis
.merge_columns
.iter()
.map(|merge_col| build_left_join_merge_expr(delta_alias, sink_alias, merge_col)),
);
LogicalPlanBuilder::from(joined)
.project(projection_exprs)
.with_context(|_| DatafusionSnafu {
context: "Failed to build projection merge plan for incremental sink merge".to_string(),
})?
.build()
.with_context(|_| DatafusionSnafu {
context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
})
}
fn build_left_join_merge_expr(
delta_alias: &str,
sink_alias: &str,
merge_col: &PocIncrementalMergeColumn,
) -> Expr {
let left = col(format!("{delta_alias}.{}", merge_col.output_name));
let right = col(format!("{sink_alias}.{}", merge_col.output_name));
let merged = match merge_col.merge_function.as_str() {
"sum" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
.unwrap(),
"min" => when(is_null(right.clone()), left.clone())
.when(left.clone().lt_eq(right.clone()), left.clone())
.otherwise(right.clone())
.unwrap(),
"max" => when(is_null(right.clone()), left.clone())
.when(left.clone().gt_eq(right.clone()), left.clone())
.otherwise(right.clone())
.unwrap(),
"bool_and" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(and(left.clone(), right.clone()))
.unwrap(),
"bool_or" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(or(left.clone(), right.clone()))
.unwrap(),
"bit_and" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_and(left.clone(), right.clone()))
.unwrap(),
"bit_or" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_or(left.clone(), right.clone()))
.unwrap(),
"bit_xor" => when(is_null(left.clone()), right.clone())
.when(is_null(right.clone()), left.clone())
.otherwise(bitwise_xor(left.clone(), right.clone()))
.unwrap(),
other => Expr::Literal(
ScalarValue::Utf8(Some(format!("UNSUPPORTED_INCREMENTAL_AGG:{other}"))),
None,
),
};
merged.alias(merge_col.output_name.clone())
}
pub async fn get_table_info_df_schema(
catalog_mr: CatalogManagerRef,
table_name: TableName,
@@ -907,6 +1205,76 @@ mod test {
}
}
#[tokio::test]
async fn test_analyze_poc_incremental_aggregate_plan() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_poc_incremental_aggregate_plan(&plan)
.unwrap()
.unwrap();
assert!(analysis.unsupported_exprs.is_empty());
assert!(analysis.group_columns.contains(&"ts".to_string()));
assert_eq!(analysis.merge_columns.len(), 1);
assert_eq!(analysis.merge_columns[0].output_name, "number");
assert_eq!(analysis.merge_columns[0].merge_function, "max");
}
#[tokio::test]
async fn test_analyze_poc_incremental_aggregate_plan_rejects_avg() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap();
let analysis = analyze_poc_incremental_aggregate_plan(&plan)
.unwrap()
.unwrap();
assert!(!analysis.unsupported_exprs.is_empty());
}
#[tokio::test]
async fn test_rewrite_poc_incremental_aggregate_with_left_join() {
let query_engine = create_test_query_engine();
let ctx = QueryContext::arc();
let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts";
let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
.await
.unwrap();
let analysis = analyze_poc_incremental_aggregate_plan(&plan)
.unwrap()
.unwrap();
let (sink_table, _) = get_table_info_df_schema(
query_engine.engine_state().catalog_manager().clone(),
[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
)
.await
.unwrap();
let rewritten = rewrite_poc_incremental_aggregate_with_sink_merge(
&plan,
&analysis,
sink_table,
&[
"greptime".to_string(),
"public".to_string(),
"numbers_with_ts".to_string(),
],
)
.await
.unwrap();
let plan_text = format!("{}", rewritten.display_indent());
assert!(plan_text.contains("Left Join"));
assert!(!plan_text.contains("Union"));
}
#[tokio::test]
async fn test_null_cast() {
let query_engine = create_test_query_engine();