Skip to main content

flow/batching_mode/
engine.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Batching mode engine
16
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::sync::Arc;
19use std::time::Duration;
20
21use api::v1::flow::DirtyWindowRequests;
22use catalog::CatalogManagerRef;
23use common_error::ext::BoxedError;
24use common_meta::ddl::create_flow::{FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY, FlowType};
25use common_meta::key::TableMetadataManagerRef;
26use common_meta::key::flow::FlowMetadataManagerRef;
27use common_meta::key::flow::flow_state::FlowStat;
28use common_meta::key::table_info::{TableInfoManager, TableInfoValue};
29use common_runtime::JoinHandle;
30use common_telemetry::tracing::warn;
31use common_telemetry::{debug, info};
32use common_time::TimeToLive;
33use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
34use datafusion_expr::LogicalPlan;
35use datatypes::prelude::ConcreteDataType;
36use query::QueryEngineRef;
37use session::context::QueryContext;
38use snafu::{OptionExt, ResultExt, ensure};
39use sql::parsers::utils::is_tql;
40use store_api::metric_engine_consts::is_metric_engine_internal_column;
41use store_api::mito_engine_options::APPEND_MODE_KEY;
42use store_api::storage::{RegionId, TableId};
43use table::table_reference::TableReference;
44use tokio::sync::{RwLock, oneshot};
45
46use crate::batching_mode::BatchingModeOptions;
47use crate::batching_mode::frontend_client::FrontendClient;
48use crate::batching_mode::task::{BatchingTask, TaskArgs};
49use crate::batching_mode::time_window::{TimeWindowExpr, find_time_window_expr};
50use crate::batching_mode::utils::sql_to_df_plan;
51use crate::engine::{FlowEngine, FlowStatProvider};
52use crate::error::{
53    CreateFlowSnafu, DatafusionSnafu, ExternalSnafu, FlowAlreadyExistSnafu, FlowNotFoundSnafu,
54    InvalidQuerySnafu, TableNotFoundMetaSnafu, UnexpectedSnafu, UnsupportedSnafu,
55};
56use crate::metrics::METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW;
57use crate::{CreateFlowArgs, Error, FlowId, TableName};
58
59/// Batching mode Engine, responsible for driving all the batching mode tasks
60///
61/// TODO(discord9): determine how to configure refresh rate
62pub struct BatchingEngine {
63    runtime: RwLock<FlowRuntimeRegistry>,
64    /// frontend client for insert request
65    pub(crate) frontend_client: Arc<FrontendClient>,
66    flow_metadata_manager: FlowMetadataManagerRef,
67    table_meta: TableMetadataManagerRef,
68    catalog_manager: CatalogManagerRef,
69    query_engine: QueryEngineRef,
70    /// Batching mode options for control how batching mode query works
71    ///
72    pub(crate) batch_opts: Arc<BatchingModeOptions>,
73}
74
75#[derive(Default)]
76struct FlowRuntimeRegistry {
77    tasks: BTreeMap<FlowId, BatchingTask>,
78    shutdown_txs: BTreeMap<FlowId, oneshot::Sender<()>>,
79}
80
81impl FlowRuntimeRegistry {
82    fn insert(
83        &mut self,
84        flow_id: FlowId,
85        task: BatchingTask,
86        shutdown_tx: oneshot::Sender<()>,
87    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
88        (
89            self.tasks.insert(flow_id, task),
90            self.shutdown_txs.insert(flow_id, shutdown_tx),
91        )
92    }
93
94    fn remove(&mut self, flow_id: FlowId) -> Option<(BatchingTask, Option<oneshot::Sender<()>>)> {
95        let task = self.tasks.remove(&flow_id)?;
96        let shutdown_tx = self.shutdown_txs.remove(&flow_id);
97        Some((task, shutdown_tx))
98    }
99
100    fn remove_if_current(
101        &mut self,
102        flow_id: FlowId,
103        task: &BatchingTask,
104    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
105        if self
106            .tasks
107            .get(&flow_id)
108            .is_some_and(|current| Arc::ptr_eq(&current.state, &task.state))
109        {
110            let Some((removed_task, removed_shutdown_tx)) = self.remove(flow_id) else {
111                return (None, None);
112            };
113            (Some(removed_task), removed_shutdown_tx)
114        } else {
115            (None, None)
116        }
117    }
118}
119
120impl BatchingEngine {
121    pub fn new(
122        frontend_client: Arc<FrontendClient>,
123        query_engine: QueryEngineRef,
124        flow_metadata_manager: FlowMetadataManagerRef,
125        table_meta: TableMetadataManagerRef,
126        catalog_manager: CatalogManagerRef,
127        batch_opts: BatchingModeOptions,
128    ) -> Self {
129        Self {
130            runtime: Default::default(),
131            frontend_client,
132            flow_metadata_manager,
133            table_meta,
134            catalog_manager,
135            query_engine,
136            batch_opts: Arc::new(batch_opts),
137        }
138    }
139
140    /// Returns last execution timestamps (millisecond) for all batching flows.
141    pub async fn get_last_exec_time_map(&self) -> BTreeMap<FlowId, i64> {
142        let runtime = self.runtime.read().await;
143        runtime
144            .tasks
145            .iter()
146            .filter_map(|(flow_id, task)| {
147                task.last_execution_time_millis()
148                    .map(|timestamp| (*flow_id, timestamp))
149            })
150            .collect()
151    }
152
153    pub async fn handle_mark_dirty_time_window(
154        &self,
155        reqs: DirtyWindowRequests,
156    ) -> Result<(), Error> {
157        let table_info_mgr = self.table_meta.table_info_manager();
158
159        let mut group_by_table_id: HashMap<u32, Vec<_>> = HashMap::new();
160        for r in reqs.requests {
161            let tid = TableId::from(r.table_id);
162            let entry = group_by_table_id.entry(tid).or_default();
163            entry.extend(r.timestamps);
164        }
165        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
166        let table_infos =
167            table_info_mgr
168                .batch_get(&tids)
169                .await
170                .with_context(|_| TableNotFoundMetaSnafu {
171                    msg: format!("Failed to get table info for table ids: {:?}", tids),
172                })?;
173
174        let group_by_table_name = group_by_table_id
175            .into_iter()
176            .filter_map(|(id, timestamps)| {
177                let table_name = table_infos.get(&id).map(|info| info.table_name());
178                let Some(table_name) = table_name else {
179                    warn!("Failed to get table infos for table id: {:?}", id);
180                    return None;
181                };
182                let table_name = [
183                    table_name.catalog_name,
184                    table_name.schema_name,
185                    table_name.table_name,
186                ];
187                let schema = &table_infos.get(&id).unwrap().table_info.meta.schema;
188                let time_index_unit = schema.column_schemas()[schema.timestamp_index().unwrap()]
189                    .data_type
190                    .as_timestamp()
191                    .unwrap()
192                    .unit();
193                Some((table_name, (timestamps, time_index_unit)))
194            })
195            .collect::<HashMap<_, _>>();
196
197        let group_by_table_name = Arc::new(group_by_table_name);
198
199        let tasks = self
200            .runtime
201            .read()
202            .await
203            .tasks
204            .values()
205            .cloned()
206            .collect::<Vec<_>>();
207        let mut handles = Vec::new();
208
209        for task in tasks {
210            let src_table_names = &task.config.source_table_names;
211
212            if src_table_names
213                .iter()
214                .all(|name| !group_by_table_name.contains_key(name))
215            {
216                continue;
217            }
218
219            let group_by_table_name = group_by_table_name.clone();
220            let task = task.clone();
221
222            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
223                let src_table_names = &task.config.source_table_names;
224                let mut all_dirty_windows = HashSet::new();
225                let mut is_dirty = false;
226                for src_table_name in src_table_names {
227                    if let Some((timestamps, unit)) = group_by_table_name.get(src_table_name) {
228                        let Some(expr) = &task.config.time_window_expr else {
229                            is_dirty = true;
230                            continue;
231                        };
232                        for timestamp in timestamps {
233                            let align_start = expr
234                                .eval(common_time::Timestamp::new(*timestamp, *unit))?
235                                .0
236                                .context(UnexpectedSnafu {
237                                    reason: "Failed to eval start value",
238                                })?;
239                            all_dirty_windows.insert(align_start);
240                        }
241                    }
242                }
243                let mut state = task.state.write().unwrap();
244                if is_dirty {
245                    state.dirty_time_windows.set_dirty();
246                }
247                let flow_id_label = task.config.flow_id.to_string();
248                for timestamp in all_dirty_windows {
249                    state.dirty_time_windows.add_window(timestamp, None);
250                }
251
252                METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW
253                    .with_label_values(&[&flow_id_label])
254                    .set(state.dirty_time_windows.len() as f64);
255                Ok(())
256            });
257            handles.push(handle);
258        }
259        for handle in handles {
260            match handle.await {
261                Err(e) => {
262                    warn!("Failed to handle inserts: {e}");
263                }
264                Ok(Ok(())) => (),
265                Ok(Err(e)) => {
266                    warn!("Failed to handle inserts: {e}");
267                }
268            }
269        }
270
271        Ok(())
272    }
273
274    pub async fn handle_inserts_inner(
275        &self,
276        request: api::v1::region::InsertRequests,
277    ) -> Result<(), Error> {
278        let table_info_mgr = self.table_meta.table_info_manager();
279        let mut group_by_table_id: HashMap<TableId, Vec<api::v1::Rows>> = HashMap::new();
280
281        for r in request.requests {
282            let tid = RegionId::from(r.region_id).table_id();
283            let entry = group_by_table_id.entry(tid).or_default();
284            if let Some(rows) = r.rows {
285                entry.push(rows);
286            }
287        }
288
289        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
290        let table_infos =
291            table_info_mgr
292                .batch_get(&tids)
293                .await
294                .with_context(|_| TableNotFoundMetaSnafu {
295                    msg: format!("Failed to get table info for table ids: {:?}", tids),
296                })?;
297
298        let missing_tids = tids
299            .iter()
300            .filter(|id| !table_infos.contains_key(id))
301            .collect::<Vec<_>>();
302        if !missing_tids.is_empty() {
303            warn!(
304                "Failed to get all the table info for table ids, expected table ids: {:?}, those table doesn't exist: {:?}",
305                tids, missing_tids
306            );
307        }
308
309        let group_by_table_name = group_by_table_id
310            .into_iter()
311            .filter_map(|(id, rows)| {
312                let table_name = table_infos.get(&id).map(|info| info.table_name());
313                let Some(table_name) = table_name else {
314                    warn!("Failed to get table infos for table id: {:?}", id);
315                    return None;
316                };
317                let table_name = [
318                    table_name.catalog_name,
319                    table_name.schema_name,
320                    table_name.table_name,
321                ];
322                Some((table_name, rows))
323            })
324            .collect::<HashMap<_, _>>();
325
326        let group_by_table_name = Arc::new(group_by_table_name);
327
328        let tasks = self
329            .runtime
330            .read()
331            .await
332            .tasks
333            .values()
334            .cloned()
335            .collect::<Vec<_>>();
336        let mut handles = Vec::new();
337        for task in tasks {
338            let src_table_names = &task.config.source_table_names;
339
340            if src_table_names
341                .iter()
342                .all(|name| !group_by_table_name.contains_key(name))
343            {
344                continue;
345            }
346
347            let group_by_table_name = group_by_table_name.clone();
348            let task = task.clone();
349
350            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
351                let src_table_names = &task.config.source_table_names;
352
353                let mut is_dirty = false;
354
355                for src_table_name in src_table_names {
356                    if let Some(entry) = group_by_table_name.get(src_table_name) {
357                        let Some(expr) = &task.config.time_window_expr else {
358                            is_dirty = true;
359                            continue;
360                        };
361                        let involved_time_windows = expr.handle_rows(entry.clone()).await?;
362                        let mut state = task.state.write().unwrap();
363                        state
364                            .dirty_time_windows
365                            .add_lower_bounds(involved_time_windows.into_iter());
366                    }
367                }
368                if is_dirty {
369                    task.state.write().unwrap().dirty_time_windows.set_dirty();
370                }
371
372                Ok(())
373            });
374            handles.push(handle);
375        }
376
377        for handle in handles {
378            match handle.await {
379                Err(e) => {
380                    warn!("Failed to handle inserts: {e}");
381                }
382                Ok(Ok(())) => (),
383                Ok(Err(e)) => {
384                    warn!("Failed to handle inserts: {e}");
385                }
386            }
387        }
388        Ok(())
389    }
390}
391
392impl FlowStatProvider for BatchingEngine {
393    async fn flow_stat(&self) -> FlowStat {
394        FlowStat {
395            state_size: BTreeMap::new(),
396            last_exec_time_map: self
397                .get_last_exec_time_map()
398                .await
399                .into_iter()
400                .map(|(flow_id, timestamp)| (flow_id as u32, timestamp))
401                .collect(),
402        }
403    }
404}
405
406async fn get_table_name(
407    table_info: &TableInfoManager,
408    table_id: &TableId,
409) -> Result<TableName, Error> {
410    get_table_info(table_info, table_id).await.map(|info| {
411        let name = info.table_name();
412        [name.catalog_name, name.schema_name, name.table_name]
413    })
414}
415
416async fn get_table_info(
417    table_info: &TableInfoManager,
418    table_id: &TableId,
419) -> Result<TableInfoValue, Error> {
420    table_info
421        .get(*table_id)
422        .await
423        .map_err(BoxedError::new)
424        .context(ExternalSnafu)?
425        .with_context(|| UnexpectedSnafu {
426            reason: format!("Table id = {:?}, couldn't found table name", table_id),
427        })
428        .map(|info| info.into_inner())
429}
430
431impl BatchingEngine {
432    fn batch_opts_for_flow_options(
433        &self,
434        flow_options: &HashMap<String, String>,
435    ) -> Result<Arc<BatchingModeOptions>, Error> {
436        let mut batch_opts = (*self.batch_opts).clone();
437        if let Some(enable_incremental_read) =
438            flow_options.get(FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY)
439        {
440            batch_opts.experimental_enable_incremental_read = enable_incremental_read
441                .parse::<bool>()
442                .map_err(|_| {
443                    InvalidQuerySnafu {
444                        reason: format!(
445                            "Invalid flow option {FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY}: {enable_incremental_read}"
446                        ),
447                    }
448                    .build()
449                })?;
450        }
451
452        Ok(Arc::new(batch_opts))
453    }
454
455    fn table_options_enable_append_mode(extra_options: &HashMap<String, String>) -> bool {
456        extra_options
457            .get(APPEND_MODE_KEY)
458            .is_some_and(|value| value.eq_ignore_ascii_case("true"))
459    }
460
461    fn ensure_incremental_source_append_only(
462        batch_opts: &BatchingModeOptions,
463        table_name: &[String; 3],
464        extra_options: &HashMap<String, String>,
465    ) -> Result<(), Error> {
466        if batch_opts.experimental_enable_incremental_read {
467            ensure!(
468                Self::table_options_enable_append_mode(extra_options),
469                UnsupportedSnafu {
470                    reason: format!(
471                        "Flow incremental read requires append-only source table, but source table `{}` is not append-only. Consider setting append_mode='true' on the source table or disabling experimental_enable_incremental_read",
472                        table_name.join(".")
473                    ),
474                }
475            );
476        }
477
478        Ok(())
479    }
480
481    pub async fn create_flow_inner(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
482        let CreateFlowArgs {
483            flow_id,
484            sink_table_name,
485            source_table_ids,
486            create_if_not_exists,
487            or_replace,
488            expire_after,
489            eval_interval,
490            comment: _,
491            sql,
492            flow_options,
493            query_ctx,
494        } = args;
495
496        // or replace logic
497        {
498            let is_exist = self.runtime.read().await.tasks.contains_key(&flow_id);
499            match (create_if_not_exists, or_replace, is_exist) {
500                // if replace, ignore that old flow exists
501                (_, true, true) => {
502                    info!("Replacing flow with id={}", flow_id);
503                }
504                (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
505                // already exists, and not replace, return None
506                (true, false, true) => {
507                    info!("Flow with id={} already exists, do nothing", flow_id);
508                    return Ok(None);
509                }
510
511                // continue as normal
512                (_, _, false) => (),
513            }
514        }
515
516        let query_ctx = query_ctx.context({
517            UnexpectedSnafu {
518                reason: "Query context is None".to_string(),
519            }
520        })?;
521        let query_ctx = Arc::new(query_ctx);
522        let is_tql = is_tql(query_ctx.sql_dialect(), &sql)
523            .map_err(BoxedError::new)
524            .context(CreateFlowSnafu { sql: &sql })?;
525
526        // optionally set a eval interval for the flow
527        if eval_interval.is_none() && is_tql {
528            InvalidQuerySnafu {
529                reason: "TQL query requires EVAL INTERVAL to be set".to_string(),
530            }
531            .fail()?;
532        }
533
534        let flow_type = flow_options.get(FlowType::FLOW_TYPE_KEY);
535
536        ensure!(
537            match flow_type {
538                None => true,
539                Some(ty) if ty == FlowType::BATCHING => true,
540                _ => false,
541            },
542            UnexpectedSnafu {
543                reason: format!("Flow type is not batching nor None, got {flow_type:?}")
544            }
545        );
546
547        let batch_opts = self.batch_opts_for_flow_options(&flow_options)?;
548
549        let mut source_table_names = Vec::with_capacity(2);
550        for src_id in source_table_ids {
551            // also check table option to see if ttl!=instant
552            let table_name = get_table_name(self.table_meta.table_info_manager(), &src_id).await?;
553            let table_info = get_table_info(self.table_meta.table_info_manager(), &src_id).await?;
554            ensure!(
555                table_info.table_info.meta.options.ttl != Some(TimeToLive::Instant),
556                UnsupportedSnafu {
557                    reason: format!(
558                        "Source table `{}`(id={}) has instant TTL, Instant TTL is not supported under batching mode. Consider using a TTL longer than flush interval",
559                        table_name.join("."),
560                        src_id
561                    ),
562                }
563            );
564            Self::ensure_incremental_source_append_only(
565                &batch_opts,
566                &table_name,
567                &table_info.table_info.meta.options.extra_options,
568            )?;
569
570            source_table_names.push(table_name);
571        }
572
573        let (tx, rx) = oneshot::channel();
574
575        let plan = sql_to_df_plan(query_ctx.clone(), self.query_engine.clone(), &sql, true).await?;
576
577        if is_tql {
578            self.check_is_tql_table(&plan, &query_ctx).await?;
579        }
580
581        let phy_expr = if !is_tql {
582            let (column_name, time_window_expr, _, df_schema) = find_time_window_expr(
583                &plan,
584                self.query_engine.engine_state().catalog_manager().clone(),
585                query_ctx.clone(),
586            )
587            .await?;
588            time_window_expr
589                .map(|expr| {
590                    TimeWindowExpr::from_expr(
591                        &expr,
592                        &column_name,
593                        &df_schema,
594                        &self.query_engine.engine_state().session_state(),
595                    )
596                })
597                .transpose()?
598        } else {
599            // tql control by `EVAL INTERVAL`, no need to find time window expr
600            None
601        };
602
603        debug!(
604            "Flow id={}, found time window expr={}",
605            flow_id,
606            phy_expr
607                .as_ref()
608                .map(|phy_expr| phy_expr.to_string())
609                .unwrap_or("None".to_string())
610        );
611
612        let task_args = TaskArgs {
613            flow_id,
614            query: &sql,
615            plan,
616            time_window_expr: phy_expr,
617            expire_after,
618            sink_table_name,
619            source_table_names,
620            query_ctx,
621            catalog_manager: self.catalog_manager.clone(),
622            shutdown_rx: rx,
623            batch_opts,
624            flow_eval_interval: eval_interval.map(|secs| Duration::from_secs(secs as u64)),
625        };
626
627        let task = BatchingTask::try_new(task_args)?;
628
629        let task_inner = task.clone();
630        let engine = self.query_engine.clone();
631        let frontend = self.frontend_client.clone();
632
633        // check execute once first to detect any error early
634        task.check_or_create_sink_table(&engine, &frontend).await?;
635
636        let (start_tx, start_rx) = oneshot::channel();
637
638        // TODO(discord9): use time wheel or what for better
639        let handle = common_runtime::spawn_global(async move {
640            if start_rx.await.is_ok() {
641                task_inner.start_executing_loop(engine, frontend).await;
642            }
643        });
644        task.state.write().unwrap().task_handle = Some(handle);
645        let task_for_rollback = task.clone();
646
647        // Only replace here, not earlier, because we want the old one intact if
648        // something went wrong before this line. Keep the task and shutdown
649        // sender in one registry lock so create/remove can't observe one
650        // without the other.
651        let (replaced_old_task_opt, replaced_old_shutdown_tx) = {
652            let mut runtime = self.runtime.write().await;
653
654            let is_exist = runtime.tasks.contains_key(&flow_id);
655            match (create_if_not_exists, or_replace, is_exist) {
656                (_, true, true) => {
657                    info!(
658                        "Replacing flow with id={} after final registry check",
659                        flow_id
660                    );
661                }
662                (false, false, true) => {
663                    abort_flow_task(flow_id, Some(task), "unregistered");
664                    return FlowAlreadyExistSnafu { id: flow_id }.fail();
665                }
666                (true, false, true) => {
667                    info!(
668                        "Flow with id={} already exists at final registry check, do nothing",
669                        flow_id
670                    );
671                    abort_flow_task(flow_id, Some(task), "unregistered");
672                    return Ok(None);
673                }
674                (_, _, false) => (),
675            }
676
677            runtime.insert(flow_id, task, tx)
678        };
679
680        notify_flow_shutdown(flow_id, replaced_old_shutdown_tx, "replaced");
681        abort_flow_task(flow_id, replaced_old_task_opt, "replaced");
682        if start_tx.send(()).is_err() {
683            self.rollback_flow_runtime_if_current(flow_id, &task_for_rollback)
684                .await;
685            UnexpectedSnafu {
686                reason: format!("Failed to start flow {flow_id} due to task already dropped"),
687            }
688            .fail()?;
689        }
690
691        Ok(Some(flow_id))
692    }
693
694    async fn check_is_tql_table(
695        &self,
696        query: &LogicalPlan,
697        query_ctx: &QueryContext,
698    ) -> Result<(), Error> {
699        struct CollectTableRef {
700            table_refs: HashSet<datafusion_common::TableReference>,
701        }
702
703        impl TreeNodeVisitor<'_> for CollectTableRef {
704            type Node = LogicalPlan;
705            fn f_down(
706                &mut self,
707                node: &Self::Node,
708            ) -> datafusion_common::Result<TreeNodeRecursion> {
709                if let LogicalPlan::TableScan(scan) = node {
710                    self.table_refs.insert(scan.table_name.clone());
711                }
712                Ok(TreeNodeRecursion::Continue)
713            }
714        }
715        let mut table_refs = CollectTableRef {
716            table_refs: HashSet::new(),
717        };
718        query
719            .visit_with_subqueries(&mut table_refs)
720            .context(DatafusionSnafu {
721                context: "Checking if all source tables are TQL tables",
722            })?;
723
724        let default_catalog = query_ctx.current_catalog();
725        let default_schema = query_ctx.current_schema();
726        let default_schema = &default_schema;
727
728        for table_ref in table_refs.table_refs {
729            let table_ref = match &table_ref {
730                datafusion_common::TableReference::Bare { table } => {
731                    TableReference::full(default_catalog, default_schema, table)
732                }
733                datafusion_common::TableReference::Partial { schema, table } => {
734                    TableReference::full(default_catalog, schema, table)
735                }
736                datafusion_common::TableReference::Full {
737                    catalog,
738                    schema,
739                    table,
740                } => TableReference::full(catalog, schema, table),
741            };
742
743            let table_id = self
744                .table_meta
745                .table_name_manager()
746                .get(table_ref.into())
747                .await
748                .map_err(BoxedError::new)
749                .context(ExternalSnafu)?
750                .with_context(|| UnexpectedSnafu {
751                    reason: format!("Failed to get table id for table: {}", table_ref),
752                })?
753                .table_id();
754            let table_info =
755                get_table_info(self.table_meta.table_info_manager(), &table_id).await?;
756            // first check if it's only one f64 value column
757            let value_cols = table_info
758                .table_info
759                .meta
760                .schema
761                .column_schemas()
762                .iter()
763                .filter(|col| col.data_type == ConcreteDataType::float64_datatype())
764                .collect::<Vec<_>>();
765            ensure!(
766                value_cols.len() == 1,
767                InvalidQuerySnafu {
768                    reason: format!(
769                        "TQL query only supports one f64 value column, table `{}`(id={}) has {} f64 value columns, columns are: {:?}",
770                        table_ref,
771                        table_id,
772                        value_cols.len(),
773                        value_cols
774                    ),
775                }
776            );
777            // TODO(discord9): do need to check rest columns is string and is tag column?
778            let pk_idxs = table_info
779                .table_info
780                .meta
781                .primary_key_indices
782                .iter()
783                .collect::<HashSet<_>>();
784
785            for (idx, col) in table_info
786                .table_info
787                .meta
788                .schema
789                .column_schemas()
790                .iter()
791                .enumerate()
792            {
793                if is_metric_engine_internal_column(&col.name) {
794                    continue;
795                }
796                // three cases:
797                // 1. val column
798                // 2. timestamp column
799                // 3. tag column (string)
800
801                let is_pk: bool = pk_idxs.contains(&&idx);
802
803                ensure!(
804                    col.data_type == ConcreteDataType::float64_datatype()
805                        || col.data_type.is_timestamp()
806                        || (col.data_type == ConcreteDataType::string_datatype() && is_pk),
807                    InvalidQuerySnafu {
808                        reason: format!(
809                            "TQL query only supports f64 value column, timestamp column and string tag columns, table `{}`(id={}) has column `{}` with type {:?} which is not supported",
810                            table_ref, table_id, col.name, col.data_type
811                        ),
812                    }
813                );
814            }
815        }
816        Ok(())
817    }
818
819    pub async fn remove_flow_inner(&self, flow_id: FlowId) -> Result<(), Error> {
820        let (task, shutdown_tx) = {
821            let mut runtime = self.runtime.write().await;
822            let Some((task, shutdown_tx)) = runtime.remove(flow_id) else {
823                warn!("Flow {flow_id} not found in tasks");
824                FlowNotFoundSnafu { id: flow_id }.fail()?
825            };
826            (task, shutdown_tx)
827        };
828
829        let had_shutdown_tx = notify_flow_shutdown(flow_id, shutdown_tx, "removed");
830        abort_flow_task(flow_id, Some(task), "removed");
831
832        if !had_shutdown_tx {
833            UnexpectedSnafu {
834                reason: format!("Can't found shutdown tx for flow {flow_id}"),
835            }
836            .fail()?
837        }
838
839        Ok(())
840    }
841
842    /// Only flush the dirty windows of the flow task with given flow id, by running the query on it.
843    /// As flush the whole time range is usually prohibitively expensive.
844    pub async fn flush_flow_inner(&self, flow_id: FlowId) -> Result<usize, Error> {
845        debug!("Try flush flow {flow_id}");
846        // need to wait a bit to ensure previous mirror insert is handled
847        // this is only useful for the case when we are flushing the flow right after inserting data into it
848        // TODO(discord9): find a better way to ensure the data is ready, maybe inform flownode from frontend?
849        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
850        let task = self.runtime.read().await.tasks.get(&flow_id).cloned();
851        let task = task.with_context(|| FlowNotFoundSnafu { id: flow_id })?;
852
853        let time_window_size = task
854            .config
855            .time_window_expr
856            .as_ref()
857            .and_then(|expr| *expr.time_window_size());
858
859        let cur_dirty_window_cnt = time_window_size.map(|time_window_size| {
860            task.state
861                .read()
862                .unwrap()
863                .dirty_time_windows
864                .effective_count(&time_window_size)
865        });
866
867        let res = task
868            .execute_once_serialized(
869                &self.query_engine,
870                &self.frontend_client,
871                cur_dirty_window_cnt,
872            )
873            .await?;
874
875        let affected_rows = res.map(|(r, _)| r).unwrap_or_default();
876        debug!(
877            "Successfully flush flow {flow_id}, affected rows={}",
878            affected_rows
879        );
880        Ok(affected_rows)
881    }
882
883    /// Determine if the batching mode flow task exists with given flow id
884    pub async fn flow_exist_inner(&self, flow_id: FlowId) -> bool {
885        self.runtime.read().await.tasks.contains_key(&flow_id)
886    }
887
888    async fn rollback_flow_runtime_if_current(&self, flow_id: FlowId, task: &BatchingTask) {
889        let (removed_task, removed_shutdown_tx) = {
890            let mut runtime = self.runtime.write().await;
891            runtime.remove_if_current(flow_id, task)
892        };
893
894        notify_flow_shutdown(flow_id, removed_shutdown_tx, "rolled back");
895        abort_flow_task(flow_id, removed_task, "rolled back");
896    }
897}
898
899fn notify_flow_shutdown(flow_id: FlowId, tx: Option<oneshot::Sender<()>>, action: &str) -> bool {
900    let Some(tx) = tx else {
901        return false;
902    };
903
904    if tx.send(()).is_err() {
905        warn!(
906            "Fail to shutdown {action} flow {flow_id} due to receiver already dropped, maybe flow {flow_id} is already dropped?"
907        );
908    }
909
910    true
911}
912
913fn abort_flow_task(flow_id: FlowId, task: Option<BatchingTask>, action: &str) -> bool {
914    let Some(task) = task else {
915        return false;
916    };
917
918    if let Some(handle) = task.state.write().unwrap().task_handle.take() {
919        handle.abort();
920        debug!("Aborted {action} flow task {flow_id}");
921        return true;
922    }
923
924    false
925}
926
927impl FlowEngine for BatchingEngine {
928    async fn create_flow(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
929        self.create_flow_inner(args).await
930    }
931    async fn remove_flow(&self, flow_id: FlowId) -> Result<(), Error> {
932        self.remove_flow_inner(flow_id).await
933    }
934    async fn flush_flow(&self, flow_id: FlowId) -> Result<usize, Error> {
935        self.flush_flow_inner(flow_id).await
936    }
937    async fn flow_exist(&self, flow_id: FlowId) -> Result<bool, Error> {
938        Ok(self.flow_exist_inner(flow_id).await)
939    }
940    async fn list_flows(&self) -> Result<impl IntoIterator<Item = FlowId>, Error> {
941        Ok(self
942            .runtime
943            .read()
944            .await
945            .tasks
946            .keys()
947            .cloned()
948            .collect::<Vec<_>>())
949    }
950    async fn handle_flow_inserts(
951        &self,
952        request: api::v1::region::InsertRequests,
953    ) -> Result<(), Error> {
954        self.handle_inserts_inner(request).await
955    }
956    async fn handle_mark_window_dirty(
957        &self,
958        req: api::v1::flow::DirtyWindowRequests,
959    ) -> Result<(), Error> {
960        self.handle_mark_dirty_time_window(req).await
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use catalog::memory::new_memory_catalog_manager;
967    use common_meta::key::TableMetadataManager;
968    use common_meta::key::flow::FlowMetadataManager;
969    use common_meta::kv_backend::memory::MemoryKvBackend;
970    use query::options::QueryOptions;
971    use session::context::QueryContext;
972
973    use super::*;
974    use crate::test_utils::create_test_query_engine;
975
976    struct DropNotify(Option<oneshot::Sender<()>>);
977
978    impl Drop for DropNotify {
979        fn drop(&mut self) {
980            if let Some(tx) = self.0.take() {
981                let _ = tx.send(());
982            }
983        }
984    }
985
986    async fn new_test_engine() -> BatchingEngine {
987        let kv_backend = Arc::new(MemoryKvBackend::new());
988        let table_meta = Arc::new(TableMetadataManager::new(kv_backend.clone()));
989        table_meta.init().await.unwrap();
990        let flow_meta = Arc::new(FlowMetadataManager::new(kv_backend));
991        let catalog_manager = new_memory_catalog_manager().unwrap();
992        let query_engine = create_test_query_engine();
993        let (frontend_client, _handler) =
994            FrontendClient::from_empty_grpc_handler(QueryOptions::default());
995
996        BatchingEngine::new(
997            Arc::new(frontend_client),
998            query_engine,
999            flow_meta,
1000            table_meta,
1001            catalog_manager,
1002            BatchingModeOptions::default(),
1003        )
1004    }
1005
1006    #[tokio::test]
1007    async fn test_flow_option_overrides_incremental_read_switch() {
1008        let engine = new_test_engine().await;
1009
1010        let default_opts = engine.batch_opts_for_flow_options(&HashMap::new()).unwrap();
1011        assert!(!default_opts.experimental_enable_incremental_read);
1012
1013        let enabled_opts = engine
1014            .batch_opts_for_flow_options(&HashMap::from([(
1015                FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY.to_string(),
1016                "true".to_string(),
1017            )]))
1018            .unwrap();
1019        assert!(enabled_opts.experimental_enable_incremental_read);
1020    }
1021
1022    #[test]
1023    fn test_table_options_enable_append_mode() {
1024        assert!(!BatchingEngine::table_options_enable_append_mode(
1025            &HashMap::new()
1026        ));
1027        assert!(!BatchingEngine::table_options_enable_append_mode(
1028            &HashMap::from([(APPEND_MODE_KEY.to_string(), "false".to_string())])
1029        ));
1030        assert!(BatchingEngine::table_options_enable_append_mode(
1031            &HashMap::from([(APPEND_MODE_KEY.to_string(), "TRUE".to_string())])
1032        ));
1033    }
1034
1035    #[test]
1036    fn test_incremental_source_append_only_enforcement() {
1037        let table_name = [
1038            "greptime".to_string(),
1039            "public".to_string(),
1040            "numbers".to_string(),
1041        ];
1042        let disabled_opts = BatchingModeOptions::default();
1043        let enabled_opts = BatchingModeOptions {
1044            experimental_enable_incremental_read: true,
1045            ..Default::default()
1046        };
1047        let non_append_options = HashMap::new();
1048        let append_options = HashMap::from([(APPEND_MODE_KEY.to_string(), "true".to_string())]);
1049
1050        BatchingEngine::ensure_incremental_source_append_only(
1051            &disabled_opts,
1052            &table_name,
1053            &non_append_options,
1054        )
1055        .expect("disabled incremental read should not require append-only source");
1056        BatchingEngine::ensure_incremental_source_append_only(
1057            &enabled_opts,
1058            &table_name,
1059            &append_options,
1060        )
1061        .expect("append-only source should be accepted when incremental read is enabled");
1062
1063        let err = BatchingEngine::ensure_incremental_source_append_only(
1064            &enabled_opts,
1065            &table_name,
1066            &non_append_options,
1067        )
1068        .expect_err("non-append source should be rejected when incremental read is enabled");
1069        assert!(
1070            err.to_string()
1071                .contains("Flow incremental read requires append-only source table"),
1072            "{err}"
1073        );
1074    }
1075
1076    async fn new_test_task(flow_id: FlowId) -> (BatchingTask, oneshot::Sender<()>) {
1077        let query_engine = create_test_query_engine();
1078        let ctx = QueryContext::arc();
1079        let plan = sql_to_df_plan(
1080            ctx.clone(),
1081            query_engine.clone(),
1082            "SELECT number, ts FROM numbers_with_ts",
1083            true,
1084        )
1085        .await
1086        .unwrap();
1087        let (tx, rx) = oneshot::channel();
1088
1089        let task = BatchingTask::try_new(TaskArgs {
1090            flow_id,
1091            query: "SELECT number, ts FROM numbers_with_ts",
1092            plan,
1093            time_window_expr: None,
1094            expire_after: None,
1095            sink_table_name: [
1096                "greptime".to_string(),
1097                "public".to_string(),
1098                "sink".to_string(),
1099            ],
1100            source_table_names: vec![[
1101                "greptime".to_string(),
1102                "public".to_string(),
1103                "numbers_with_ts".to_string(),
1104            ]],
1105            query_ctx: ctx,
1106            catalog_manager: query_engine.engine_state().catalog_manager().clone(),
1107            shutdown_rx: rx,
1108            batch_opts: Arc::new(BatchingModeOptions::default()),
1109            flow_eval_interval: None,
1110        })
1111        .unwrap();
1112
1113        (task, tx)
1114    }
1115
1116    async fn install_abort_observed_handle(task: &BatchingTask) -> oneshot::Receiver<()> {
1117        let (drop_tx, drop_rx) = oneshot::channel();
1118        let (entered_tx, entered_rx) = oneshot::channel();
1119        let handle = tokio::spawn(async move {
1120            let _guard = DropNotify(Some(drop_tx));
1121            let _ = entered_tx.send(());
1122            std::future::pending::<()>().await;
1123        });
1124        task.state.write().unwrap().task_handle = Some(handle);
1125        tokio::time::timeout(Duration::from_secs(1), entered_rx)
1126            .await
1127            .expect("test task handle should start")
1128            .expect("test task handle should report start");
1129        drop_rx
1130    }
1131
1132    #[tokio::test]
1133    async fn test_notify_flow_shutdown_sends_signal() {
1134        let (tx, rx) = oneshot::channel();
1135
1136        assert!(notify_flow_shutdown(42, Some(tx), "test"));
1137
1138        rx.await.expect("replaced flow should receive shutdown");
1139    }
1140
1141    #[test]
1142    fn test_notify_flow_shutdown_accepts_missing_sender() {
1143        assert!(!notify_flow_shutdown(42, None, "test"));
1144    }
1145
1146    #[tokio::test]
1147    async fn test_abort_flow_task_aborts_handle() {
1148        let (task, _shutdown_tx) = new_test_task(42).await;
1149        let drop_rx = install_abort_observed_handle(&task).await;
1150
1151        assert!(abort_flow_task(42, Some(task), "test"));
1152
1153        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1154            .await
1155            .expect("aborted task should be dropped")
1156            .expect("drop notifier should fire");
1157    }
1158
1159    #[tokio::test]
1160    async fn test_remove_flow_inner_aborts_registered_task() {
1161        let engine = new_test_engine().await;
1162        let (task, shutdown_tx) = new_test_task(42).await;
1163        let drop_rx = install_abort_observed_handle(&task).await;
1164
1165        engine.runtime.write().await.insert(42, task, shutdown_tx);
1166
1167        engine.remove_flow_inner(42).await.unwrap();
1168
1169        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1170            .await
1171            .expect("removed task should be dropped")
1172            .expect("drop notifier should fire");
1173        assert!(!engine.flow_exist_inner(42).await);
1174        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1175    }
1176
1177    #[tokio::test]
1178    async fn test_or_replace_flow_runtime_replaces_old_handles_and_keeps_new_task() {
1179        let engine = new_test_engine().await;
1180        let (old_task, old_shutdown_tx) = new_test_task(42).await;
1181        let old_task_identity = old_task.clone();
1182        let old_drop_rx = install_abort_observed_handle(&old_task).await;
1183        let (new_task, new_shutdown_tx) = new_test_task(42).await;
1184        let new_task_identity = new_task.clone();
1185
1186        engine
1187            .runtime
1188            .write()
1189            .await
1190            .insert(42, old_task, old_shutdown_tx);
1191        let (replaced_old_task, replaced_old_shutdown_tx) =
1192            engine
1193                .runtime
1194                .write()
1195                .await
1196                .insert(42, new_task, new_shutdown_tx);
1197
1198        let replaced_old_task = replaced_old_task.expect("old task should be returned");
1199        assert!(Arc::ptr_eq(
1200            &replaced_old_task.state,
1201            &old_task_identity.state
1202        ));
1203        assert!(notify_flow_shutdown(
1204            42,
1205            replaced_old_shutdown_tx,
1206            "replaced"
1207        ));
1208        old_task_identity
1209            .state
1210            .write()
1211            .unwrap()
1212            .shutdown_rx
1213            .try_recv()
1214            .expect("old shutdown receiver should receive signal");
1215        assert!(abort_flow_task(42, Some(replaced_old_task), "replaced"));
1216
1217        tokio::time::timeout(Duration::from_secs(1), old_drop_rx)
1218            .await
1219            .expect("replaced task should be dropped")
1220            .expect("drop notifier should fire");
1221
1222        let runtime = engine.runtime.read().await;
1223        assert_eq!(1, runtime.tasks.len());
1224        assert_eq!(1, runtime.shutdown_txs.len());
1225        let registered_task = runtime.tasks.get(&42).expect("new task should remain");
1226        assert!(Arc::ptr_eq(
1227            &registered_task.state,
1228            &new_task_identity.state
1229        ));
1230        assert!(runtime.shutdown_txs.contains_key(&42));
1231        assert!(matches!(
1232            new_task_identity
1233                .state
1234                .write()
1235                .unwrap()
1236                .shutdown_rx
1237                .try_recv(),
1238            Err(oneshot::error::TryRecvError::Empty)
1239        ));
1240    }
1241
1242    #[tokio::test]
1243    async fn test_rollback_flow_runtime_if_current_removes_matching_task_only() {
1244        let engine = new_test_engine().await;
1245        let (old_task, _old_shutdown_tx) = new_test_task(42).await;
1246        let (current_task, current_shutdown_tx) = new_test_task(42).await;
1247        let current_task_identity = current_task.clone();
1248
1249        engine
1250            .runtime
1251            .write()
1252            .await
1253            .insert(42, current_task, current_shutdown_tx);
1254
1255        engine.rollback_flow_runtime_if_current(42, &old_task).await;
1256
1257        let registered_task = engine.runtime.read().await.tasks.get(&42).cloned().unwrap();
1258        assert!(Arc::ptr_eq(
1259            &registered_task.state,
1260            &current_task_identity.state
1261        ));
1262        assert!(engine.runtime.read().await.shutdown_txs.contains_key(&42));
1263
1264        engine
1265            .rollback_flow_runtime_if_current(42, &current_task_identity)
1266            .await;
1267        assert!(!engine.flow_exist_inner(42).await);
1268        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1269    }
1270}