1use std::collections::{BTreeMap, HashMap, HashSet};
18use std::sync::Arc;
19use std::time::Duration;
20
21use api::v1::flow::DirtyWindowRequests;
22use catalog::CatalogManagerRef;
23use common_error::ext::BoxedError;
24use common_meta::ddl::create_flow::FlowType;
25use common_meta::key::TableMetadataManagerRef;
26use common_meta::key::flow::FlowMetadataManagerRef;
27use common_meta::key::flow::flow_state::FlowStat;
28use common_meta::key::table_info::{TableInfoManager, TableInfoValue};
29use common_runtime::JoinHandle;
30use common_telemetry::tracing::warn;
31use common_telemetry::{debug, info};
32use common_time::TimeToLive;
33use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
34use datafusion_expr::LogicalPlan;
35use datatypes::prelude::ConcreteDataType;
36use query::QueryEngineRef;
37use session::context::QueryContext;
38use snafu::{OptionExt, ResultExt, ensure};
39use sql::parsers::utils::is_tql;
40use store_api::metric_engine_consts::is_metric_engine_internal_column;
41use store_api::storage::{RegionId, TableId};
42use table::table_reference::TableReference;
43use tokio::sync::{RwLock, oneshot};
44
45use crate::batching_mode::BatchingModeOptions;
46use crate::batching_mode::frontend_client::FrontendClient;
47use crate::batching_mode::task::{BatchingTask, TaskArgs};
48use crate::batching_mode::time_window::{TimeWindowExpr, find_time_window_expr};
49use crate::batching_mode::utils::sql_to_df_plan;
50use crate::engine::{FlowEngine, FlowStatProvider};
51use crate::error::{
52 CreateFlowSnafu, DatafusionSnafu, ExternalSnafu, FlowAlreadyExistSnafu, FlowNotFoundSnafu,
53 InvalidQuerySnafu, TableNotFoundMetaSnafu, UnexpectedSnafu, UnsupportedSnafu,
54};
55use crate::metrics::METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW;
56use crate::{CreateFlowArgs, Error, FlowId, TableName};
57
58pub struct BatchingEngine {
62 runtime: RwLock<FlowRuntimeRegistry>,
63 pub(crate) frontend_client: Arc<FrontendClient>,
65 flow_metadata_manager: FlowMetadataManagerRef,
66 table_meta: TableMetadataManagerRef,
67 catalog_manager: CatalogManagerRef,
68 query_engine: QueryEngineRef,
69 pub(crate) batch_opts: Arc<BatchingModeOptions>,
72}
73
74#[derive(Default)]
75struct FlowRuntimeRegistry {
76 tasks: BTreeMap<FlowId, BatchingTask>,
77 shutdown_txs: BTreeMap<FlowId, oneshot::Sender<()>>,
78}
79
80impl FlowRuntimeRegistry {
81 fn insert(
82 &mut self,
83 flow_id: FlowId,
84 task: BatchingTask,
85 shutdown_tx: oneshot::Sender<()>,
86 ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
87 (
88 self.tasks.insert(flow_id, task),
89 self.shutdown_txs.insert(flow_id, shutdown_tx),
90 )
91 }
92
93 fn remove(&mut self, flow_id: FlowId) -> Option<(BatchingTask, Option<oneshot::Sender<()>>)> {
94 let task = self.tasks.remove(&flow_id)?;
95 let shutdown_tx = self.shutdown_txs.remove(&flow_id);
96 Some((task, shutdown_tx))
97 }
98
99 fn remove_if_current(
100 &mut self,
101 flow_id: FlowId,
102 task: &BatchingTask,
103 ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
104 if self
105 .tasks
106 .get(&flow_id)
107 .is_some_and(|current| Arc::ptr_eq(¤t.state, &task.state))
108 {
109 let Some((removed_task, removed_shutdown_tx)) = self.remove(flow_id) else {
110 return (None, None);
111 };
112 (Some(removed_task), removed_shutdown_tx)
113 } else {
114 (None, None)
115 }
116 }
117}
118
119impl BatchingEngine {
120 pub fn new(
121 frontend_client: Arc<FrontendClient>,
122 query_engine: QueryEngineRef,
123 flow_metadata_manager: FlowMetadataManagerRef,
124 table_meta: TableMetadataManagerRef,
125 catalog_manager: CatalogManagerRef,
126 batch_opts: BatchingModeOptions,
127 ) -> Self {
128 Self {
129 runtime: Default::default(),
130 frontend_client,
131 flow_metadata_manager,
132 table_meta,
133 catalog_manager,
134 query_engine,
135 batch_opts: Arc::new(batch_opts),
136 }
137 }
138
139 pub async fn get_last_exec_time_map(&self) -> BTreeMap<FlowId, i64> {
141 let runtime = self.runtime.read().await;
142 runtime
143 .tasks
144 .iter()
145 .filter_map(|(flow_id, task)| {
146 task.last_execution_time_millis()
147 .map(|timestamp| (*flow_id, timestamp))
148 })
149 .collect()
150 }
151
152 pub async fn handle_mark_dirty_time_window(
153 &self,
154 reqs: DirtyWindowRequests,
155 ) -> Result<(), Error> {
156 let table_info_mgr = self.table_meta.table_info_manager();
157
158 let mut group_by_table_id: HashMap<u32, Vec<_>> = HashMap::new();
159 for r in reqs.requests {
160 let tid = TableId::from(r.table_id);
161 let entry = group_by_table_id.entry(tid).or_default();
162 entry.extend(r.timestamps);
163 }
164 let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
165 let table_infos =
166 table_info_mgr
167 .batch_get(&tids)
168 .await
169 .with_context(|_| TableNotFoundMetaSnafu {
170 msg: format!("Failed to get table info for table ids: {:?}", tids),
171 })?;
172
173 let group_by_table_name = group_by_table_id
174 .into_iter()
175 .filter_map(|(id, timestamps)| {
176 let table_name = table_infos.get(&id).map(|info| info.table_name());
177 let Some(table_name) = table_name else {
178 warn!("Failed to get table infos for table id: {:?}", id);
179 return None;
180 };
181 let table_name = [
182 table_name.catalog_name,
183 table_name.schema_name,
184 table_name.table_name,
185 ];
186 let schema = &table_infos.get(&id).unwrap().table_info.meta.schema;
187 let time_index_unit = schema.column_schemas()[schema.timestamp_index().unwrap()]
188 .data_type
189 .as_timestamp()
190 .unwrap()
191 .unit();
192 Some((table_name, (timestamps, time_index_unit)))
193 })
194 .collect::<HashMap<_, _>>();
195
196 let group_by_table_name = Arc::new(group_by_table_name);
197
198 let tasks = self
199 .runtime
200 .read()
201 .await
202 .tasks
203 .values()
204 .cloned()
205 .collect::<Vec<_>>();
206 let mut handles = Vec::new();
207
208 for task in tasks {
209 let src_table_names = &task.config.source_table_names;
210
211 if src_table_names
212 .iter()
213 .all(|name| !group_by_table_name.contains_key(name))
214 {
215 continue;
216 }
217
218 let group_by_table_name = group_by_table_name.clone();
219 let task = task.clone();
220
221 let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
222 let src_table_names = &task.config.source_table_names;
223 let mut all_dirty_windows = HashSet::new();
224 let mut is_dirty = false;
225 for src_table_name in src_table_names {
226 if let Some((timestamps, unit)) = group_by_table_name.get(src_table_name) {
227 let Some(expr) = &task.config.time_window_expr else {
228 is_dirty = true;
229 continue;
230 };
231 for timestamp in timestamps {
232 let align_start = expr
233 .eval(common_time::Timestamp::new(*timestamp, *unit))?
234 .0
235 .context(UnexpectedSnafu {
236 reason: "Failed to eval start value",
237 })?;
238 all_dirty_windows.insert(align_start);
239 }
240 }
241 }
242 let mut state = task.state.write().unwrap();
243 if is_dirty {
244 state.dirty_time_windows.set_dirty();
245 }
246 let flow_id_label = task.config.flow_id.to_string();
247 for timestamp in all_dirty_windows {
248 state.dirty_time_windows.add_window(timestamp, None);
249 }
250
251 METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW
252 .with_label_values(&[&flow_id_label])
253 .set(state.dirty_time_windows.len() as f64);
254 Ok(())
255 });
256 handles.push(handle);
257 }
258 for handle in handles {
259 match handle.await {
260 Err(e) => {
261 warn!("Failed to handle inserts: {e}");
262 }
263 Ok(Ok(())) => (),
264 Ok(Err(e)) => {
265 warn!("Failed to handle inserts: {e}");
266 }
267 }
268 }
269
270 Ok(())
271 }
272
273 pub async fn handle_inserts_inner(
274 &self,
275 request: api::v1::region::InsertRequests,
276 ) -> Result<(), Error> {
277 let table_info_mgr = self.table_meta.table_info_manager();
278 let mut group_by_table_id: HashMap<TableId, Vec<api::v1::Rows>> = HashMap::new();
279
280 for r in request.requests {
281 let tid = RegionId::from(r.region_id).table_id();
282 let entry = group_by_table_id.entry(tid).or_default();
283 if let Some(rows) = r.rows {
284 entry.push(rows);
285 }
286 }
287
288 let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
289 let table_infos =
290 table_info_mgr
291 .batch_get(&tids)
292 .await
293 .with_context(|_| TableNotFoundMetaSnafu {
294 msg: format!("Failed to get table info for table ids: {:?}", tids),
295 })?;
296
297 let missing_tids = tids
298 .iter()
299 .filter(|id| !table_infos.contains_key(id))
300 .collect::<Vec<_>>();
301 if !missing_tids.is_empty() {
302 warn!(
303 "Failed to get all the table info for table ids, expected table ids: {:?}, those table doesn't exist: {:?}",
304 tids, missing_tids
305 );
306 }
307
308 let group_by_table_name = group_by_table_id
309 .into_iter()
310 .filter_map(|(id, rows)| {
311 let table_name = table_infos.get(&id).map(|info| info.table_name());
312 let Some(table_name) = table_name else {
313 warn!("Failed to get table infos for table id: {:?}", id);
314 return None;
315 };
316 let table_name = [
317 table_name.catalog_name,
318 table_name.schema_name,
319 table_name.table_name,
320 ];
321 Some((table_name, rows))
322 })
323 .collect::<HashMap<_, _>>();
324
325 let group_by_table_name = Arc::new(group_by_table_name);
326
327 let tasks = self
328 .runtime
329 .read()
330 .await
331 .tasks
332 .values()
333 .cloned()
334 .collect::<Vec<_>>();
335 let mut handles = Vec::new();
336 for task in tasks {
337 let src_table_names = &task.config.source_table_names;
338
339 if src_table_names
340 .iter()
341 .all(|name| !group_by_table_name.contains_key(name))
342 {
343 continue;
344 }
345
346 let group_by_table_name = group_by_table_name.clone();
347 let task = task.clone();
348
349 let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
350 let src_table_names = &task.config.source_table_names;
351
352 let mut is_dirty = false;
353
354 for src_table_name in src_table_names {
355 if let Some(entry) = group_by_table_name.get(src_table_name) {
356 let Some(expr) = &task.config.time_window_expr else {
357 is_dirty = true;
358 continue;
359 };
360 let involved_time_windows = expr.handle_rows(entry.clone()).await?;
361 let mut state = task.state.write().unwrap();
362 state
363 .dirty_time_windows
364 .add_lower_bounds(involved_time_windows.into_iter());
365 }
366 }
367 if is_dirty {
368 task.state.write().unwrap().dirty_time_windows.set_dirty();
369 }
370
371 Ok(())
372 });
373 handles.push(handle);
374 }
375
376 for handle in handles {
377 match handle.await {
378 Err(e) => {
379 warn!("Failed to handle inserts: {e}");
380 }
381 Ok(Ok(())) => (),
382 Ok(Err(e)) => {
383 warn!("Failed to handle inserts: {e}");
384 }
385 }
386 }
387 Ok(())
388 }
389}
390
391impl FlowStatProvider for BatchingEngine {
392 async fn flow_stat(&self) -> FlowStat {
393 FlowStat {
394 state_size: BTreeMap::new(),
395 last_exec_time_map: self
396 .get_last_exec_time_map()
397 .await
398 .into_iter()
399 .map(|(flow_id, timestamp)| (flow_id as u32, timestamp))
400 .collect(),
401 }
402 }
403}
404
405async fn get_table_name(
406 table_info: &TableInfoManager,
407 table_id: &TableId,
408) -> Result<TableName, Error> {
409 get_table_info(table_info, table_id).await.map(|info| {
410 let name = info.table_name();
411 [name.catalog_name, name.schema_name, name.table_name]
412 })
413}
414
415async fn get_table_info(
416 table_info: &TableInfoManager,
417 table_id: &TableId,
418) -> Result<TableInfoValue, Error> {
419 table_info
420 .get(*table_id)
421 .await
422 .map_err(BoxedError::new)
423 .context(ExternalSnafu)?
424 .with_context(|| UnexpectedSnafu {
425 reason: format!("Table id = {:?}, couldn't found table name", table_id),
426 })
427 .map(|info| info.into_inner())
428}
429
430impl BatchingEngine {
431 pub async fn create_flow_inner(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
432 let CreateFlowArgs {
433 flow_id,
434 sink_table_name,
435 source_table_ids,
436 create_if_not_exists,
437 or_replace,
438 expire_after,
439 eval_interval,
440 comment: _,
441 sql,
442 flow_options,
443 query_ctx,
444 } = args;
445
446 {
448 let is_exist = self.runtime.read().await.tasks.contains_key(&flow_id);
449 match (create_if_not_exists, or_replace, is_exist) {
450 (_, true, true) => {
452 info!("Replacing flow with id={}", flow_id);
453 }
454 (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
455 (true, false, true) => {
457 info!("Flow with id={} already exists, do nothing", flow_id);
458 return Ok(None);
459 }
460
461 (_, _, false) => (),
463 }
464 }
465
466 let query_ctx = query_ctx.context({
467 UnexpectedSnafu {
468 reason: "Query context is None".to_string(),
469 }
470 })?;
471 let query_ctx = Arc::new(query_ctx);
472 let is_tql = is_tql(query_ctx.sql_dialect(), &sql)
473 .map_err(BoxedError::new)
474 .context(CreateFlowSnafu { sql: &sql })?;
475
476 if eval_interval.is_none() && is_tql {
478 InvalidQuerySnafu {
479 reason: "TQL query requires EVAL INTERVAL to be set".to_string(),
480 }
481 .fail()?;
482 }
483
484 let flow_type = flow_options.get(FlowType::FLOW_TYPE_KEY);
485
486 ensure!(
487 match flow_type {
488 None => true,
489 Some(ty) if ty == FlowType::BATCHING => true,
490 _ => false,
491 },
492 UnexpectedSnafu {
493 reason: format!("Flow type is not batching nor None, got {flow_type:?}")
494 }
495 );
496
497 let mut source_table_names = Vec::with_capacity(2);
498 for src_id in source_table_ids {
499 let table_name = get_table_name(self.table_meta.table_info_manager(), &src_id).await?;
501 let table_info = get_table_info(self.table_meta.table_info_manager(), &src_id).await?;
502 ensure!(
503 table_info.table_info.meta.options.ttl != Some(TimeToLive::Instant),
504 UnsupportedSnafu {
505 reason: format!(
506 "Source table `{}`(id={}) has instant TTL, Instant TTL is not supported under batching mode. Consider using a TTL longer than flush interval",
507 table_name.join("."),
508 src_id
509 ),
510 }
511 );
512
513 source_table_names.push(table_name);
514 }
515
516 let (tx, rx) = oneshot::channel();
517
518 let plan = sql_to_df_plan(query_ctx.clone(), self.query_engine.clone(), &sql, true).await?;
519
520 if is_tql {
521 self.check_is_tql_table(&plan, &query_ctx).await?;
522 }
523
524 let phy_expr = if !is_tql {
525 let (column_name, time_window_expr, _, df_schema) = find_time_window_expr(
526 &plan,
527 self.query_engine.engine_state().catalog_manager().clone(),
528 query_ctx.clone(),
529 )
530 .await?;
531 time_window_expr
532 .map(|expr| {
533 TimeWindowExpr::from_expr(
534 &expr,
535 &column_name,
536 &df_schema,
537 &self.query_engine.engine_state().session_state(),
538 )
539 })
540 .transpose()?
541 } else {
542 None
544 };
545
546 debug!(
547 "Flow id={}, found time window expr={}",
548 flow_id,
549 phy_expr
550 .as_ref()
551 .map(|phy_expr| phy_expr.to_string())
552 .unwrap_or("None".to_string())
553 );
554
555 let task_args = TaskArgs {
556 flow_id,
557 query: &sql,
558 plan,
559 time_window_expr: phy_expr,
560 expire_after,
561 sink_table_name,
562 source_table_names,
563 query_ctx,
564 catalog_manager: self.catalog_manager.clone(),
565 shutdown_rx: rx,
566 batch_opts: self.batch_opts.clone(),
567 flow_eval_interval: eval_interval.map(|secs| Duration::from_secs(secs as u64)),
568 };
569
570 let task = BatchingTask::try_new(task_args)?;
571
572 let task_inner = task.clone();
573 let engine = self.query_engine.clone();
574 let frontend = self.frontend_client.clone();
575
576 task.check_or_create_sink_table(&engine, &frontend).await?;
578
579 let (start_tx, start_rx) = oneshot::channel();
580
581 let handle = common_runtime::spawn_global(async move {
583 if start_rx.await.is_ok() {
584 task_inner.start_executing_loop(engine, frontend).await;
585 }
586 });
587 task.state.write().unwrap().task_handle = Some(handle);
588 let task_for_rollback = task.clone();
589
590 let (replaced_old_task_opt, replaced_old_shutdown_tx) = {
595 let mut runtime = self.runtime.write().await;
596
597 let is_exist = runtime.tasks.contains_key(&flow_id);
598 match (create_if_not_exists, or_replace, is_exist) {
599 (_, true, true) => {
600 info!(
601 "Replacing flow with id={} after final registry check",
602 flow_id
603 );
604 }
605 (false, false, true) => {
606 abort_flow_task(flow_id, Some(task), "unregistered");
607 return FlowAlreadyExistSnafu { id: flow_id }.fail();
608 }
609 (true, false, true) => {
610 info!(
611 "Flow with id={} already exists at final registry check, do nothing",
612 flow_id
613 );
614 abort_flow_task(flow_id, Some(task), "unregistered");
615 return Ok(None);
616 }
617 (_, _, false) => (),
618 }
619
620 runtime.insert(flow_id, task, tx)
621 };
622
623 notify_flow_shutdown(flow_id, replaced_old_shutdown_tx, "replaced");
624 abort_flow_task(flow_id, replaced_old_task_opt, "replaced");
625 if start_tx.send(()).is_err() {
626 self.rollback_flow_runtime_if_current(flow_id, &task_for_rollback)
627 .await;
628 UnexpectedSnafu {
629 reason: format!("Failed to start flow {flow_id} due to task already dropped"),
630 }
631 .fail()?;
632 }
633
634 Ok(Some(flow_id))
635 }
636
637 async fn check_is_tql_table(
638 &self,
639 query: &LogicalPlan,
640 query_ctx: &QueryContext,
641 ) -> Result<(), Error> {
642 struct CollectTableRef {
643 table_refs: HashSet<datafusion_common::TableReference>,
644 }
645
646 impl TreeNodeVisitor<'_> for CollectTableRef {
647 type Node = LogicalPlan;
648 fn f_down(
649 &mut self,
650 node: &Self::Node,
651 ) -> datafusion_common::Result<TreeNodeRecursion> {
652 if let LogicalPlan::TableScan(scan) = node {
653 self.table_refs.insert(scan.table_name.clone());
654 }
655 Ok(TreeNodeRecursion::Continue)
656 }
657 }
658 let mut table_refs = CollectTableRef {
659 table_refs: HashSet::new(),
660 };
661 query
662 .visit_with_subqueries(&mut table_refs)
663 .context(DatafusionSnafu {
664 context: "Checking if all source tables are TQL tables",
665 })?;
666
667 let default_catalog = query_ctx.current_catalog();
668 let default_schema = query_ctx.current_schema();
669 let default_schema = &default_schema;
670
671 for table_ref in table_refs.table_refs {
672 let table_ref = match &table_ref {
673 datafusion_common::TableReference::Bare { table } => {
674 TableReference::full(default_catalog, default_schema, table)
675 }
676 datafusion_common::TableReference::Partial { schema, table } => {
677 TableReference::full(default_catalog, schema, table)
678 }
679 datafusion_common::TableReference::Full {
680 catalog,
681 schema,
682 table,
683 } => TableReference::full(catalog, schema, table),
684 };
685
686 let table_id = self
687 .table_meta
688 .table_name_manager()
689 .get(table_ref.into())
690 .await
691 .map_err(BoxedError::new)
692 .context(ExternalSnafu)?
693 .with_context(|| UnexpectedSnafu {
694 reason: format!("Failed to get table id for table: {}", table_ref),
695 })?
696 .table_id();
697 let table_info =
698 get_table_info(self.table_meta.table_info_manager(), &table_id).await?;
699 let value_cols = table_info
701 .table_info
702 .meta
703 .schema
704 .column_schemas()
705 .iter()
706 .filter(|col| col.data_type == ConcreteDataType::float64_datatype())
707 .collect::<Vec<_>>();
708 ensure!(
709 value_cols.len() == 1,
710 InvalidQuerySnafu {
711 reason: format!(
712 "TQL query only supports one f64 value column, table `{}`(id={}) has {} f64 value columns, columns are: {:?}",
713 table_ref,
714 table_id,
715 value_cols.len(),
716 value_cols
717 ),
718 }
719 );
720 let pk_idxs = table_info
722 .table_info
723 .meta
724 .primary_key_indices
725 .iter()
726 .collect::<HashSet<_>>();
727
728 for (idx, col) in table_info
729 .table_info
730 .meta
731 .schema
732 .column_schemas()
733 .iter()
734 .enumerate()
735 {
736 if is_metric_engine_internal_column(&col.name) {
737 continue;
738 }
739 let is_pk: bool = pk_idxs.contains(&&idx);
745
746 ensure!(
747 col.data_type == ConcreteDataType::float64_datatype()
748 || col.data_type.is_timestamp()
749 || (col.data_type == ConcreteDataType::string_datatype() && is_pk),
750 InvalidQuerySnafu {
751 reason: format!(
752 "TQL query only supports f64 value column, timestamp column and string tag columns, table `{}`(id={}) has column `{}` with type {:?} which is not supported",
753 table_ref, table_id, col.name, col.data_type
754 ),
755 }
756 );
757 }
758 }
759 Ok(())
760 }
761
762 pub async fn remove_flow_inner(&self, flow_id: FlowId) -> Result<(), Error> {
763 let (task, shutdown_tx) = {
764 let mut runtime = self.runtime.write().await;
765 let Some((task, shutdown_tx)) = runtime.remove(flow_id) else {
766 warn!("Flow {flow_id} not found in tasks");
767 FlowNotFoundSnafu { id: flow_id }.fail()?
768 };
769 (task, shutdown_tx)
770 };
771
772 let had_shutdown_tx = notify_flow_shutdown(flow_id, shutdown_tx, "removed");
773 abort_flow_task(flow_id, Some(task), "removed");
774
775 if !had_shutdown_tx {
776 UnexpectedSnafu {
777 reason: format!("Can't found shutdown tx for flow {flow_id}"),
778 }
779 .fail()?
780 }
781
782 Ok(())
783 }
784
785 pub async fn flush_flow_inner(&self, flow_id: FlowId) -> Result<usize, Error> {
788 debug!("Try flush flow {flow_id}");
789 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
793 let task = self.runtime.read().await.tasks.get(&flow_id).cloned();
794 let task = task.with_context(|| FlowNotFoundSnafu { id: flow_id })?;
795
796 let time_window_size = task
797 .config
798 .time_window_expr
799 .as_ref()
800 .and_then(|expr| *expr.time_window_size());
801
802 let cur_dirty_window_cnt = time_window_size.map(|time_window_size| {
803 task.state
804 .read()
805 .unwrap()
806 .dirty_time_windows
807 .effective_count(&time_window_size)
808 });
809
810 let res = task
811 .gen_exec_once(
812 &self.query_engine,
813 &self.frontend_client,
814 cur_dirty_window_cnt,
815 )
816 .await?;
817
818 let affected_rows = res.map(|(r, _)| r).unwrap_or_default();
819 debug!(
820 "Successfully flush flow {flow_id}, affected rows={}",
821 affected_rows
822 );
823 Ok(affected_rows)
824 }
825
826 pub async fn flow_exist_inner(&self, flow_id: FlowId) -> bool {
828 self.runtime.read().await.tasks.contains_key(&flow_id)
829 }
830
831 async fn rollback_flow_runtime_if_current(&self, flow_id: FlowId, task: &BatchingTask) {
832 let (removed_task, removed_shutdown_tx) = {
833 let mut runtime = self.runtime.write().await;
834 runtime.remove_if_current(flow_id, task)
835 };
836
837 notify_flow_shutdown(flow_id, removed_shutdown_tx, "rolled back");
838 abort_flow_task(flow_id, removed_task, "rolled back");
839 }
840}
841
842fn notify_flow_shutdown(flow_id: FlowId, tx: Option<oneshot::Sender<()>>, action: &str) -> bool {
843 let Some(tx) = tx else {
844 return false;
845 };
846
847 if tx.send(()).is_err() {
848 warn!(
849 "Fail to shutdown {action} flow {flow_id} due to receiver already dropped, maybe flow {flow_id} is already dropped?"
850 );
851 }
852
853 true
854}
855
856fn abort_flow_task(flow_id: FlowId, task: Option<BatchingTask>, action: &str) -> bool {
857 let Some(task) = task else {
858 return false;
859 };
860
861 if let Some(handle) = task.state.write().unwrap().task_handle.take() {
862 handle.abort();
863 debug!("Aborted {action} flow task {flow_id}");
864 return true;
865 }
866
867 false
868}
869
870impl FlowEngine for BatchingEngine {
871 async fn create_flow(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
872 self.create_flow_inner(args).await
873 }
874 async fn remove_flow(&self, flow_id: FlowId) -> Result<(), Error> {
875 self.remove_flow_inner(flow_id).await
876 }
877 async fn flush_flow(&self, flow_id: FlowId) -> Result<usize, Error> {
878 self.flush_flow_inner(flow_id).await
879 }
880 async fn flow_exist(&self, flow_id: FlowId) -> Result<bool, Error> {
881 Ok(self.flow_exist_inner(flow_id).await)
882 }
883 async fn list_flows(&self) -> Result<impl IntoIterator<Item = FlowId>, Error> {
884 Ok(self
885 .runtime
886 .read()
887 .await
888 .tasks
889 .keys()
890 .cloned()
891 .collect::<Vec<_>>())
892 }
893 async fn handle_flow_inserts(
894 &self,
895 request: api::v1::region::InsertRequests,
896 ) -> Result<(), Error> {
897 self.handle_inserts_inner(request).await
898 }
899 async fn handle_mark_window_dirty(
900 &self,
901 req: api::v1::flow::DirtyWindowRequests,
902 ) -> Result<(), Error> {
903 self.handle_mark_dirty_time_window(req).await
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use catalog::memory::new_memory_catalog_manager;
910 use common_meta::key::TableMetadataManager;
911 use common_meta::key::flow::FlowMetadataManager;
912 use common_meta::kv_backend::memory::MemoryKvBackend;
913 use query::options::QueryOptions;
914 use session::context::QueryContext;
915
916 use super::*;
917 use crate::test_utils::create_test_query_engine;
918
919 struct DropNotify(Option<oneshot::Sender<()>>);
920
921 impl Drop for DropNotify {
922 fn drop(&mut self) {
923 if let Some(tx) = self.0.take() {
924 let _ = tx.send(());
925 }
926 }
927 }
928
929 async fn new_test_engine() -> BatchingEngine {
930 let kv_backend = Arc::new(MemoryKvBackend::new());
931 let table_meta = Arc::new(TableMetadataManager::new(kv_backend.clone()));
932 table_meta.init().await.unwrap();
933 let flow_meta = Arc::new(FlowMetadataManager::new(kv_backend));
934 let catalog_manager = new_memory_catalog_manager().unwrap();
935 let query_engine = create_test_query_engine();
936 let (frontend_client, _handler) =
937 FrontendClient::from_empty_grpc_handler(QueryOptions::default());
938
939 BatchingEngine::new(
940 Arc::new(frontend_client),
941 query_engine,
942 flow_meta,
943 table_meta,
944 catalog_manager,
945 BatchingModeOptions::default(),
946 )
947 }
948
949 async fn new_test_task(flow_id: FlowId) -> (BatchingTask, oneshot::Sender<()>) {
950 let query_engine = create_test_query_engine();
951 let ctx = QueryContext::arc();
952 let plan = sql_to_df_plan(
953 ctx.clone(),
954 query_engine.clone(),
955 "SELECT number, ts FROM numbers_with_ts",
956 true,
957 )
958 .await
959 .unwrap();
960 let (tx, rx) = oneshot::channel();
961
962 let task = BatchingTask::try_new(TaskArgs {
963 flow_id,
964 query: "SELECT number, ts FROM numbers_with_ts",
965 plan,
966 time_window_expr: None,
967 expire_after: None,
968 sink_table_name: [
969 "greptime".to_string(),
970 "public".to_string(),
971 "sink".to_string(),
972 ],
973 source_table_names: vec![[
974 "greptime".to_string(),
975 "public".to_string(),
976 "numbers_with_ts".to_string(),
977 ]],
978 query_ctx: ctx,
979 catalog_manager: query_engine.engine_state().catalog_manager().clone(),
980 shutdown_rx: rx,
981 batch_opts: Arc::new(BatchingModeOptions::default()),
982 flow_eval_interval: None,
983 })
984 .unwrap();
985
986 (task, tx)
987 }
988
989 async fn install_abort_observed_handle(task: &BatchingTask) -> oneshot::Receiver<()> {
990 let (drop_tx, drop_rx) = oneshot::channel();
991 let (entered_tx, entered_rx) = oneshot::channel();
992 let handle = tokio::spawn(async move {
993 let _guard = DropNotify(Some(drop_tx));
994 let _ = entered_tx.send(());
995 std::future::pending::<()>().await;
996 });
997 task.state.write().unwrap().task_handle = Some(handle);
998 tokio::time::timeout(Duration::from_secs(1), entered_rx)
999 .await
1000 .expect("test task handle should start")
1001 .expect("test task handle should report start");
1002 drop_rx
1003 }
1004
1005 #[tokio::test]
1006 async fn test_notify_flow_shutdown_sends_signal() {
1007 let (tx, rx) = oneshot::channel();
1008
1009 assert!(notify_flow_shutdown(42, Some(tx), "test"));
1010
1011 rx.await.expect("replaced flow should receive shutdown");
1012 }
1013
1014 #[test]
1015 fn test_notify_flow_shutdown_accepts_missing_sender() {
1016 assert!(!notify_flow_shutdown(42, None, "test"));
1017 }
1018
1019 #[tokio::test]
1020 async fn test_abort_flow_task_aborts_handle() {
1021 let (task, _shutdown_tx) = new_test_task(42).await;
1022 let drop_rx = install_abort_observed_handle(&task).await;
1023
1024 assert!(abort_flow_task(42, Some(task), "test"));
1025
1026 tokio::time::timeout(Duration::from_secs(1), drop_rx)
1027 .await
1028 .expect("aborted task should be dropped")
1029 .expect("drop notifier should fire");
1030 }
1031
1032 #[tokio::test]
1033 async fn test_remove_flow_inner_aborts_registered_task() {
1034 let engine = new_test_engine().await;
1035 let (task, shutdown_tx) = new_test_task(42).await;
1036 let drop_rx = install_abort_observed_handle(&task).await;
1037
1038 engine.runtime.write().await.insert(42, task, shutdown_tx);
1039
1040 engine.remove_flow_inner(42).await.unwrap();
1041
1042 tokio::time::timeout(Duration::from_secs(1), drop_rx)
1043 .await
1044 .expect("removed task should be dropped")
1045 .expect("drop notifier should fire");
1046 assert!(!engine.flow_exist_inner(42).await);
1047 assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1048 }
1049
1050 #[tokio::test]
1051 async fn test_or_replace_flow_runtime_replaces_old_handles_and_keeps_new_task() {
1052 let engine = new_test_engine().await;
1053 let (old_task, old_shutdown_tx) = new_test_task(42).await;
1054 let old_task_identity = old_task.clone();
1055 let old_drop_rx = install_abort_observed_handle(&old_task).await;
1056 let (new_task, new_shutdown_tx) = new_test_task(42).await;
1057 let new_task_identity = new_task.clone();
1058
1059 engine
1060 .runtime
1061 .write()
1062 .await
1063 .insert(42, old_task, old_shutdown_tx);
1064 let (replaced_old_task, replaced_old_shutdown_tx) =
1065 engine
1066 .runtime
1067 .write()
1068 .await
1069 .insert(42, new_task, new_shutdown_tx);
1070
1071 let replaced_old_task = replaced_old_task.expect("old task should be returned");
1072 assert!(Arc::ptr_eq(
1073 &replaced_old_task.state,
1074 &old_task_identity.state
1075 ));
1076 assert!(notify_flow_shutdown(
1077 42,
1078 replaced_old_shutdown_tx,
1079 "replaced"
1080 ));
1081 old_task_identity
1082 .state
1083 .write()
1084 .unwrap()
1085 .shutdown_rx
1086 .try_recv()
1087 .expect("old shutdown receiver should receive signal");
1088 assert!(abort_flow_task(42, Some(replaced_old_task), "replaced"));
1089
1090 tokio::time::timeout(Duration::from_secs(1), old_drop_rx)
1091 .await
1092 .expect("replaced task should be dropped")
1093 .expect("drop notifier should fire");
1094
1095 let runtime = engine.runtime.read().await;
1096 assert_eq!(1, runtime.tasks.len());
1097 assert_eq!(1, runtime.shutdown_txs.len());
1098 let registered_task = runtime.tasks.get(&42).expect("new task should remain");
1099 assert!(Arc::ptr_eq(
1100 ®istered_task.state,
1101 &new_task_identity.state
1102 ));
1103 assert!(runtime.shutdown_txs.contains_key(&42));
1104 assert!(matches!(
1105 new_task_identity
1106 .state
1107 .write()
1108 .unwrap()
1109 .shutdown_rx
1110 .try_recv(),
1111 Err(oneshot::error::TryRecvError::Empty)
1112 ));
1113 }
1114
1115 #[tokio::test]
1116 async fn test_rollback_flow_runtime_if_current_removes_matching_task_only() {
1117 let engine = new_test_engine().await;
1118 let (old_task, _old_shutdown_tx) = new_test_task(42).await;
1119 let (current_task, current_shutdown_tx) = new_test_task(42).await;
1120 let current_task_identity = current_task.clone();
1121
1122 engine
1123 .runtime
1124 .write()
1125 .await
1126 .insert(42, current_task, current_shutdown_tx);
1127
1128 engine.rollback_flow_runtime_if_current(42, &old_task).await;
1129
1130 let registered_task = engine.runtime.read().await.tasks.get(&42).cloned().unwrap();
1131 assert!(Arc::ptr_eq(
1132 ®istered_task.state,
1133 ¤t_task_identity.state
1134 ));
1135 assert!(engine.runtime.read().await.shutdown_txs.contains_key(&42));
1136
1137 engine
1138 .rollback_flow_runtime_if_current(42, ¤t_task_identity)
1139 .await;
1140 assert!(!engine.flow_exist_inner(42).await);
1141 assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1142 }
1143}