Skip to main content

flow/batching_mode/
engine.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Batching mode engine
16
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::sync::Arc;
19use std::time::Duration;
20
21use api::v1::flow::DirtyWindowRequests;
22use catalog::CatalogManagerRef;
23use common_error::ext::BoxedError;
24use common_meta::ddl::create_flow::{FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY, FlowType};
25use common_meta::key::TableMetadataManagerRef;
26use common_meta::key::flow::FlowMetadataManagerRef;
27use common_meta::key::flow::flow_state::FlowStat;
28use common_meta::key::table_info::{TableInfoManager, TableInfoValue};
29use common_runtime::JoinHandle;
30use common_telemetry::tracing::warn;
31use common_telemetry::{debug, info};
32use common_time::TimeToLive;
33use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
34use datafusion_expr::LogicalPlan;
35use datatypes::prelude::ConcreteDataType;
36use query::QueryEngineRef;
37use session::context::QueryContext;
38use snafu::{OptionExt, ResultExt, ensure};
39use sql::parsers::utils::is_tql;
40use store_api::metric_engine_consts::is_metric_engine_internal_column;
41use store_api::mito_engine_options::APPEND_MODE_KEY;
42use store_api::storage::{RegionId, TableId};
43use table::table_reference::TableReference;
44use tokio::sync::{RwLock, oneshot};
45
46use crate::batching_mode::BatchingModeOptions;
47use crate::batching_mode::frontend_client::FrontendClient;
48use crate::batching_mode::task::{BatchingTask, TaskArgs};
49use crate::batching_mode::time_window::{TimeWindowExpr, find_time_window_expr};
50use crate::batching_mode::utils::sql_to_df_plan;
51use crate::engine::{FlowEngine, FlowStatProvider};
52use crate::error::{
53    CreateFlowSnafu, DatafusionSnafu, ExternalSnafu, FlowAlreadyExistSnafu, FlowNotFoundSnafu,
54    InvalidQuerySnafu, TableNotFoundMetaSnafu, UnexpectedSnafu, UnsupportedSnafu,
55};
56use crate::metrics::METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW;
57use crate::{CreateFlowArgs, Error, FlowId, TableName};
58
59/// Batching mode Engine, responsible for driving all the batching mode tasks
60///
61/// TODO(discord9): determine how to configure refresh rate
62pub struct BatchingEngine {
63    runtime: RwLock<FlowRuntimeRegistry>,
64    /// frontend client for insert request
65    pub(crate) frontend_client: Arc<FrontendClient>,
66    flow_metadata_manager: FlowMetadataManagerRef,
67    table_meta: TableMetadataManagerRef,
68    catalog_manager: CatalogManagerRef,
69    query_engine: QueryEngineRef,
70    /// Batching mode options for control how batching mode query works
71    ///
72    pub(crate) batch_opts: Arc<BatchingModeOptions>,
73}
74
75#[derive(Default)]
76struct FlowRuntimeRegistry {
77    tasks: BTreeMap<FlowId, BatchingTask>,
78    shutdown_txs: BTreeMap<FlowId, oneshot::Sender<()>>,
79}
80
81impl FlowRuntimeRegistry {
82    fn insert(
83        &mut self,
84        flow_id: FlowId,
85        task: BatchingTask,
86        shutdown_tx: oneshot::Sender<()>,
87    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
88        (
89            self.tasks.insert(flow_id, task),
90            self.shutdown_txs.insert(flow_id, shutdown_tx),
91        )
92    }
93
94    fn remove(&mut self, flow_id: FlowId) -> Option<(BatchingTask, Option<oneshot::Sender<()>>)> {
95        let task = self.tasks.remove(&flow_id)?;
96        let shutdown_tx = self.shutdown_txs.remove(&flow_id);
97        Some((task, shutdown_tx))
98    }
99
100    fn remove_if_current(
101        &mut self,
102        flow_id: FlowId,
103        task: &BatchingTask,
104    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
105        if self
106            .tasks
107            .get(&flow_id)
108            .is_some_and(|current| Arc::ptr_eq(&current.state, &task.state))
109        {
110            let Some((removed_task, removed_shutdown_tx)) = self.remove(flow_id) else {
111                return (None, None);
112            };
113            (Some(removed_task), removed_shutdown_tx)
114        } else {
115            (None, None)
116        }
117    }
118}
119
120impl BatchingEngine {
121    pub fn new(
122        frontend_client: Arc<FrontendClient>,
123        query_engine: QueryEngineRef,
124        flow_metadata_manager: FlowMetadataManagerRef,
125        table_meta: TableMetadataManagerRef,
126        catalog_manager: CatalogManagerRef,
127        batch_opts: BatchingModeOptions,
128    ) -> Self {
129        Self {
130            runtime: Default::default(),
131            frontend_client,
132            flow_metadata_manager,
133            table_meta,
134            catalog_manager,
135            query_engine,
136            batch_opts: Arc::new(batch_opts),
137        }
138    }
139
140    /// Returns last execution timestamps (millisecond) for all batching flows.
141    pub async fn get_last_exec_time_map(&self) -> BTreeMap<FlowId, i64> {
142        let runtime = self.runtime.read().await;
143        runtime
144            .tasks
145            .iter()
146            .filter_map(|(flow_id, task)| {
147                task.last_execution_time_millis()
148                    .map(|timestamp| (*flow_id, timestamp))
149            })
150            .collect()
151    }
152
153    pub async fn handle_mark_dirty_time_window(
154        &self,
155        reqs: DirtyWindowRequests,
156    ) -> Result<(), Error> {
157        let table_info_mgr = self.table_meta.table_info_manager();
158
159        let mut group_by_table_id: HashMap<u32, Vec<_>> = HashMap::new();
160        for r in reqs.requests {
161            let tid = TableId::from(r.table_id);
162            let entry = group_by_table_id.entry(tid).or_default();
163            entry.extend(r.timestamps);
164        }
165        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
166        let table_infos =
167            table_info_mgr
168                .batch_get(&tids)
169                .await
170                .with_context(|_| TableNotFoundMetaSnafu {
171                    msg: format!("Failed to get table info for table ids: {:?}", tids),
172                })?;
173
174        let group_by_table_name = group_by_table_id
175            .into_iter()
176            .filter_map(|(id, timestamps)| {
177                let table_name = table_infos.get(&id).map(|info| info.table_name());
178                let Some(table_name) = table_name else {
179                    warn!("Failed to get table infos for table id: {:?}", id);
180                    return None;
181                };
182                let table_name = [
183                    table_name.catalog_name,
184                    table_name.schema_name,
185                    table_name.table_name,
186                ];
187                let schema = &table_infos.get(&id).unwrap().table_info.meta.schema;
188                let time_index_unit = schema.column_schemas()[schema.timestamp_index().unwrap()]
189                    .data_type
190                    .as_timestamp()
191                    .unwrap()
192                    .unit();
193                Some((table_name, (timestamps, time_index_unit)))
194            })
195            .collect::<HashMap<_, _>>();
196
197        let group_by_table_name = Arc::new(group_by_table_name);
198
199        let tasks = self
200            .runtime
201            .read()
202            .await
203            .tasks
204            .values()
205            .cloned()
206            .collect::<Vec<_>>();
207        let mut handles = Vec::new();
208
209        for task in tasks {
210            let src_table_names = &task.config.source_table_names;
211
212            if src_table_names
213                .iter()
214                .all(|name| !group_by_table_name.contains_key(name))
215            {
216                continue;
217            }
218
219            let group_by_table_name = group_by_table_name.clone();
220            let task = task.clone();
221
222            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
223                let src_table_names = &task.config.source_table_names;
224                let mut all_dirty_windows = HashSet::new();
225                let mut is_dirty = false;
226                for src_table_name in src_table_names {
227                    if let Some((timestamps, unit)) = group_by_table_name.get(src_table_name) {
228                        let Some(expr) = &task.config.time_window_expr else {
229                            is_dirty = true;
230                            continue;
231                        };
232                        for timestamp in timestamps {
233                            let align_start = expr
234                                .eval(common_time::Timestamp::new(*timestamp, *unit))?
235                                .0
236                                .context(UnexpectedSnafu {
237                                    reason: "Failed to eval start value",
238                                })?;
239                            all_dirty_windows.insert(align_start);
240                        }
241                    }
242                }
243                let mut state = task.state.write().unwrap();
244                if is_dirty {
245                    state.dirty_time_windows.set_dirty();
246                }
247                let flow_id_label = task.config.flow_id.to_string();
248                for timestamp in all_dirty_windows {
249                    state.dirty_time_windows.add_window(timestamp, None);
250                }
251
252                METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW
253                    .with_label_values(&[&flow_id_label])
254                    .set(state.dirty_time_windows.len() as f64);
255                Ok(())
256            });
257            handles.push(handle);
258        }
259        for handle in handles {
260            match handle.await {
261                Err(e) => {
262                    warn!("Failed to handle inserts: {e}");
263                }
264                Ok(Ok(())) => (),
265                Ok(Err(e)) => {
266                    warn!("Failed to handle inserts: {e}");
267                }
268            }
269        }
270
271        Ok(())
272    }
273
274    pub async fn handle_inserts_inner(
275        &self,
276        request: api::v1::region::InsertRequests,
277    ) -> Result<(), Error> {
278        let table_info_mgr = self.table_meta.table_info_manager();
279        let mut group_by_table_id: HashMap<TableId, Vec<api::v1::Rows>> = HashMap::new();
280
281        for r in request.requests {
282            let tid = RegionId::from(r.region_id).table_id();
283            let entry = group_by_table_id.entry(tid).or_default();
284            if let Some(rows) = r.rows {
285                entry.push(rows);
286            }
287        }
288
289        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
290        let table_infos =
291            table_info_mgr
292                .batch_get(&tids)
293                .await
294                .with_context(|_| TableNotFoundMetaSnafu {
295                    msg: format!("Failed to get table info for table ids: {:?}", tids),
296                })?;
297
298        let missing_tids = tids
299            .iter()
300            .filter(|id| !table_infos.contains_key(id))
301            .collect::<Vec<_>>();
302        if !missing_tids.is_empty() {
303            warn!(
304                "Failed to get all the table info for table ids, expected table ids: {:?}, those table doesn't exist: {:?}",
305                tids, missing_tids
306            );
307        }
308
309        let group_by_table_name = group_by_table_id
310            .into_iter()
311            .filter_map(|(id, rows)| {
312                let table_name = table_infos.get(&id).map(|info| info.table_name());
313                let Some(table_name) = table_name else {
314                    warn!("Failed to get table infos for table id: {:?}", id);
315                    return None;
316                };
317                let table_name = [
318                    table_name.catalog_name,
319                    table_name.schema_name,
320                    table_name.table_name,
321                ];
322                Some((table_name, rows))
323            })
324            .collect::<HashMap<_, _>>();
325
326        let group_by_table_name = Arc::new(group_by_table_name);
327
328        let tasks = self
329            .runtime
330            .read()
331            .await
332            .tasks
333            .values()
334            .cloned()
335            .collect::<Vec<_>>();
336        let mut handles = Vec::new();
337        for task in tasks {
338            let src_table_names = &task.config.source_table_names;
339
340            if src_table_names
341                .iter()
342                .all(|name| !group_by_table_name.contains_key(name))
343            {
344                continue;
345            }
346
347            let group_by_table_name = group_by_table_name.clone();
348            let task = task.clone();
349
350            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
351                let src_table_names = &task.config.source_table_names;
352
353                let mut is_dirty = false;
354
355                for src_table_name in src_table_names {
356                    if let Some(entry) = group_by_table_name.get(src_table_name) {
357                        let Some(expr) = &task.config.time_window_expr else {
358                            is_dirty = true;
359                            continue;
360                        };
361                        let involved_time_windows = expr.handle_rows(entry.clone()).await?;
362                        let mut state = task.state.write().unwrap();
363                        state
364                            .dirty_time_windows
365                            .add_lower_bounds(involved_time_windows.into_iter());
366                    }
367                }
368                if is_dirty {
369                    task.state.write().unwrap().dirty_time_windows.set_dirty();
370                }
371
372                Ok(())
373            });
374            handles.push(handle);
375        }
376
377        for handle in handles {
378            match handle.await {
379                Err(e) => {
380                    warn!("Failed to handle inserts: {e}");
381                }
382                Ok(Ok(())) => (),
383                Ok(Err(e)) => {
384                    warn!("Failed to handle inserts: {e}");
385                }
386            }
387        }
388        Ok(())
389    }
390}
391
392impl FlowStatProvider for BatchingEngine {
393    async fn flow_stat(&self) -> FlowStat {
394        FlowStat {
395            state_size: BTreeMap::new(),
396            last_exec_time_map: self
397                .get_last_exec_time_map()
398                .await
399                .into_iter()
400                .map(|(flow_id, timestamp)| (flow_id as u32, timestamp))
401                .collect(),
402        }
403    }
404}
405
406async fn get_table_name(
407    table_info: &TableInfoManager,
408    table_id: &TableId,
409) -> Result<TableName, Error> {
410    get_table_info(table_info, table_id).await.map(|info| {
411        let name = info.table_name();
412        [name.catalog_name, name.schema_name, name.table_name]
413    })
414}
415
416async fn get_table_info(
417    table_info: &TableInfoManager,
418    table_id: &TableId,
419) -> Result<TableInfoValue, Error> {
420    table_info
421        .get(*table_id)
422        .await
423        .map_err(BoxedError::new)
424        .context(ExternalSnafu)?
425        .with_context(|| UnexpectedSnafu {
426            reason: format!("Table id = {:?}, couldn't found table name", table_id),
427        })
428        .map(|info| info.into_inner())
429}
430
431impl BatchingEngine {
432    fn batch_opts_for_flow_options(
433        &self,
434        flow_options: &HashMap<String, String>,
435    ) -> Result<Arc<BatchingModeOptions>, Error> {
436        let mut batch_opts = (*self.batch_opts).clone();
437        if let Some(enable_incremental_read) =
438            flow_options.get(FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY)
439        {
440            batch_opts.experimental_enable_incremental_read = enable_incremental_read
441                .parse::<bool>()
442                .map_err(|_| {
443                    InvalidQuerySnafu {
444                        reason: format!(
445                            "Invalid flow option {FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY}: {enable_incremental_read}"
446                        ),
447                    }
448                    .build()
449                })?;
450        }
451
452        Ok(Arc::new(batch_opts))
453    }
454
455    fn table_options_enable_append_mode(extra_options: &HashMap<String, String>) -> bool {
456        extra_options
457            .get(APPEND_MODE_KEY)
458            .is_some_and(|value| value.eq_ignore_ascii_case("true"))
459    }
460
461    /// SQL flows without a usable time-window expression can only run as an
462    /// explicit full-query flow, so require `EVAL INTERVAL` at creation time.
463    fn ensure_sql_flow_has_twe_or_eval_interval(
464        eval_interval: Option<i64>,
465        has_time_window_expr: bool,
466    ) -> Result<(), Error> {
467        ensure!(
468            eval_interval.is_some() || has_time_window_expr,
469            InvalidQuerySnafu {
470                reason: "SQL batching flow without a time-window expression must specify EVAL INTERVAL to run as an explicit full-query flow"
471                    .to_string(),
472            }
473        );
474        Ok(())
475    }
476
477    fn ensure_incremental_source_append_only(
478        batch_opts: &BatchingModeOptions,
479        table_name: &[String; 3],
480        extra_options: &HashMap<String, String>,
481    ) -> Result<(), Error> {
482        if batch_opts.experimental_enable_incremental_read {
483            ensure!(
484                Self::table_options_enable_append_mode(extra_options),
485                UnsupportedSnafu {
486                    reason: format!(
487                        "Flow incremental read requires append-only source table, but source table `{}` is not append-only. Consider setting append_mode='true' on the source table or disabling experimental_enable_incremental_read",
488                        table_name.join(".")
489                    ),
490                }
491            );
492        }
493
494        Ok(())
495    }
496
497    pub async fn create_flow_inner(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
498        let CreateFlowArgs {
499            flow_id,
500            sink_table_name,
501            source_table_ids,
502            create_if_not_exists,
503            or_replace,
504            expire_after,
505            eval_interval,
506            comment: _,
507            sql,
508            flow_options,
509            query_ctx,
510        } = args;
511
512        // or replace logic
513        {
514            let is_exist = self.runtime.read().await.tasks.contains_key(&flow_id);
515            match (create_if_not_exists, or_replace, is_exist) {
516                // if replace, ignore that old flow exists
517                (_, true, true) => {
518                    info!("Replacing flow with id={}", flow_id);
519                }
520                (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
521                // already exists, and not replace, return None
522                (true, false, true) => {
523                    info!("Flow with id={} already exists, do nothing", flow_id);
524                    return Ok(None);
525                }
526
527                // continue as normal
528                (_, _, false) => (),
529            }
530        }
531
532        let query_ctx = query_ctx.context({
533            UnexpectedSnafu {
534                reason: "Query context is None".to_string(),
535            }
536        })?;
537        let query_ctx = Arc::new(query_ctx);
538        let is_tql = is_tql(query_ctx.sql_dialect(), &sql)
539            .map_err(BoxedError::new)
540            .context(CreateFlowSnafu { sql: &sql })?;
541
542        // optionally set a eval interval for the flow
543        if eval_interval.is_none() && is_tql {
544            InvalidQuerySnafu {
545                reason: "TQL query requires EVAL INTERVAL to be set".to_string(),
546            }
547            .fail()?;
548        }
549
550        let flow_type = flow_options.get(FlowType::FLOW_TYPE_KEY);
551
552        ensure!(
553            match flow_type {
554                None => true,
555                Some(ty) if ty == FlowType::BATCHING => true,
556                _ => false,
557            },
558            UnexpectedSnafu {
559                reason: format!("Flow type is not batching nor None, got {flow_type:?}")
560            }
561        );
562
563        let batch_opts = self.batch_opts_for_flow_options(&flow_options)?;
564
565        let mut source_table_names = Vec::with_capacity(2);
566        for src_id in source_table_ids {
567            // also check table option to see if ttl!=instant
568            let table_name = get_table_name(self.table_meta.table_info_manager(), &src_id).await?;
569            let table_info = get_table_info(self.table_meta.table_info_manager(), &src_id).await?;
570            ensure!(
571                table_info.table_info.meta.options.ttl != Some(TimeToLive::Instant),
572                UnsupportedSnafu {
573                    reason: format!(
574                        "Source table `{}`(id={}) has instant TTL, Instant TTL is not supported under batching mode. Consider using a TTL longer than flush interval",
575                        table_name.join("."),
576                        src_id
577                    ),
578                }
579            );
580            Self::ensure_incremental_source_append_only(
581                &batch_opts,
582                &table_name,
583                &table_info.table_info.meta.options.extra_options,
584            )?;
585
586            source_table_names.push(table_name);
587        }
588
589        let (tx, rx) = oneshot::channel();
590
591        let plan = sql_to_df_plan(query_ctx.clone(), self.query_engine.clone(), &sql, true).await?;
592
593        if is_tql {
594            self.check_is_tql_table(&plan, &query_ctx).await?;
595        }
596
597        let phy_expr = if !is_tql {
598            let (column_name, time_window_expr, _, df_schema) = find_time_window_expr(
599                &plan,
600                self.query_engine.engine_state().catalog_manager().clone(),
601                query_ctx.clone(),
602            )
603            .await?;
604            time_window_expr
605                .map(|expr| {
606                    TimeWindowExpr::from_expr(
607                        &expr,
608                        &column_name,
609                        &df_schema,
610                        &self.query_engine.engine_state().session_state(),
611                    )
612                })
613                .transpose()?
614        } else {
615            // tql control by `EVAL INTERVAL`, no need to find time window expr
616            None
617        };
618
619        debug!(
620            "Flow id={}, found time window expr={}",
621            flow_id,
622            phy_expr
623                .as_ref()
624                .map(|phy_expr| phy_expr.to_string())
625                .unwrap_or("None".to_string())
626        );
627
628        if !is_tql {
629            Self::ensure_sql_flow_has_twe_or_eval_interval(eval_interval, phy_expr.is_some())?;
630        }
631
632        let task_args = TaskArgs {
633            flow_id,
634            query: &sql,
635            plan,
636            time_window_expr: phy_expr,
637            expire_after,
638            sink_table_name,
639            source_table_names,
640            query_ctx,
641            catalog_manager: self.catalog_manager.clone(),
642            shutdown_rx: rx,
643            batch_opts,
644            flow_eval_interval: eval_interval.map(|secs| Duration::from_secs(secs as u64)),
645        };
646
647        let task = BatchingTask::try_new(task_args)?;
648
649        let task_inner = task.clone();
650        let engine = self.query_engine.clone();
651        let frontend = self.frontend_client.clone();
652
653        // Create sink table if needed, then validate an existing/created sink schema before
654        // spawning the background task. This catches user-created sink schema mismatches at
655        // CREATE FLOW time instead of surfacing them later in the execution loop.
656        task.check_or_create_sink_table(&engine, &frontend).await?;
657        task.validate_sink_table_schema(&engine).await?;
658
659        let (start_tx, start_rx) = oneshot::channel();
660
661        // TODO(discord9): use time wheel or what for better
662        let handle = common_runtime::spawn_global(async move {
663            if start_rx.await.is_ok() {
664                task_inner.start_executing_loop(engine, frontend).await;
665            }
666        });
667        task.state.write().unwrap().task_handle = Some(handle);
668        let task_for_rollback = task.clone();
669
670        // Only replace here, not earlier, because we want the old one intact if
671        // something went wrong before this line. Keep the task and shutdown
672        // sender in one registry lock so create/remove can't observe one
673        // without the other.
674        let (replaced_old_task_opt, replaced_old_shutdown_tx) = {
675            let mut runtime = self.runtime.write().await;
676
677            let is_exist = runtime.tasks.contains_key(&flow_id);
678            match (create_if_not_exists, or_replace, is_exist) {
679                (_, true, true) => {
680                    info!(
681                        "Replacing flow with id={} after final registry check",
682                        flow_id
683                    );
684                }
685                (false, false, true) => {
686                    abort_flow_task(flow_id, Some(task), "unregistered");
687                    return FlowAlreadyExistSnafu { id: flow_id }.fail();
688                }
689                (true, false, true) => {
690                    info!(
691                        "Flow with id={} already exists at final registry check, do nothing",
692                        flow_id
693                    );
694                    abort_flow_task(flow_id, Some(task), "unregistered");
695                    return Ok(None);
696                }
697                (_, _, false) => (),
698            }
699
700            runtime.insert(flow_id, task, tx)
701        };
702
703        notify_flow_shutdown(flow_id, replaced_old_shutdown_tx, "replaced");
704        abort_flow_task(flow_id, replaced_old_task_opt, "replaced");
705        if start_tx.send(()).is_err() {
706            self.rollback_flow_runtime_if_current(flow_id, &task_for_rollback)
707                .await;
708            UnexpectedSnafu {
709                reason: format!("Failed to start flow {flow_id} due to task already dropped"),
710            }
711            .fail()?;
712        }
713
714        Ok(Some(flow_id))
715    }
716
717    async fn check_is_tql_table(
718        &self,
719        query: &LogicalPlan,
720        query_ctx: &QueryContext,
721    ) -> Result<(), Error> {
722        struct CollectTableRef {
723            table_refs: HashSet<datafusion_common::TableReference>,
724        }
725
726        impl TreeNodeVisitor<'_> for CollectTableRef {
727            type Node = LogicalPlan;
728            fn f_down(
729                &mut self,
730                node: &Self::Node,
731            ) -> datafusion_common::Result<TreeNodeRecursion> {
732                if let LogicalPlan::TableScan(scan) = node {
733                    self.table_refs.insert(scan.table_name.clone());
734                }
735                Ok(TreeNodeRecursion::Continue)
736            }
737        }
738        let mut table_refs = CollectTableRef {
739            table_refs: HashSet::new(),
740        };
741        query
742            .visit_with_subqueries(&mut table_refs)
743            .context(DatafusionSnafu {
744                context: "Checking if all source tables are TQL tables",
745            })?;
746
747        let default_catalog = query_ctx.current_catalog();
748        let default_schema = query_ctx.current_schema();
749        let default_schema = &default_schema;
750
751        for table_ref in table_refs.table_refs {
752            let table_ref = match &table_ref {
753                datafusion_common::TableReference::Bare { table } => {
754                    TableReference::full(default_catalog, default_schema, table)
755                }
756                datafusion_common::TableReference::Partial { schema, table } => {
757                    TableReference::full(default_catalog, schema, table)
758                }
759                datafusion_common::TableReference::Full {
760                    catalog,
761                    schema,
762                    table,
763                } => TableReference::full(catalog, schema, table),
764            };
765
766            let table_id = self
767                .table_meta
768                .table_name_manager()
769                .get(table_ref.into())
770                .await
771                .map_err(BoxedError::new)
772                .context(ExternalSnafu)?
773                .with_context(|| UnexpectedSnafu {
774                    reason: format!("Failed to get table id for table: {}", table_ref),
775                })?
776                .table_id();
777            let table_info =
778                get_table_info(self.table_meta.table_info_manager(), &table_id).await?;
779            // first check if it's only one f64 value column
780            let value_cols = table_info
781                .table_info
782                .meta
783                .schema
784                .column_schemas()
785                .iter()
786                .filter(|col| col.data_type == ConcreteDataType::float64_datatype())
787                .collect::<Vec<_>>();
788            ensure!(
789                value_cols.len() == 1,
790                InvalidQuerySnafu {
791                    reason: format!(
792                        "TQL query only supports one f64 value column, table `{}`(id={}) has {} f64 value columns, columns are: {:?}",
793                        table_ref,
794                        table_id,
795                        value_cols.len(),
796                        value_cols
797                    ),
798                }
799            );
800            // TODO(discord9): do need to check rest columns is string and is tag column?
801            let pk_idxs = table_info
802                .table_info
803                .meta
804                .primary_key_indices
805                .iter()
806                .collect::<HashSet<_>>();
807
808            for (idx, col) in table_info
809                .table_info
810                .meta
811                .schema
812                .column_schemas()
813                .iter()
814                .enumerate()
815            {
816                if is_metric_engine_internal_column(&col.name) {
817                    continue;
818                }
819                // three cases:
820                // 1. val column
821                // 2. timestamp column
822                // 3. tag column (string)
823
824                let is_pk: bool = pk_idxs.contains(&&idx);
825
826                ensure!(
827                    col.data_type == ConcreteDataType::float64_datatype()
828                        || col.data_type.is_timestamp()
829                        || (col.data_type == ConcreteDataType::string_datatype() && is_pk),
830                    InvalidQuerySnafu {
831                        reason: format!(
832                            "TQL query only supports f64 value column, timestamp column and string tag columns, table `{}`(id={}) has column `{}` with type {:?} which is not supported",
833                            table_ref, table_id, col.name, col.data_type
834                        ),
835                    }
836                );
837            }
838        }
839        Ok(())
840    }
841
842    pub async fn remove_flow_inner(&self, flow_id: FlowId) -> Result<(), Error> {
843        let (task, shutdown_tx) = {
844            let mut runtime = self.runtime.write().await;
845            let Some((task, shutdown_tx)) = runtime.remove(flow_id) else {
846                warn!("Flow {flow_id} not found in tasks");
847                FlowNotFoundSnafu { id: flow_id }.fail()?
848            };
849            (task, shutdown_tx)
850        };
851
852        let had_shutdown_tx = notify_flow_shutdown(flow_id, shutdown_tx, "removed");
853        abort_flow_task(flow_id, Some(task), "removed");
854
855        if !had_shutdown_tx {
856            UnexpectedSnafu {
857                reason: format!("Can't found shutdown tx for flow {flow_id}"),
858            }
859            .fail()?
860        }
861
862        Ok(())
863    }
864
865    /// Only flush the dirty windows of the flow task with given flow id, by running the query on it.
866    /// As flush the whole time range is usually prohibitively expensive.
867    pub async fn flush_flow_inner(&self, flow_id: FlowId) -> Result<usize, Error> {
868        debug!("Try flush flow {flow_id}");
869        // need to wait a bit to ensure previous mirror insert is handled
870        // this is only useful for the case when we are flushing the flow right after inserting data into it
871        // TODO(discord9): find a better way to ensure the data is ready, maybe inform flownode from frontend?
872        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
873        let task = self.runtime.read().await.tasks.get(&flow_id).cloned();
874        let task = task.with_context(|| FlowNotFoundSnafu { id: flow_id })?;
875
876        let time_window_size = task
877            .config
878            .time_window_expr
879            .as_ref()
880            .and_then(|expr| *expr.time_window_size());
881
882        let cur_dirty_window_cnt = time_window_size.map(|time_window_size| {
883            task.state
884                .read()
885                .unwrap()
886                .dirty_time_windows
887                .effective_count(&time_window_size)
888        });
889
890        let res = task
891            .execute_once_serialized(
892                &self.query_engine,
893                &self.frontend_client,
894                cur_dirty_window_cnt,
895            )
896            .await?;
897
898        let affected_rows = res.map(|(r, _)| r).unwrap_or_default();
899        debug!(
900            "Successfully flush flow {flow_id}, affected rows={}",
901            affected_rows
902        );
903        Ok(affected_rows)
904    }
905
906    /// Determine if the batching mode flow task exists with given flow id
907    pub async fn flow_exist_inner(&self, flow_id: FlowId) -> bool {
908        self.runtime.read().await.tasks.contains_key(&flow_id)
909    }
910
911    async fn rollback_flow_runtime_if_current(&self, flow_id: FlowId, task: &BatchingTask) {
912        let (removed_task, removed_shutdown_tx) = {
913            let mut runtime = self.runtime.write().await;
914            runtime.remove_if_current(flow_id, task)
915        };
916
917        notify_flow_shutdown(flow_id, removed_shutdown_tx, "rolled back");
918        abort_flow_task(flow_id, removed_task, "rolled back");
919    }
920}
921
922fn notify_flow_shutdown(flow_id: FlowId, tx: Option<oneshot::Sender<()>>, action: &str) -> bool {
923    let Some(tx) = tx else {
924        return false;
925    };
926
927    if tx.send(()).is_err() {
928        warn!(
929            "Fail to shutdown {action} flow {flow_id} due to receiver already dropped, maybe flow {flow_id} is already dropped?"
930        );
931    }
932
933    true
934}
935
936fn abort_flow_task(flow_id: FlowId, task: Option<BatchingTask>, action: &str) -> bool {
937    let Some(task) = task else {
938        return false;
939    };
940
941    if let Some(handle) = task.state.write().unwrap().task_handle.take() {
942        handle.abort();
943        debug!("Aborted {action} flow task {flow_id}");
944        return true;
945    }
946
947    false
948}
949
950impl FlowEngine for BatchingEngine {
951    async fn create_flow(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
952        self.create_flow_inner(args).await
953    }
954    async fn remove_flow(&self, flow_id: FlowId) -> Result<(), Error> {
955        self.remove_flow_inner(flow_id).await
956    }
957    async fn flush_flow(&self, flow_id: FlowId) -> Result<usize, Error> {
958        self.flush_flow_inner(flow_id).await
959    }
960    async fn flow_exist(&self, flow_id: FlowId) -> Result<bool, Error> {
961        Ok(self.flow_exist_inner(flow_id).await)
962    }
963    async fn list_flows(&self) -> Result<impl IntoIterator<Item = FlowId>, Error> {
964        Ok(self
965            .runtime
966            .read()
967            .await
968            .tasks
969            .keys()
970            .cloned()
971            .collect::<Vec<_>>())
972    }
973    async fn handle_flow_inserts(
974        &self,
975        request: api::v1::region::InsertRequests,
976    ) -> Result<(), Error> {
977        self.handle_inserts_inner(request).await
978    }
979    async fn handle_mark_window_dirty(
980        &self,
981        req: api::v1::flow::DirtyWindowRequests,
982    ) -> Result<(), Error> {
983        self.handle_mark_dirty_time_window(req).await
984    }
985}
986
987#[cfg(test)]
988mod tests {
989    use catalog::memory::new_memory_catalog_manager;
990    use common_meta::key::TableMetadataManager;
991    use common_meta::key::flow::FlowMetadataManager;
992    use common_meta::kv_backend::memory::MemoryKvBackend;
993    use query::options::QueryOptions;
994    use session::context::QueryContext;
995
996    use super::*;
997    use crate::test_utils::create_test_query_engine;
998
999    struct DropNotify(Option<oneshot::Sender<()>>);
1000
1001    impl Drop for DropNotify {
1002        fn drop(&mut self) {
1003            if let Some(tx) = self.0.take() {
1004                let _ = tx.send(());
1005            }
1006        }
1007    }
1008
1009    async fn new_test_engine() -> BatchingEngine {
1010        let kv_backend = Arc::new(MemoryKvBackend::new());
1011        let table_meta = Arc::new(TableMetadataManager::new(kv_backend.clone()));
1012        table_meta.init().await.unwrap();
1013        let flow_meta = Arc::new(FlowMetadataManager::new(kv_backend));
1014        let catalog_manager = new_memory_catalog_manager().unwrap();
1015        let query_engine = create_test_query_engine();
1016        let (frontend_client, _handler) =
1017            FrontendClient::from_empty_grpc_handler(QueryOptions::default());
1018
1019        BatchingEngine::new(
1020            Arc::new(frontend_client),
1021            query_engine,
1022            flow_meta,
1023            table_meta,
1024            catalog_manager,
1025            BatchingModeOptions::default(),
1026        )
1027    }
1028
1029    #[tokio::test]
1030    async fn test_flow_option_overrides_incremental_read_switch() {
1031        let engine = new_test_engine().await;
1032
1033        let default_opts = engine.batch_opts_for_flow_options(&HashMap::new()).unwrap();
1034        assert!(!default_opts.experimental_enable_incremental_read);
1035
1036        let enabled_opts = engine
1037            .batch_opts_for_flow_options(&HashMap::from([(
1038                FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY.to_string(),
1039                "true".to_string(),
1040            )]))
1041            .unwrap();
1042        assert!(enabled_opts.experimental_enable_incremental_read);
1043    }
1044
1045    #[test]
1046    fn test_table_options_enable_append_mode() {
1047        assert!(!BatchingEngine::table_options_enable_append_mode(
1048            &HashMap::new()
1049        ));
1050        assert!(!BatchingEngine::table_options_enable_append_mode(
1051            &HashMap::from([(APPEND_MODE_KEY.to_string(), "false".to_string())])
1052        ));
1053        assert!(BatchingEngine::table_options_enable_append_mode(
1054            &HashMap::from([(APPEND_MODE_KEY.to_string(), "TRUE".to_string())])
1055        ));
1056    }
1057
1058    #[test]
1059    fn test_sql_flow_requires_time_window_or_eval_interval() {
1060        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, true)
1061            .expect("SQL flow with a time-window expression should be accepted");
1062        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(Some(10), false).expect(
1063            "SQL flow with EVAL INTERVAL should be accepted as an explicit full-query flow",
1064        );
1065
1066        let err = BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, false)
1067            .expect_err("SQL flow without a time-window expression or EVAL INTERVAL should fail");
1068        assert!(matches!(err, Error::InvalidQuery { .. }), "{err}");
1069        assert!(
1070            err.to_string().contains("must specify EVAL INTERVAL"),
1071            "{err}"
1072        );
1073    }
1074
1075    #[tokio::test]
1076    async fn test_complex_sql_without_eval_interval_is_rejected_as_no_twe() {
1077        let query_engine = create_test_query_engine();
1078        let ctx = QueryContext::arc();
1079        let plan = sql_to_df_plan(
1080            ctx.clone(),
1081            query_engine.clone(),
1082            r#"
1083SELECT
1084    l.number,
1085    date_bin('5 minutes', l.ts) AS time_window
1086FROM numbers_with_ts l
1087JOIN numbers_with_ts r ON l.number = r.number
1088GROUP BY l.number, time_window
1089"#,
1090            true,
1091        )
1092        .await
1093        .unwrap();
1094
1095        let (_, time_window_expr, _, _) = find_time_window_expr(
1096            &plan,
1097            query_engine.engine_state().catalog_manager().clone(),
1098            ctx,
1099        )
1100        .await
1101        .unwrap();
1102        assert!(
1103            time_window_expr.is_none(),
1104            "complex SQL should be classified as having no safe TWE"
1105        );
1106
1107        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(Some(10), false)
1108            .expect("complex SQL can run as an explicit full-query flow when EVAL INTERVAL is set");
1109        let err = BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, false)
1110            .expect_err("complex SQL without EVAL INTERVAL should fail creation");
1111        assert!(matches!(err, Error::InvalidQuery { .. }), "{err}");
1112    }
1113
1114    #[test]
1115    fn test_incremental_source_append_only_enforcement() {
1116        let table_name = [
1117            "greptime".to_string(),
1118            "public".to_string(),
1119            "numbers".to_string(),
1120        ];
1121        let disabled_opts = BatchingModeOptions::default();
1122        let enabled_opts = BatchingModeOptions {
1123            experimental_enable_incremental_read: true,
1124            ..Default::default()
1125        };
1126        let non_append_options = HashMap::new();
1127        let append_options = HashMap::from([(APPEND_MODE_KEY.to_string(), "true".to_string())]);
1128
1129        BatchingEngine::ensure_incremental_source_append_only(
1130            &disabled_opts,
1131            &table_name,
1132            &non_append_options,
1133        )
1134        .expect("disabled incremental read should not require append-only source");
1135        BatchingEngine::ensure_incremental_source_append_only(
1136            &enabled_opts,
1137            &table_name,
1138            &append_options,
1139        )
1140        .expect("append-only source should be accepted when incremental read is enabled");
1141
1142        let err = BatchingEngine::ensure_incremental_source_append_only(
1143            &enabled_opts,
1144            &table_name,
1145            &non_append_options,
1146        )
1147        .expect_err("non-append source should be rejected when incremental read is enabled");
1148        assert!(
1149            err.to_string()
1150                .contains("Flow incremental read requires append-only source table"),
1151            "{err}"
1152        );
1153    }
1154
1155    async fn new_test_task(flow_id: FlowId) -> (BatchingTask, oneshot::Sender<()>) {
1156        let query_engine = create_test_query_engine();
1157        let ctx = QueryContext::arc();
1158        let plan = sql_to_df_plan(
1159            ctx.clone(),
1160            query_engine.clone(),
1161            "SELECT number, ts FROM numbers_with_ts",
1162            true,
1163        )
1164        .await
1165        .unwrap();
1166        let (tx, rx) = oneshot::channel();
1167
1168        let task = BatchingTask::try_new(TaskArgs {
1169            flow_id,
1170            query: "SELECT number, ts FROM numbers_with_ts",
1171            plan,
1172            time_window_expr: None,
1173            expire_after: None,
1174            sink_table_name: [
1175                "greptime".to_string(),
1176                "public".to_string(),
1177                "sink".to_string(),
1178            ],
1179            source_table_names: vec![[
1180                "greptime".to_string(),
1181                "public".to_string(),
1182                "numbers_with_ts".to_string(),
1183            ]],
1184            query_ctx: ctx,
1185            catalog_manager: query_engine.engine_state().catalog_manager().clone(),
1186            shutdown_rx: rx,
1187            batch_opts: Arc::new(BatchingModeOptions::default()),
1188            flow_eval_interval: None,
1189        })
1190        .unwrap();
1191
1192        (task, tx)
1193    }
1194
1195    async fn install_abort_observed_handle(task: &BatchingTask) -> oneshot::Receiver<()> {
1196        let (drop_tx, drop_rx) = oneshot::channel();
1197        let (entered_tx, entered_rx) = oneshot::channel();
1198        let handle = tokio::spawn(async move {
1199            let _guard = DropNotify(Some(drop_tx));
1200            let _ = entered_tx.send(());
1201            std::future::pending::<()>().await;
1202        });
1203        task.state.write().unwrap().task_handle = Some(handle);
1204        tokio::time::timeout(Duration::from_secs(1), entered_rx)
1205            .await
1206            .expect("test task handle should start")
1207            .expect("test task handle should report start");
1208        drop_rx
1209    }
1210
1211    #[tokio::test]
1212    async fn test_notify_flow_shutdown_sends_signal() {
1213        let (tx, rx) = oneshot::channel();
1214
1215        assert!(notify_flow_shutdown(42, Some(tx), "test"));
1216
1217        rx.await.expect("replaced flow should receive shutdown");
1218    }
1219
1220    #[test]
1221    fn test_notify_flow_shutdown_accepts_missing_sender() {
1222        assert!(!notify_flow_shutdown(42, None, "test"));
1223    }
1224
1225    #[tokio::test]
1226    async fn test_abort_flow_task_aborts_handle() {
1227        let (task, _shutdown_tx) = new_test_task(42).await;
1228        let drop_rx = install_abort_observed_handle(&task).await;
1229
1230        assert!(abort_flow_task(42, Some(task), "test"));
1231
1232        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1233            .await
1234            .expect("aborted task should be dropped")
1235            .expect("drop notifier should fire");
1236    }
1237
1238    #[tokio::test]
1239    async fn test_remove_flow_inner_aborts_registered_task() {
1240        let engine = new_test_engine().await;
1241        let (task, shutdown_tx) = new_test_task(42).await;
1242        let drop_rx = install_abort_observed_handle(&task).await;
1243
1244        engine.runtime.write().await.insert(42, task, shutdown_tx);
1245
1246        engine.remove_flow_inner(42).await.unwrap();
1247
1248        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1249            .await
1250            .expect("removed task should be dropped")
1251            .expect("drop notifier should fire");
1252        assert!(!engine.flow_exist_inner(42).await);
1253        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1254    }
1255
1256    #[tokio::test]
1257    async fn test_or_replace_flow_runtime_replaces_old_handles_and_keeps_new_task() {
1258        let engine = new_test_engine().await;
1259        let (old_task, old_shutdown_tx) = new_test_task(42).await;
1260        let old_task_identity = old_task.clone();
1261        let old_drop_rx = install_abort_observed_handle(&old_task).await;
1262        let (new_task, new_shutdown_tx) = new_test_task(42).await;
1263        let new_task_identity = new_task.clone();
1264
1265        engine
1266            .runtime
1267            .write()
1268            .await
1269            .insert(42, old_task, old_shutdown_tx);
1270        let (replaced_old_task, replaced_old_shutdown_tx) =
1271            engine
1272                .runtime
1273                .write()
1274                .await
1275                .insert(42, new_task, new_shutdown_tx);
1276
1277        let replaced_old_task = replaced_old_task.expect("old task should be returned");
1278        assert!(Arc::ptr_eq(
1279            &replaced_old_task.state,
1280            &old_task_identity.state
1281        ));
1282        assert!(notify_flow_shutdown(
1283            42,
1284            replaced_old_shutdown_tx,
1285            "replaced"
1286        ));
1287        old_task_identity
1288            .state
1289            .write()
1290            .unwrap()
1291            .shutdown_rx
1292            .try_recv()
1293            .expect("old shutdown receiver should receive signal");
1294        assert!(abort_flow_task(42, Some(replaced_old_task), "replaced"));
1295
1296        tokio::time::timeout(Duration::from_secs(1), old_drop_rx)
1297            .await
1298            .expect("replaced task should be dropped")
1299            .expect("drop notifier should fire");
1300
1301        let runtime = engine.runtime.read().await;
1302        assert_eq!(1, runtime.tasks.len());
1303        assert_eq!(1, runtime.shutdown_txs.len());
1304        let registered_task = runtime.tasks.get(&42).expect("new task should remain");
1305        assert!(Arc::ptr_eq(
1306            &registered_task.state,
1307            &new_task_identity.state
1308        ));
1309        assert!(runtime.shutdown_txs.contains_key(&42));
1310        assert!(matches!(
1311            new_task_identity
1312                .state
1313                .write()
1314                .unwrap()
1315                .shutdown_rx
1316                .try_recv(),
1317            Err(oneshot::error::TryRecvError::Empty)
1318        ));
1319    }
1320
1321    #[tokio::test]
1322    async fn test_rollback_flow_runtime_if_current_removes_matching_task_only() {
1323        let engine = new_test_engine().await;
1324        let (old_task, _old_shutdown_tx) = new_test_task(42).await;
1325        let (current_task, current_shutdown_tx) = new_test_task(42).await;
1326        let current_task_identity = current_task.clone();
1327
1328        engine
1329            .runtime
1330            .write()
1331            .await
1332            .insert(42, current_task, current_shutdown_tx);
1333
1334        engine.rollback_flow_runtime_if_current(42, &old_task).await;
1335
1336        let registered_task = engine.runtime.read().await.tasks.get(&42).cloned().unwrap();
1337        assert!(Arc::ptr_eq(
1338            &registered_task.state,
1339            &current_task_identity.state
1340        ));
1341        assert!(engine.runtime.read().await.shutdown_txs.contains_key(&42));
1342
1343        engine
1344            .rollback_flow_runtime_if_current(42, &current_task_identity)
1345            .await;
1346        assert!(!engine.flow_exist_inner(42).await);
1347        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1348    }
1349}