Skip to main content

flow/batching_mode/
engine.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Batching mode engine
16
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::sync::Arc;
19use std::time::Duration;
20
21use api::v1::flow::DirtyWindowRequests;
22use catalog::CatalogManagerRef;
23use common_error::ext::BoxedError;
24use common_meta::ddl::create_flow::FlowType;
25use common_meta::key::TableMetadataManagerRef;
26use common_meta::key::flow::FlowMetadataManagerRef;
27use common_meta::key::flow::flow_state::FlowStat;
28use common_meta::key::table_info::{TableInfoManager, TableInfoValue};
29use common_runtime::JoinHandle;
30use common_telemetry::tracing::warn;
31use common_telemetry::{debug, info};
32use common_time::TimeToLive;
33use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
34use datafusion_expr::LogicalPlan;
35use datatypes::prelude::ConcreteDataType;
36use query::QueryEngineRef;
37use session::context::QueryContext;
38use snafu::{OptionExt, ResultExt, ensure};
39use sql::parsers::utils::is_tql;
40use store_api::metric_engine_consts::is_metric_engine_internal_column;
41use store_api::storage::{RegionId, TableId};
42use table::table_reference::TableReference;
43use tokio::sync::{RwLock, oneshot};
44
45use crate::batching_mode::BatchingModeOptions;
46use crate::batching_mode::frontend_client::FrontendClient;
47use crate::batching_mode::task::{BatchingTask, TaskArgs};
48use crate::batching_mode::time_window::{TimeWindowExpr, find_time_window_expr};
49use crate::batching_mode::utils::sql_to_df_plan;
50use crate::engine::{FlowEngine, FlowStatProvider};
51use crate::error::{
52    CreateFlowSnafu, DatafusionSnafu, ExternalSnafu, FlowAlreadyExistSnafu, FlowNotFoundSnafu,
53    InvalidQuerySnafu, TableNotFoundMetaSnafu, UnexpectedSnafu, UnsupportedSnafu,
54};
55use crate::metrics::METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW;
56use crate::{CreateFlowArgs, Error, FlowId, TableName};
57
58/// Batching mode Engine, responsible for driving all the batching mode tasks
59///
60/// TODO(discord9): determine how to configure refresh rate
61pub struct BatchingEngine {
62    runtime: RwLock<FlowRuntimeRegistry>,
63    /// frontend client for insert request
64    pub(crate) frontend_client: Arc<FrontendClient>,
65    flow_metadata_manager: FlowMetadataManagerRef,
66    table_meta: TableMetadataManagerRef,
67    catalog_manager: CatalogManagerRef,
68    query_engine: QueryEngineRef,
69    /// Batching mode options for control how batching mode query works
70    ///
71    pub(crate) batch_opts: Arc<BatchingModeOptions>,
72}
73
74#[derive(Default)]
75struct FlowRuntimeRegistry {
76    tasks: BTreeMap<FlowId, BatchingTask>,
77    shutdown_txs: BTreeMap<FlowId, oneshot::Sender<()>>,
78}
79
80impl FlowRuntimeRegistry {
81    fn insert(
82        &mut self,
83        flow_id: FlowId,
84        task: BatchingTask,
85        shutdown_tx: oneshot::Sender<()>,
86    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
87        (
88            self.tasks.insert(flow_id, task),
89            self.shutdown_txs.insert(flow_id, shutdown_tx),
90        )
91    }
92
93    fn remove(&mut self, flow_id: FlowId) -> Option<(BatchingTask, Option<oneshot::Sender<()>>)> {
94        let task = self.tasks.remove(&flow_id)?;
95        let shutdown_tx = self.shutdown_txs.remove(&flow_id);
96        Some((task, shutdown_tx))
97    }
98
99    fn remove_if_current(
100        &mut self,
101        flow_id: FlowId,
102        task: &BatchingTask,
103    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
104        if self
105            .tasks
106            .get(&flow_id)
107            .is_some_and(|current| Arc::ptr_eq(&current.state, &task.state))
108        {
109            let Some((removed_task, removed_shutdown_tx)) = self.remove(flow_id) else {
110                return (None, None);
111            };
112            (Some(removed_task), removed_shutdown_tx)
113        } else {
114            (None, None)
115        }
116    }
117}
118
119impl BatchingEngine {
120    pub fn new(
121        frontend_client: Arc<FrontendClient>,
122        query_engine: QueryEngineRef,
123        flow_metadata_manager: FlowMetadataManagerRef,
124        table_meta: TableMetadataManagerRef,
125        catalog_manager: CatalogManagerRef,
126        batch_opts: BatchingModeOptions,
127    ) -> Self {
128        Self {
129            runtime: Default::default(),
130            frontend_client,
131            flow_metadata_manager,
132            table_meta,
133            catalog_manager,
134            query_engine,
135            batch_opts: Arc::new(batch_opts),
136        }
137    }
138
139    /// Returns last execution timestamps (millisecond) for all batching flows.
140    pub async fn get_last_exec_time_map(&self) -> BTreeMap<FlowId, i64> {
141        let runtime = self.runtime.read().await;
142        runtime
143            .tasks
144            .iter()
145            .filter_map(|(flow_id, task)| {
146                task.last_execution_time_millis()
147                    .map(|timestamp| (*flow_id, timestamp))
148            })
149            .collect()
150    }
151
152    pub async fn handle_mark_dirty_time_window(
153        &self,
154        reqs: DirtyWindowRequests,
155    ) -> Result<(), Error> {
156        let table_info_mgr = self.table_meta.table_info_manager();
157
158        let mut group_by_table_id: HashMap<u32, Vec<_>> = HashMap::new();
159        for r in reqs.requests {
160            let tid = TableId::from(r.table_id);
161            let entry = group_by_table_id.entry(tid).or_default();
162            entry.extend(r.timestamps);
163        }
164        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
165        let table_infos =
166            table_info_mgr
167                .batch_get(&tids)
168                .await
169                .with_context(|_| TableNotFoundMetaSnafu {
170                    msg: format!("Failed to get table info for table ids: {:?}", tids),
171                })?;
172
173        let group_by_table_name = group_by_table_id
174            .into_iter()
175            .filter_map(|(id, timestamps)| {
176                let table_name = table_infos.get(&id).map(|info| info.table_name());
177                let Some(table_name) = table_name else {
178                    warn!("Failed to get table infos for table id: {:?}", id);
179                    return None;
180                };
181                let table_name = [
182                    table_name.catalog_name,
183                    table_name.schema_name,
184                    table_name.table_name,
185                ];
186                let schema = &table_infos.get(&id).unwrap().table_info.meta.schema;
187                let time_index_unit = schema.column_schemas()[schema.timestamp_index().unwrap()]
188                    .data_type
189                    .as_timestamp()
190                    .unwrap()
191                    .unit();
192                Some((table_name, (timestamps, time_index_unit)))
193            })
194            .collect::<HashMap<_, _>>();
195
196        let group_by_table_name = Arc::new(group_by_table_name);
197
198        let tasks = self
199            .runtime
200            .read()
201            .await
202            .tasks
203            .values()
204            .cloned()
205            .collect::<Vec<_>>();
206        let mut handles = Vec::new();
207
208        for task in tasks {
209            let src_table_names = &task.config.source_table_names;
210
211            if src_table_names
212                .iter()
213                .all(|name| !group_by_table_name.contains_key(name))
214            {
215                continue;
216            }
217
218            let group_by_table_name = group_by_table_name.clone();
219            let task = task.clone();
220
221            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
222                let src_table_names = &task.config.source_table_names;
223                let mut all_dirty_windows = HashSet::new();
224                let mut is_dirty = false;
225                for src_table_name in src_table_names {
226                    if let Some((timestamps, unit)) = group_by_table_name.get(src_table_name) {
227                        let Some(expr) = &task.config.time_window_expr else {
228                            is_dirty = true;
229                            continue;
230                        };
231                        for timestamp in timestamps {
232                            let align_start = expr
233                                .eval(common_time::Timestamp::new(*timestamp, *unit))?
234                                .0
235                                .context(UnexpectedSnafu {
236                                    reason: "Failed to eval start value",
237                                })?;
238                            all_dirty_windows.insert(align_start);
239                        }
240                    }
241                }
242                let mut state = task.state.write().unwrap();
243                if is_dirty {
244                    state.dirty_time_windows.set_dirty();
245                }
246                let flow_id_label = task.config.flow_id.to_string();
247                for timestamp in all_dirty_windows {
248                    state.dirty_time_windows.add_window(timestamp, None);
249                }
250
251                METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW
252                    .with_label_values(&[&flow_id_label])
253                    .set(state.dirty_time_windows.len() as f64);
254                Ok(())
255            });
256            handles.push(handle);
257        }
258        for handle in handles {
259            match handle.await {
260                Err(e) => {
261                    warn!("Failed to handle inserts: {e}");
262                }
263                Ok(Ok(())) => (),
264                Ok(Err(e)) => {
265                    warn!("Failed to handle inserts: {e}");
266                }
267            }
268        }
269
270        Ok(())
271    }
272
273    pub async fn handle_inserts_inner(
274        &self,
275        request: api::v1::region::InsertRequests,
276    ) -> Result<(), Error> {
277        let table_info_mgr = self.table_meta.table_info_manager();
278        let mut group_by_table_id: HashMap<TableId, Vec<api::v1::Rows>> = HashMap::new();
279
280        for r in request.requests {
281            let tid = RegionId::from(r.region_id).table_id();
282            let entry = group_by_table_id.entry(tid).or_default();
283            if let Some(rows) = r.rows {
284                entry.push(rows);
285            }
286        }
287
288        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
289        let table_infos =
290            table_info_mgr
291                .batch_get(&tids)
292                .await
293                .with_context(|_| TableNotFoundMetaSnafu {
294                    msg: format!("Failed to get table info for table ids: {:?}", tids),
295                })?;
296
297        let missing_tids = tids
298            .iter()
299            .filter(|id| !table_infos.contains_key(id))
300            .collect::<Vec<_>>();
301        if !missing_tids.is_empty() {
302            warn!(
303                "Failed to get all the table info for table ids, expected table ids: {:?}, those table doesn't exist: {:?}",
304                tids, missing_tids
305            );
306        }
307
308        let group_by_table_name = group_by_table_id
309            .into_iter()
310            .filter_map(|(id, rows)| {
311                let table_name = table_infos.get(&id).map(|info| info.table_name());
312                let Some(table_name) = table_name else {
313                    warn!("Failed to get table infos for table id: {:?}", id);
314                    return None;
315                };
316                let table_name = [
317                    table_name.catalog_name,
318                    table_name.schema_name,
319                    table_name.table_name,
320                ];
321                Some((table_name, rows))
322            })
323            .collect::<HashMap<_, _>>();
324
325        let group_by_table_name = Arc::new(group_by_table_name);
326
327        let tasks = self
328            .runtime
329            .read()
330            .await
331            .tasks
332            .values()
333            .cloned()
334            .collect::<Vec<_>>();
335        let mut handles = Vec::new();
336        for task in tasks {
337            let src_table_names = &task.config.source_table_names;
338
339            if src_table_names
340                .iter()
341                .all(|name| !group_by_table_name.contains_key(name))
342            {
343                continue;
344            }
345
346            let group_by_table_name = group_by_table_name.clone();
347            let task = task.clone();
348
349            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
350                let src_table_names = &task.config.source_table_names;
351
352                let mut is_dirty = false;
353
354                for src_table_name in src_table_names {
355                    if let Some(entry) = group_by_table_name.get(src_table_name) {
356                        let Some(expr) = &task.config.time_window_expr else {
357                            is_dirty = true;
358                            continue;
359                        };
360                        let involved_time_windows = expr.handle_rows(entry.clone()).await?;
361                        let mut state = task.state.write().unwrap();
362                        state
363                            .dirty_time_windows
364                            .add_lower_bounds(involved_time_windows.into_iter());
365                    }
366                }
367                if is_dirty {
368                    task.state.write().unwrap().dirty_time_windows.set_dirty();
369                }
370
371                Ok(())
372            });
373            handles.push(handle);
374        }
375
376        for handle in handles {
377            match handle.await {
378                Err(e) => {
379                    warn!("Failed to handle inserts: {e}");
380                }
381                Ok(Ok(())) => (),
382                Ok(Err(e)) => {
383                    warn!("Failed to handle inserts: {e}");
384                }
385            }
386        }
387        Ok(())
388    }
389}
390
391impl FlowStatProvider for BatchingEngine {
392    async fn flow_stat(&self) -> FlowStat {
393        FlowStat {
394            state_size: BTreeMap::new(),
395            last_exec_time_map: self
396                .get_last_exec_time_map()
397                .await
398                .into_iter()
399                .map(|(flow_id, timestamp)| (flow_id as u32, timestamp))
400                .collect(),
401        }
402    }
403}
404
405async fn get_table_name(
406    table_info: &TableInfoManager,
407    table_id: &TableId,
408) -> Result<TableName, Error> {
409    get_table_info(table_info, table_id).await.map(|info| {
410        let name = info.table_name();
411        [name.catalog_name, name.schema_name, name.table_name]
412    })
413}
414
415async fn get_table_info(
416    table_info: &TableInfoManager,
417    table_id: &TableId,
418) -> Result<TableInfoValue, Error> {
419    table_info
420        .get(*table_id)
421        .await
422        .map_err(BoxedError::new)
423        .context(ExternalSnafu)?
424        .with_context(|| UnexpectedSnafu {
425            reason: format!("Table id = {:?}, couldn't found table name", table_id),
426        })
427        .map(|info| info.into_inner())
428}
429
430impl BatchingEngine {
431    pub async fn create_flow_inner(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
432        let CreateFlowArgs {
433            flow_id,
434            sink_table_name,
435            source_table_ids,
436            create_if_not_exists,
437            or_replace,
438            expire_after,
439            eval_interval,
440            comment: _,
441            sql,
442            flow_options,
443            query_ctx,
444        } = args;
445
446        // or replace logic
447        {
448            let is_exist = self.runtime.read().await.tasks.contains_key(&flow_id);
449            match (create_if_not_exists, or_replace, is_exist) {
450                // if replace, ignore that old flow exists
451                (_, true, true) => {
452                    info!("Replacing flow with id={}", flow_id);
453                }
454                (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
455                // already exists, and not replace, return None
456                (true, false, true) => {
457                    info!("Flow with id={} already exists, do nothing", flow_id);
458                    return Ok(None);
459                }
460
461                // continue as normal
462                (_, _, false) => (),
463            }
464        }
465
466        let query_ctx = query_ctx.context({
467            UnexpectedSnafu {
468                reason: "Query context is None".to_string(),
469            }
470        })?;
471        let query_ctx = Arc::new(query_ctx);
472        let is_tql = is_tql(query_ctx.sql_dialect(), &sql)
473            .map_err(BoxedError::new)
474            .context(CreateFlowSnafu { sql: &sql })?;
475
476        // optionally set a eval interval for the flow
477        if eval_interval.is_none() && is_tql {
478            InvalidQuerySnafu {
479                reason: "TQL query requires EVAL INTERVAL to be set".to_string(),
480            }
481            .fail()?;
482        }
483
484        let flow_type = flow_options.get(FlowType::FLOW_TYPE_KEY);
485
486        ensure!(
487            match flow_type {
488                None => true,
489                Some(ty) if ty == FlowType::BATCHING => true,
490                _ => false,
491            },
492            UnexpectedSnafu {
493                reason: format!("Flow type is not batching nor None, got {flow_type:?}")
494            }
495        );
496
497        let mut source_table_names = Vec::with_capacity(2);
498        for src_id in source_table_ids {
499            // also check table option to see if ttl!=instant
500            let table_name = get_table_name(self.table_meta.table_info_manager(), &src_id).await?;
501            let table_info = get_table_info(self.table_meta.table_info_manager(), &src_id).await?;
502            ensure!(
503                table_info.table_info.meta.options.ttl != Some(TimeToLive::Instant),
504                UnsupportedSnafu {
505                    reason: format!(
506                        "Source table `{}`(id={}) has instant TTL, Instant TTL is not supported under batching mode. Consider using a TTL longer than flush interval",
507                        table_name.join("."),
508                        src_id
509                    ),
510                }
511            );
512
513            source_table_names.push(table_name);
514        }
515
516        let (tx, rx) = oneshot::channel();
517
518        let plan = sql_to_df_plan(query_ctx.clone(), self.query_engine.clone(), &sql, true).await?;
519
520        if is_tql {
521            self.check_is_tql_table(&plan, &query_ctx).await?;
522        }
523
524        let phy_expr = if !is_tql {
525            let (column_name, time_window_expr, _, df_schema) = find_time_window_expr(
526                &plan,
527                self.query_engine.engine_state().catalog_manager().clone(),
528                query_ctx.clone(),
529            )
530            .await?;
531            time_window_expr
532                .map(|expr| {
533                    TimeWindowExpr::from_expr(
534                        &expr,
535                        &column_name,
536                        &df_schema,
537                        &self.query_engine.engine_state().session_state(),
538                    )
539                })
540                .transpose()?
541        } else {
542            // tql control by `EVAL INTERVAL`, no need to find time window expr
543            None
544        };
545
546        debug!(
547            "Flow id={}, found time window expr={}",
548            flow_id,
549            phy_expr
550                .as_ref()
551                .map(|phy_expr| phy_expr.to_string())
552                .unwrap_or("None".to_string())
553        );
554
555        let task_args = TaskArgs {
556            flow_id,
557            query: &sql,
558            plan,
559            time_window_expr: phy_expr,
560            expire_after,
561            sink_table_name,
562            source_table_names,
563            query_ctx,
564            catalog_manager: self.catalog_manager.clone(),
565            shutdown_rx: rx,
566            batch_opts: self.batch_opts.clone(),
567            flow_eval_interval: eval_interval.map(|secs| Duration::from_secs(secs as u64)),
568        };
569
570        let task = BatchingTask::try_new(task_args)?;
571
572        let task_inner = task.clone();
573        let engine = self.query_engine.clone();
574        let frontend = self.frontend_client.clone();
575
576        // check execute once first to detect any error early
577        task.check_or_create_sink_table(&engine, &frontend).await?;
578
579        let (start_tx, start_rx) = oneshot::channel();
580
581        // TODO(discord9): use time wheel or what for better
582        let handle = common_runtime::spawn_global(async move {
583            if start_rx.await.is_ok() {
584                task_inner.start_executing_loop(engine, frontend).await;
585            }
586        });
587        task.state.write().unwrap().task_handle = Some(handle);
588        let task_for_rollback = task.clone();
589
590        // Only replace here, not earlier, because we want the old one intact if
591        // something went wrong before this line. Keep the task and shutdown
592        // sender in one registry lock so create/remove can't observe one
593        // without the other.
594        let (replaced_old_task_opt, replaced_old_shutdown_tx) = {
595            let mut runtime = self.runtime.write().await;
596
597            let is_exist = runtime.tasks.contains_key(&flow_id);
598            match (create_if_not_exists, or_replace, is_exist) {
599                (_, true, true) => {
600                    info!(
601                        "Replacing flow with id={} after final registry check",
602                        flow_id
603                    );
604                }
605                (false, false, true) => {
606                    abort_flow_task(flow_id, Some(task), "unregistered");
607                    return FlowAlreadyExistSnafu { id: flow_id }.fail();
608                }
609                (true, false, true) => {
610                    info!(
611                        "Flow with id={} already exists at final registry check, do nothing",
612                        flow_id
613                    );
614                    abort_flow_task(flow_id, Some(task), "unregistered");
615                    return Ok(None);
616                }
617                (_, _, false) => (),
618            }
619
620            runtime.insert(flow_id, task, tx)
621        };
622
623        notify_flow_shutdown(flow_id, replaced_old_shutdown_tx, "replaced");
624        abort_flow_task(flow_id, replaced_old_task_opt, "replaced");
625        if start_tx.send(()).is_err() {
626            self.rollback_flow_runtime_if_current(flow_id, &task_for_rollback)
627                .await;
628            UnexpectedSnafu {
629                reason: format!("Failed to start flow {flow_id} due to task already dropped"),
630            }
631            .fail()?;
632        }
633
634        Ok(Some(flow_id))
635    }
636
637    async fn check_is_tql_table(
638        &self,
639        query: &LogicalPlan,
640        query_ctx: &QueryContext,
641    ) -> Result<(), Error> {
642        struct CollectTableRef {
643            table_refs: HashSet<datafusion_common::TableReference>,
644        }
645
646        impl TreeNodeVisitor<'_> for CollectTableRef {
647            type Node = LogicalPlan;
648            fn f_down(
649                &mut self,
650                node: &Self::Node,
651            ) -> datafusion_common::Result<TreeNodeRecursion> {
652                if let LogicalPlan::TableScan(scan) = node {
653                    self.table_refs.insert(scan.table_name.clone());
654                }
655                Ok(TreeNodeRecursion::Continue)
656            }
657        }
658        let mut table_refs = CollectTableRef {
659            table_refs: HashSet::new(),
660        };
661        query
662            .visit_with_subqueries(&mut table_refs)
663            .context(DatafusionSnafu {
664                context: "Checking if all source tables are TQL tables",
665            })?;
666
667        let default_catalog = query_ctx.current_catalog();
668        let default_schema = query_ctx.current_schema();
669        let default_schema = &default_schema;
670
671        for table_ref in table_refs.table_refs {
672            let table_ref = match &table_ref {
673                datafusion_common::TableReference::Bare { table } => {
674                    TableReference::full(default_catalog, default_schema, table)
675                }
676                datafusion_common::TableReference::Partial { schema, table } => {
677                    TableReference::full(default_catalog, schema, table)
678                }
679                datafusion_common::TableReference::Full {
680                    catalog,
681                    schema,
682                    table,
683                } => TableReference::full(catalog, schema, table),
684            };
685
686            let table_id = self
687                .table_meta
688                .table_name_manager()
689                .get(table_ref.into())
690                .await
691                .map_err(BoxedError::new)
692                .context(ExternalSnafu)?
693                .with_context(|| UnexpectedSnafu {
694                    reason: format!("Failed to get table id for table: {}", table_ref),
695                })?
696                .table_id();
697            let table_info =
698                get_table_info(self.table_meta.table_info_manager(), &table_id).await?;
699            // first check if it's only one f64 value column
700            let value_cols = table_info
701                .table_info
702                .meta
703                .schema
704                .column_schemas()
705                .iter()
706                .filter(|col| col.data_type == ConcreteDataType::float64_datatype())
707                .collect::<Vec<_>>();
708            ensure!(
709                value_cols.len() == 1,
710                InvalidQuerySnafu {
711                    reason: format!(
712                        "TQL query only supports one f64 value column, table `{}`(id={}) has {} f64 value columns, columns are: {:?}",
713                        table_ref,
714                        table_id,
715                        value_cols.len(),
716                        value_cols
717                    ),
718                }
719            );
720            // TODO(discord9): do need to check rest columns is string and is tag column?
721            let pk_idxs = table_info
722                .table_info
723                .meta
724                .primary_key_indices
725                .iter()
726                .collect::<HashSet<_>>();
727
728            for (idx, col) in table_info
729                .table_info
730                .meta
731                .schema
732                .column_schemas()
733                .iter()
734                .enumerate()
735            {
736                if is_metric_engine_internal_column(&col.name) {
737                    continue;
738                }
739                // three cases:
740                // 1. val column
741                // 2. timestamp column
742                // 3. tag column (string)
743
744                let is_pk: bool = pk_idxs.contains(&&idx);
745
746                ensure!(
747                    col.data_type == ConcreteDataType::float64_datatype()
748                        || col.data_type.is_timestamp()
749                        || (col.data_type == ConcreteDataType::string_datatype() && is_pk),
750                    InvalidQuerySnafu {
751                        reason: format!(
752                            "TQL query only supports f64 value column, timestamp column and string tag columns, table `{}`(id={}) has column `{}` with type {:?} which is not supported",
753                            table_ref, table_id, col.name, col.data_type
754                        ),
755                    }
756                );
757            }
758        }
759        Ok(())
760    }
761
762    pub async fn remove_flow_inner(&self, flow_id: FlowId) -> Result<(), Error> {
763        let (task, shutdown_tx) = {
764            let mut runtime = self.runtime.write().await;
765            let Some((task, shutdown_tx)) = runtime.remove(flow_id) else {
766                warn!("Flow {flow_id} not found in tasks");
767                FlowNotFoundSnafu { id: flow_id }.fail()?
768            };
769            (task, shutdown_tx)
770        };
771
772        let had_shutdown_tx = notify_flow_shutdown(flow_id, shutdown_tx, "removed");
773        abort_flow_task(flow_id, Some(task), "removed");
774
775        if !had_shutdown_tx {
776            UnexpectedSnafu {
777                reason: format!("Can't found shutdown tx for flow {flow_id}"),
778            }
779            .fail()?
780        }
781
782        Ok(())
783    }
784
785    /// Only flush the dirty windows of the flow task with given flow id, by running the query on it.
786    /// As flush the whole time range is usually prohibitively expensive.
787    pub async fn flush_flow_inner(&self, flow_id: FlowId) -> Result<usize, Error> {
788        debug!("Try flush flow {flow_id}");
789        // need to wait a bit to ensure previous mirror insert is handled
790        // this is only useful for the case when we are flushing the flow right after inserting data into it
791        // TODO(discord9): find a better way to ensure the data is ready, maybe inform flownode from frontend?
792        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
793        let task = self.runtime.read().await.tasks.get(&flow_id).cloned();
794        let task = task.with_context(|| FlowNotFoundSnafu { id: flow_id })?;
795
796        let time_window_size = task
797            .config
798            .time_window_expr
799            .as_ref()
800            .and_then(|expr| *expr.time_window_size());
801
802        let cur_dirty_window_cnt = time_window_size.map(|time_window_size| {
803            task.state
804                .read()
805                .unwrap()
806                .dirty_time_windows
807                .effective_count(&time_window_size)
808        });
809
810        let res = task
811            .gen_exec_once(
812                &self.query_engine,
813                &self.frontend_client,
814                cur_dirty_window_cnt,
815            )
816            .await?;
817
818        let affected_rows = res.map(|(r, _)| r).unwrap_or_default();
819        debug!(
820            "Successfully flush flow {flow_id}, affected rows={}",
821            affected_rows
822        );
823        Ok(affected_rows)
824    }
825
826    /// Determine if the batching mode flow task exists with given flow id
827    pub async fn flow_exist_inner(&self, flow_id: FlowId) -> bool {
828        self.runtime.read().await.tasks.contains_key(&flow_id)
829    }
830
831    async fn rollback_flow_runtime_if_current(&self, flow_id: FlowId, task: &BatchingTask) {
832        let (removed_task, removed_shutdown_tx) = {
833            let mut runtime = self.runtime.write().await;
834            runtime.remove_if_current(flow_id, task)
835        };
836
837        notify_flow_shutdown(flow_id, removed_shutdown_tx, "rolled back");
838        abort_flow_task(flow_id, removed_task, "rolled back");
839    }
840}
841
842fn notify_flow_shutdown(flow_id: FlowId, tx: Option<oneshot::Sender<()>>, action: &str) -> bool {
843    let Some(tx) = tx else {
844        return false;
845    };
846
847    if tx.send(()).is_err() {
848        warn!(
849            "Fail to shutdown {action} flow {flow_id} due to receiver already dropped, maybe flow {flow_id} is already dropped?"
850        );
851    }
852
853    true
854}
855
856fn abort_flow_task(flow_id: FlowId, task: Option<BatchingTask>, action: &str) -> bool {
857    let Some(task) = task else {
858        return false;
859    };
860
861    if let Some(handle) = task.state.write().unwrap().task_handle.take() {
862        handle.abort();
863        debug!("Aborted {action} flow task {flow_id}");
864        return true;
865    }
866
867    false
868}
869
870impl FlowEngine for BatchingEngine {
871    async fn create_flow(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
872        self.create_flow_inner(args).await
873    }
874    async fn remove_flow(&self, flow_id: FlowId) -> Result<(), Error> {
875        self.remove_flow_inner(flow_id).await
876    }
877    async fn flush_flow(&self, flow_id: FlowId) -> Result<usize, Error> {
878        self.flush_flow_inner(flow_id).await
879    }
880    async fn flow_exist(&self, flow_id: FlowId) -> Result<bool, Error> {
881        Ok(self.flow_exist_inner(flow_id).await)
882    }
883    async fn list_flows(&self) -> Result<impl IntoIterator<Item = FlowId>, Error> {
884        Ok(self
885            .runtime
886            .read()
887            .await
888            .tasks
889            .keys()
890            .cloned()
891            .collect::<Vec<_>>())
892    }
893    async fn handle_flow_inserts(
894        &self,
895        request: api::v1::region::InsertRequests,
896    ) -> Result<(), Error> {
897        self.handle_inserts_inner(request).await
898    }
899    async fn handle_mark_window_dirty(
900        &self,
901        req: api::v1::flow::DirtyWindowRequests,
902    ) -> Result<(), Error> {
903        self.handle_mark_dirty_time_window(req).await
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use catalog::memory::new_memory_catalog_manager;
910    use common_meta::key::TableMetadataManager;
911    use common_meta::key::flow::FlowMetadataManager;
912    use common_meta::kv_backend::memory::MemoryKvBackend;
913    use query::options::QueryOptions;
914    use session::context::QueryContext;
915
916    use super::*;
917    use crate::test_utils::create_test_query_engine;
918
919    struct DropNotify(Option<oneshot::Sender<()>>);
920
921    impl Drop for DropNotify {
922        fn drop(&mut self) {
923            if let Some(tx) = self.0.take() {
924                let _ = tx.send(());
925            }
926        }
927    }
928
929    async fn new_test_engine() -> BatchingEngine {
930        let kv_backend = Arc::new(MemoryKvBackend::new());
931        let table_meta = Arc::new(TableMetadataManager::new(kv_backend.clone()));
932        table_meta.init().await.unwrap();
933        let flow_meta = Arc::new(FlowMetadataManager::new(kv_backend));
934        let catalog_manager = new_memory_catalog_manager().unwrap();
935        let query_engine = create_test_query_engine();
936        let (frontend_client, _handler) =
937            FrontendClient::from_empty_grpc_handler(QueryOptions::default());
938
939        BatchingEngine::new(
940            Arc::new(frontend_client),
941            query_engine,
942            flow_meta,
943            table_meta,
944            catalog_manager,
945            BatchingModeOptions::default(),
946        )
947    }
948
949    async fn new_test_task(flow_id: FlowId) -> (BatchingTask, oneshot::Sender<()>) {
950        let query_engine = create_test_query_engine();
951        let ctx = QueryContext::arc();
952        let plan = sql_to_df_plan(
953            ctx.clone(),
954            query_engine.clone(),
955            "SELECT number, ts FROM numbers_with_ts",
956            true,
957        )
958        .await
959        .unwrap();
960        let (tx, rx) = oneshot::channel();
961
962        let task = BatchingTask::try_new(TaskArgs {
963            flow_id,
964            query: "SELECT number, ts FROM numbers_with_ts",
965            plan,
966            time_window_expr: None,
967            expire_after: None,
968            sink_table_name: [
969                "greptime".to_string(),
970                "public".to_string(),
971                "sink".to_string(),
972            ],
973            source_table_names: vec![[
974                "greptime".to_string(),
975                "public".to_string(),
976                "numbers_with_ts".to_string(),
977            ]],
978            query_ctx: ctx,
979            catalog_manager: query_engine.engine_state().catalog_manager().clone(),
980            shutdown_rx: rx,
981            batch_opts: Arc::new(BatchingModeOptions::default()),
982            flow_eval_interval: None,
983        })
984        .unwrap();
985
986        (task, tx)
987    }
988
989    async fn install_abort_observed_handle(task: &BatchingTask) -> oneshot::Receiver<()> {
990        let (drop_tx, drop_rx) = oneshot::channel();
991        let (entered_tx, entered_rx) = oneshot::channel();
992        let handle = tokio::spawn(async move {
993            let _guard = DropNotify(Some(drop_tx));
994            let _ = entered_tx.send(());
995            std::future::pending::<()>().await;
996        });
997        task.state.write().unwrap().task_handle = Some(handle);
998        tokio::time::timeout(Duration::from_secs(1), entered_rx)
999            .await
1000            .expect("test task handle should start")
1001            .expect("test task handle should report start");
1002        drop_rx
1003    }
1004
1005    #[tokio::test]
1006    async fn test_notify_flow_shutdown_sends_signal() {
1007        let (tx, rx) = oneshot::channel();
1008
1009        assert!(notify_flow_shutdown(42, Some(tx), "test"));
1010
1011        rx.await.expect("replaced flow should receive shutdown");
1012    }
1013
1014    #[test]
1015    fn test_notify_flow_shutdown_accepts_missing_sender() {
1016        assert!(!notify_flow_shutdown(42, None, "test"));
1017    }
1018
1019    #[tokio::test]
1020    async fn test_abort_flow_task_aborts_handle() {
1021        let (task, _shutdown_tx) = new_test_task(42).await;
1022        let drop_rx = install_abort_observed_handle(&task).await;
1023
1024        assert!(abort_flow_task(42, Some(task), "test"));
1025
1026        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1027            .await
1028            .expect("aborted task should be dropped")
1029            .expect("drop notifier should fire");
1030    }
1031
1032    #[tokio::test]
1033    async fn test_remove_flow_inner_aborts_registered_task() {
1034        let engine = new_test_engine().await;
1035        let (task, shutdown_tx) = new_test_task(42).await;
1036        let drop_rx = install_abort_observed_handle(&task).await;
1037
1038        engine.runtime.write().await.insert(42, task, shutdown_tx);
1039
1040        engine.remove_flow_inner(42).await.unwrap();
1041
1042        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1043            .await
1044            .expect("removed task should be dropped")
1045            .expect("drop notifier should fire");
1046        assert!(!engine.flow_exist_inner(42).await);
1047        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1048    }
1049
1050    #[tokio::test]
1051    async fn test_or_replace_flow_runtime_replaces_old_handles_and_keeps_new_task() {
1052        let engine = new_test_engine().await;
1053        let (old_task, old_shutdown_tx) = new_test_task(42).await;
1054        let old_task_identity = old_task.clone();
1055        let old_drop_rx = install_abort_observed_handle(&old_task).await;
1056        let (new_task, new_shutdown_tx) = new_test_task(42).await;
1057        let new_task_identity = new_task.clone();
1058
1059        engine
1060            .runtime
1061            .write()
1062            .await
1063            .insert(42, old_task, old_shutdown_tx);
1064        let (replaced_old_task, replaced_old_shutdown_tx) =
1065            engine
1066                .runtime
1067                .write()
1068                .await
1069                .insert(42, new_task, new_shutdown_tx);
1070
1071        let replaced_old_task = replaced_old_task.expect("old task should be returned");
1072        assert!(Arc::ptr_eq(
1073            &replaced_old_task.state,
1074            &old_task_identity.state
1075        ));
1076        assert!(notify_flow_shutdown(
1077            42,
1078            replaced_old_shutdown_tx,
1079            "replaced"
1080        ));
1081        old_task_identity
1082            .state
1083            .write()
1084            .unwrap()
1085            .shutdown_rx
1086            .try_recv()
1087            .expect("old shutdown receiver should receive signal");
1088        assert!(abort_flow_task(42, Some(replaced_old_task), "replaced"));
1089
1090        tokio::time::timeout(Duration::from_secs(1), old_drop_rx)
1091            .await
1092            .expect("replaced task should be dropped")
1093            .expect("drop notifier should fire");
1094
1095        let runtime = engine.runtime.read().await;
1096        assert_eq!(1, runtime.tasks.len());
1097        assert_eq!(1, runtime.shutdown_txs.len());
1098        let registered_task = runtime.tasks.get(&42).expect("new task should remain");
1099        assert!(Arc::ptr_eq(
1100            &registered_task.state,
1101            &new_task_identity.state
1102        ));
1103        assert!(runtime.shutdown_txs.contains_key(&42));
1104        assert!(matches!(
1105            new_task_identity
1106                .state
1107                .write()
1108                .unwrap()
1109                .shutdown_rx
1110                .try_recv(),
1111            Err(oneshot::error::TryRecvError::Empty)
1112        ));
1113    }
1114
1115    #[tokio::test]
1116    async fn test_rollback_flow_runtime_if_current_removes_matching_task_only() {
1117        let engine = new_test_engine().await;
1118        let (old_task, _old_shutdown_tx) = new_test_task(42).await;
1119        let (current_task, current_shutdown_tx) = new_test_task(42).await;
1120        let current_task_identity = current_task.clone();
1121
1122        engine
1123            .runtime
1124            .write()
1125            .await
1126            .insert(42, current_task, current_shutdown_tx);
1127
1128        engine.rollback_flow_runtime_if_current(42, &old_task).await;
1129
1130        let registered_task = engine.runtime.read().await.tasks.get(&42).cloned().unwrap();
1131        assert!(Arc::ptr_eq(
1132            &registered_task.state,
1133            &current_task_identity.state
1134        ));
1135        assert!(engine.runtime.read().await.shutdown_txs.contains_key(&42));
1136
1137        engine
1138            .rollback_flow_runtime_if_current(42, &current_task_identity)
1139            .await;
1140        assert!(!engine.flow_exist_inner(42).await);
1141        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1142    }
1143}