Skip to main content

flow/batching_mode/
engine.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Batching mode engine
16
17use std::collections::{BTreeMap, HashMap, HashSet};
18use std::sync::Arc;
19use std::time::Duration;
20
21use api::v1::flow::DirtyWindowRequests;
22use catalog::CatalogManagerRef;
23use common_error::ext::BoxedError;
24use common_meta::ddl::create_flow::{FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY, FlowType};
25use common_meta::key::TableMetadataManagerRef;
26use common_meta::key::flow::FlowMetadataManagerRef;
27use common_meta::key::flow::flow_state::FlowStat;
28use common_meta::key::table_info::{TableInfoManager, TableInfoValue};
29use common_runtime::JoinHandle;
30use common_telemetry::tracing::warn;
31use common_telemetry::{debug, info};
32use common_time::TimeToLive;
33use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
34use datafusion_expr::LogicalPlan;
35use datatypes::prelude::ConcreteDataType;
36use query::QueryEngineRef;
37use session::context::QueryContext;
38use snafu::{OptionExt, ResultExt, ensure};
39use sql::parsers::utils::is_tql;
40use store_api::metric_engine_consts::is_metric_engine_internal_column;
41use store_api::mito_engine_options::APPEND_MODE_KEY;
42use store_api::storage::{RegionId, TableId};
43use table::table_reference::TableReference;
44use tokio::sync::{RwLock, oneshot};
45
46use crate::batching_mode::BatchingModeOptions;
47use crate::batching_mode::eval_schedule::EvalSchedule;
48use crate::batching_mode::frontend_client::FrontendClient;
49use crate::batching_mode::task::{BatchingTask, TaskArgs};
50use crate::batching_mode::time_window::{TimeWindowExpr, find_time_window_expr};
51use crate::batching_mode::utils::sql_to_df_plan;
52use crate::engine::{FlowEngine, FlowStatProvider};
53use crate::error::{
54    CreateFlowSnafu, DatafusionSnafu, ExternalSnafu, FlowAlreadyExistSnafu, FlowNotFoundSnafu,
55    InvalidQuerySnafu, TableNotFoundMetaSnafu, UnexpectedSnafu, UnsupportedSnafu,
56};
57use crate::metrics::METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW;
58use crate::{CreateFlowArgs, Error, FlowId, TableName};
59
60/// Batching mode Engine, responsible for driving all the batching mode tasks
61///
62/// TODO(discord9): determine how to configure refresh rate
63pub struct BatchingEngine {
64    runtime: RwLock<FlowRuntimeRegistry>,
65    /// frontend client for insert request
66    pub(crate) frontend_client: Arc<FrontendClient>,
67    flow_metadata_manager: FlowMetadataManagerRef,
68    table_meta: TableMetadataManagerRef,
69    catalog_manager: CatalogManagerRef,
70    query_engine: QueryEngineRef,
71    /// Batching mode options for control how batching mode query works
72    ///
73    pub(crate) batch_opts: Arc<BatchingModeOptions>,
74}
75
76#[derive(Default)]
77struct FlowRuntimeRegistry {
78    tasks: BTreeMap<FlowId, BatchingTask>,
79    shutdown_txs: BTreeMap<FlowId, oneshot::Sender<()>>,
80}
81
82impl FlowRuntimeRegistry {
83    fn insert(
84        &mut self,
85        flow_id: FlowId,
86        task: BatchingTask,
87        shutdown_tx: oneshot::Sender<()>,
88    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
89        (
90            self.tasks.insert(flow_id, task),
91            self.shutdown_txs.insert(flow_id, shutdown_tx),
92        )
93    }
94
95    fn remove(&mut self, flow_id: FlowId) -> Option<(BatchingTask, Option<oneshot::Sender<()>>)> {
96        let task = self.tasks.remove(&flow_id)?;
97        let shutdown_tx = self.shutdown_txs.remove(&flow_id);
98        Some((task, shutdown_tx))
99    }
100
101    fn remove_if_current(
102        &mut self,
103        flow_id: FlowId,
104        task: &BatchingTask,
105    ) -> (Option<BatchingTask>, Option<oneshot::Sender<()>>) {
106        if self
107            .tasks
108            .get(&flow_id)
109            .is_some_and(|current| Arc::ptr_eq(&current.state, &task.state))
110        {
111            let Some((removed_task, removed_shutdown_tx)) = self.remove(flow_id) else {
112                return (None, None);
113            };
114            (Some(removed_task), removed_shutdown_tx)
115        } else {
116            (None, None)
117        }
118    }
119}
120
121impl BatchingEngine {
122    pub fn new(
123        frontend_client: Arc<FrontendClient>,
124        query_engine: QueryEngineRef,
125        flow_metadata_manager: FlowMetadataManagerRef,
126        table_meta: TableMetadataManagerRef,
127        catalog_manager: CatalogManagerRef,
128        batch_opts: BatchingModeOptions,
129    ) -> Self {
130        Self {
131            runtime: Default::default(),
132            frontend_client,
133            flow_metadata_manager,
134            table_meta,
135            catalog_manager,
136            query_engine,
137            batch_opts: Arc::new(batch_opts),
138        }
139    }
140
141    /// Returns last execution timestamps (millisecond) for all batching flows.
142    pub async fn get_last_exec_time_map(&self) -> BTreeMap<FlowId, i64> {
143        let runtime = self.runtime.read().await;
144        runtime
145            .tasks
146            .iter()
147            .filter_map(|(flow_id, task)| {
148                task.last_execution_time_millis()
149                    .map(|timestamp| (*flow_id, timestamp))
150            })
151            .collect()
152    }
153
154    pub async fn handle_mark_dirty_time_window(
155        &self,
156        reqs: DirtyWindowRequests,
157    ) -> Result<(), Error> {
158        let table_info_mgr = self.table_meta.table_info_manager();
159
160        let mut group_by_table_id: HashMap<u32, Vec<_>> = HashMap::new();
161        for r in reqs.requests {
162            let tid = TableId::from(r.table_id);
163            let entry = group_by_table_id.entry(tid).or_default();
164            entry.extend(r.timestamps);
165        }
166        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
167        let table_infos =
168            table_info_mgr
169                .batch_get(&tids)
170                .await
171                .with_context(|_| TableNotFoundMetaSnafu {
172                    msg: format!("Failed to get table info for table ids: {:?}", tids),
173                })?;
174
175        let group_by_table_name = group_by_table_id
176            .into_iter()
177            .filter_map(|(id, timestamps)| {
178                let table_name = table_infos.get(&id).map(|info| info.table_name());
179                let Some(table_name) = table_name else {
180                    warn!("Failed to get table infos for table id: {:?}", id);
181                    return None;
182                };
183                let table_name = [
184                    table_name.catalog_name,
185                    table_name.schema_name,
186                    table_name.table_name,
187                ];
188                let schema = &table_infos.get(&id).unwrap().table_info.meta.schema;
189                let time_index_unit = schema.column_schemas()[schema.timestamp_index().unwrap()]
190                    .data_type
191                    .as_timestamp()
192                    .unwrap()
193                    .unit();
194                Some((table_name, (timestamps, time_index_unit)))
195            })
196            .collect::<HashMap<_, _>>();
197
198        let group_by_table_name = Arc::new(group_by_table_name);
199
200        let tasks = self
201            .runtime
202            .read()
203            .await
204            .tasks
205            .values()
206            .cloned()
207            .collect::<Vec<_>>();
208        let mut handles = Vec::new();
209
210        for task in tasks {
211            let src_table_names = &task.config.source_table_names;
212
213            if src_table_names
214                .iter()
215                .all(|name| !group_by_table_name.contains_key(name))
216            {
217                continue;
218            }
219
220            let group_by_table_name = group_by_table_name.clone();
221            let task = task.clone();
222
223            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
224                let src_table_names = &task.config.source_table_names;
225                let mut all_dirty_windows = HashSet::new();
226                let mut is_dirty = false;
227                for src_table_name in src_table_names {
228                    if let Some((timestamps, unit)) = group_by_table_name.get(src_table_name) {
229                        let Some(expr) = &task.config.time_window_expr else {
230                            is_dirty = true;
231                            continue;
232                        };
233                        for timestamp in timestamps {
234                            let align_start = expr
235                                .eval(common_time::Timestamp::new(*timestamp, *unit))?
236                                .0
237                                .context(UnexpectedSnafu {
238                                    reason: "Failed to eval start value",
239                                })?;
240                            all_dirty_windows.insert(align_start);
241                        }
242                    }
243                }
244                let mut state = task.state.write().unwrap();
245                if is_dirty {
246                    state.dirty_time_windows.set_dirty();
247                }
248                let flow_id_label = task.config.flow_id.to_string();
249                for timestamp in all_dirty_windows {
250                    state.dirty_time_windows.add_window(timestamp, None);
251                }
252
253                METRIC_FLOW_BATCHING_ENGINE_BULK_MARK_TIME_WINDOW
254                    .with_label_values(&[&flow_id_label])
255                    .set(state.dirty_time_windows.len() as f64);
256                Ok(())
257            });
258            handles.push(handle);
259        }
260        for handle in handles {
261            match handle.await {
262                Err(e) => {
263                    warn!("Failed to handle inserts: {e}");
264                }
265                Ok(Ok(())) => (),
266                Ok(Err(e)) => {
267                    warn!("Failed to handle inserts: {e}");
268                }
269            }
270        }
271
272        Ok(())
273    }
274
275    pub async fn handle_inserts_inner(
276        &self,
277        request: api::v1::region::InsertRequests,
278    ) -> Result<(), Error> {
279        let table_info_mgr = self.table_meta.table_info_manager();
280        let mut group_by_table_id: HashMap<TableId, Vec<api::v1::Rows>> = HashMap::new();
281
282        for r in request.requests {
283            let tid = RegionId::from(r.region_id).table_id();
284            let entry = group_by_table_id.entry(tid).or_default();
285            if let Some(rows) = r.rows {
286                entry.push(rows);
287            }
288        }
289
290        let tids = group_by_table_id.keys().cloned().collect::<Vec<TableId>>();
291        let table_infos =
292            table_info_mgr
293                .batch_get(&tids)
294                .await
295                .with_context(|_| TableNotFoundMetaSnafu {
296                    msg: format!("Failed to get table info for table ids: {:?}", tids),
297                })?;
298
299        let missing_tids = tids
300            .iter()
301            .filter(|id| !table_infos.contains_key(id))
302            .collect::<Vec<_>>();
303        if !missing_tids.is_empty() {
304            warn!(
305                "Failed to get all the table info for table ids, expected table ids: {:?}, those table doesn't exist: {:?}",
306                tids, missing_tids
307            );
308        }
309
310        let group_by_table_name = group_by_table_id
311            .into_iter()
312            .filter_map(|(id, rows)| {
313                let table_name = table_infos.get(&id).map(|info| info.table_name());
314                let Some(table_name) = table_name else {
315                    warn!("Failed to get table infos for table id: {:?}", id);
316                    return None;
317                };
318                let table_name = [
319                    table_name.catalog_name,
320                    table_name.schema_name,
321                    table_name.table_name,
322                ];
323                Some((table_name, rows))
324            })
325            .collect::<HashMap<_, _>>();
326
327        let group_by_table_name = Arc::new(group_by_table_name);
328
329        let tasks = self
330            .runtime
331            .read()
332            .await
333            .tasks
334            .values()
335            .cloned()
336            .collect::<Vec<_>>();
337        let mut handles = Vec::new();
338        for task in tasks {
339            let src_table_names = &task.config.source_table_names;
340
341            if src_table_names
342                .iter()
343                .all(|name| !group_by_table_name.contains_key(name))
344            {
345                continue;
346            }
347
348            let group_by_table_name = group_by_table_name.clone();
349            let task = task.clone();
350
351            let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
352                let src_table_names = &task.config.source_table_names;
353
354                let mut is_dirty = false;
355
356                for src_table_name in src_table_names {
357                    if let Some(entry) = group_by_table_name.get(src_table_name) {
358                        let Some(expr) = &task.config.time_window_expr else {
359                            is_dirty = true;
360                            continue;
361                        };
362                        let involved_time_windows = expr.handle_rows(entry.clone()).await?;
363                        let mut state = task.state.write().unwrap();
364                        state
365                            .dirty_time_windows
366                            .add_lower_bounds(involved_time_windows.into_iter());
367                    }
368                }
369                if is_dirty {
370                    task.state.write().unwrap().dirty_time_windows.set_dirty();
371                }
372
373                Ok(())
374            });
375            handles.push(handle);
376        }
377
378        for handle in handles {
379            match handle.await {
380                Err(e) => {
381                    warn!("Failed to handle inserts: {e}");
382                }
383                Ok(Ok(())) => (),
384                Ok(Err(e)) => {
385                    warn!("Failed to handle inserts: {e}");
386                }
387            }
388        }
389        Ok(())
390    }
391}
392
393impl FlowStatProvider for BatchingEngine {
394    async fn flow_stat(&self) -> FlowStat {
395        FlowStat {
396            state_size: BTreeMap::new(),
397            last_exec_time_map: self
398                .get_last_exec_time_map()
399                .await
400                .into_iter()
401                .map(|(flow_id, timestamp)| (flow_id as u32, timestamp))
402                .collect(),
403        }
404    }
405}
406
407async fn get_table_name(
408    table_info: &TableInfoManager,
409    table_id: &TableId,
410) -> Result<TableName, Error> {
411    get_table_info(table_info, table_id).await.map(|info| {
412        let name = info.table_name();
413        [name.catalog_name, name.schema_name, name.table_name]
414    })
415}
416
417async fn get_table_info(
418    table_info: &TableInfoManager,
419    table_id: &TableId,
420) -> Result<TableInfoValue, Error> {
421    table_info
422        .get(*table_id)
423        .await
424        .map_err(BoxedError::new)
425        .context(ExternalSnafu)?
426        .with_context(|| UnexpectedSnafu {
427            reason: format!("Table id = {:?}, couldn't found table name", table_id),
428        })
429        .map(|info| info.into_inner())
430}
431
432impl BatchingEngine {
433    fn batch_opts_for_flow_options(
434        &self,
435        flow_options: &HashMap<String, String>,
436    ) -> Result<Arc<BatchingModeOptions>, Error> {
437        let mut batch_opts = (*self.batch_opts).clone();
438        if let Some(enable_incremental_read) =
439            flow_options.get(FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY)
440        {
441            batch_opts.experimental_enable_incremental_read = enable_incremental_read
442                .parse::<bool>()
443                .map_err(|_| {
444                    InvalidQuerySnafu {
445                        reason: format!(
446                            "Invalid flow option {FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY}: {enable_incremental_read}"
447                        ),
448                    }
449                    .build()
450                })?;
451        }
452
453        Ok(Arc::new(batch_opts))
454    }
455
456    fn table_options_enable_append_mode(extra_options: &HashMap<String, String>) -> bool {
457        extra_options
458            .get(APPEND_MODE_KEY)
459            .is_some_and(|value| value.eq_ignore_ascii_case("true"))
460    }
461
462    /// SQL flows without a usable time-window expression can only run as an
463    /// explicit full-query flow, so require `EVAL INTERVAL` at creation time.
464    fn ensure_sql_flow_has_twe_or_eval_interval(
465        eval_interval: Option<i64>,
466        has_time_window_expr: bool,
467    ) -> Result<(), Error> {
468        ensure!(
469            eval_interval.is_some() || has_time_window_expr,
470            InvalidQuerySnafu {
471                reason: "SQL batching flow without a time-window expression must specify EVAL INTERVAL to run as an explicit full-query flow"
472                    .to_string(),
473            }
474        );
475        Ok(())
476    }
477
478    fn ensure_incremental_source_append_only(
479        batch_opts: &BatchingModeOptions,
480        table_name: &[String; 3],
481        extra_options: &HashMap<String, String>,
482    ) -> Result<(), Error> {
483        if batch_opts.experimental_enable_incremental_read {
484            ensure!(
485                Self::table_options_enable_append_mode(extra_options),
486                UnsupportedSnafu {
487                    reason: format!(
488                        "Flow incremental read requires append-only source table, but source table `{}` is not append-only. Consider setting append_mode='true' on the source table or disabling experimental_enable_incremental_read",
489                        table_name.join(".")
490                    ),
491                }
492            );
493        }
494
495        Ok(())
496    }
497
498    pub async fn create_flow_inner(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
499        let CreateFlowArgs {
500            flow_id,
501            sink_table_name,
502            source_table_ids,
503            create_if_not_exists,
504            or_replace,
505            expire_after,
506            eval_interval,
507            comment: _,
508            sql,
509            flow_options,
510            query_ctx,
511            eval_schedule: eval_schedule_config,
512        } = args;
513
514        // or replace logic
515        {
516            let is_exist = self.runtime.read().await.tasks.contains_key(&flow_id);
517            match (create_if_not_exists, or_replace, is_exist) {
518                // if replace, ignore that old flow exists
519                (_, true, true) => {
520                    info!("Replacing flow with id={}", flow_id);
521                }
522                (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
523                // already exists, and not replace, return None
524                (true, false, true) => {
525                    info!("Flow with id={} already exists, do nothing", flow_id);
526                    return Ok(None);
527                }
528
529                // continue as normal
530                (_, _, false) => (),
531            }
532        }
533
534        let query_ctx = query_ctx.context({
535            UnexpectedSnafu {
536                reason: "Query context is None".to_string(),
537            }
538        })?;
539        let query_ctx = Arc::new(query_ctx);
540        let is_tql = is_tql(query_ctx.sql_dialect(), &sql)
541            .map_err(BoxedError::new)
542            .context(CreateFlowSnafu { sql: &sql })?;
543
544        // optionally set a eval interval for the flow
545        if eval_interval.is_none() && is_tql {
546            InvalidQuerySnafu {
547                reason: "TQL query requires EVAL INTERVAL to be set".to_string(),
548            }
549            .fail()?;
550        }
551
552        let flow_type = flow_options.get(FlowType::FLOW_TYPE_KEY);
553
554        ensure!(
555            match flow_type {
556                None => true,
557                Some(ty) if ty == FlowType::BATCHING => true,
558                _ => false,
559            },
560            UnexpectedSnafu {
561                reason: format!("Flow type is not batching nor None, got {flow_type:?}")
562            }
563        );
564
565        let batch_opts = self.batch_opts_for_flow_options(&flow_options)?;
566
567        let mut source_table_names = Vec::with_capacity(2);
568        for src_id in source_table_ids {
569            // also check table option to see if ttl!=instant
570            let table_name = get_table_name(self.table_meta.table_info_manager(), &src_id).await?;
571            let table_info = get_table_info(self.table_meta.table_info_manager(), &src_id).await?;
572            ensure!(
573                table_info.table_info.meta.options.ttl != Some(TimeToLive::Instant),
574                UnsupportedSnafu {
575                    reason: format!(
576                        "Source table `{}`(id={}) has instant TTL, Instant TTL is not supported under batching mode. Consider using a TTL longer than flush interval",
577                        table_name.join("."),
578                        src_id
579                    ),
580                }
581            );
582            Self::ensure_incremental_source_append_only(
583                &batch_opts,
584                &table_name,
585                &table_info.table_info.meta.options.extra_options,
586            )?;
587
588            source_table_names.push(table_name);
589        }
590
591        let (tx, rx) = oneshot::channel();
592
593        let plan = sql_to_df_plan(query_ctx.clone(), self.query_engine.clone(), &sql, true).await?;
594
595        if is_tql {
596            self.check_is_tql_table(&plan, &query_ctx).await?;
597        }
598
599        let phy_expr = if !is_tql {
600            let (column_name, time_window_expr, _, df_schema) = find_time_window_expr(
601                &plan,
602                self.query_engine.engine_state().catalog_manager().clone(),
603                query_ctx.clone(),
604            )
605            .await?;
606            time_window_expr
607                .map(|expr| {
608                    TimeWindowExpr::from_expr(
609                        &expr,
610                        &column_name,
611                        &df_schema,
612                        &self.query_engine.engine_state().session_state(),
613                    )
614                })
615                .transpose()?
616        } else {
617            // tql control by `EVAL INTERVAL`, no need to find time window expr
618            None
619        };
620
621        debug!(
622            "Flow id={}, found time window expr={}",
623            flow_id,
624            phy_expr
625                .as_ref()
626                .map(|phy_expr| phy_expr.to_string())
627                .unwrap_or("None".to_string())
628        );
629
630        if !is_tql {
631            Self::ensure_sql_flow_has_twe_or_eval_interval(eval_interval, phy_expr.is_some())?;
632        }
633
634        // Compute typed EvalSchedule from FlowScheduleConfig.
635        let eval_schedule = {
636            let interval = eval_interval;
637            let config = eval_schedule_config.as_ref();
638            match EvalSchedule::from_config(interval, config) {
639                Ok(s) => s,
640                Err(e) => {
641                    return UnexpectedSnafu {
642                        reason: format!(
643                            "Failed to build eval schedule for flow {}: {}",
644                            flow_id, e
645                        ),
646                    }
647                    .fail();
648                }
649            }
650        };
651
652        let task_args = TaskArgs {
653            flow_id,
654            query: &sql,
655            plan,
656            time_window_expr: phy_expr,
657            expire_after,
658            sink_table_name,
659            source_table_names,
660            query_ctx,
661            catalog_manager: self.catalog_manager.clone(),
662            shutdown_rx: rx,
663            batch_opts,
664            flow_eval_interval: eval_interval.map(|secs| Duration::from_secs(secs as u64)),
665            eval_schedule,
666        };
667
668        let task = BatchingTask::try_new(task_args)?;
669
670        let task_inner = task.clone();
671        let engine = self.query_engine.clone();
672        let frontend = self.frontend_client.clone();
673
674        // Create sink table if needed, then validate an existing/created sink schema before
675        // spawning the background task. This catches user-created sink schema mismatches at
676        // CREATE FLOW time instead of surfacing them later in the execution loop.
677        task.check_or_create_sink_table(&engine, &frontend).await?;
678        task.validate_sink_table_schema(&engine).await?;
679
680        let (start_tx, start_rx) = oneshot::channel();
681
682        // TODO(discord9): use time wheel or what for better
683        let handle = common_runtime::spawn_global(async move {
684            if start_rx.await.is_ok() {
685                task_inner.start_executing_loop(engine, frontend).await;
686            }
687        });
688        task.state.write().unwrap().task_handle = Some(handle);
689        let task_for_rollback = task.clone();
690
691        // Only replace here, not earlier, because we want the old one intact if
692        // something went wrong before this line. Keep the task and shutdown
693        // sender in one registry lock so create/remove can't observe one
694        // without the other.
695        let (replaced_old_task_opt, replaced_old_shutdown_tx) = {
696            let mut runtime = self.runtime.write().await;
697
698            let is_exist = runtime.tasks.contains_key(&flow_id);
699            match (create_if_not_exists, or_replace, is_exist) {
700                (_, true, true) => {
701                    info!(
702                        "Replacing flow with id={} after final registry check",
703                        flow_id
704                    );
705                }
706                (false, false, true) => {
707                    abort_flow_task(flow_id, Some(task), "unregistered");
708                    return FlowAlreadyExistSnafu { id: flow_id }.fail();
709                }
710                (true, false, true) => {
711                    info!(
712                        "Flow with id={} already exists at final registry check, do nothing",
713                        flow_id
714                    );
715                    abort_flow_task(flow_id, Some(task), "unregistered");
716                    return Ok(None);
717                }
718                (_, _, false) => (),
719            }
720
721            runtime.insert(flow_id, task, tx)
722        };
723
724        notify_flow_shutdown(flow_id, replaced_old_shutdown_tx, "replaced");
725        abort_flow_task(flow_id, replaced_old_task_opt, "replaced");
726        if start_tx.send(()).is_err() {
727            self.rollback_flow_runtime_if_current(flow_id, &task_for_rollback)
728                .await;
729            UnexpectedSnafu {
730                reason: format!("Failed to start flow {flow_id} due to task already dropped"),
731            }
732            .fail()?;
733        }
734
735        Ok(Some(flow_id))
736    }
737
738    async fn check_is_tql_table(
739        &self,
740        query: &LogicalPlan,
741        query_ctx: &QueryContext,
742    ) -> Result<(), Error> {
743        struct CollectTableRef {
744            table_refs: HashSet<datafusion_common::TableReference>,
745        }
746
747        impl TreeNodeVisitor<'_> for CollectTableRef {
748            type Node = LogicalPlan;
749            fn f_down(
750                &mut self,
751                node: &Self::Node,
752            ) -> datafusion_common::Result<TreeNodeRecursion> {
753                if let LogicalPlan::TableScan(scan) = node {
754                    self.table_refs.insert(scan.table_name.clone());
755                }
756                Ok(TreeNodeRecursion::Continue)
757            }
758        }
759        let mut table_refs = CollectTableRef {
760            table_refs: HashSet::new(),
761        };
762        query
763            .visit_with_subqueries(&mut table_refs)
764            .context(DatafusionSnafu {
765                context: "Checking if all source tables are TQL tables",
766            })?;
767
768        let default_catalog = query_ctx.current_catalog();
769        let default_schema = query_ctx.current_schema();
770        let default_schema = &default_schema;
771
772        for table_ref in table_refs.table_refs {
773            let table_ref = match &table_ref {
774                datafusion_common::TableReference::Bare { table } => {
775                    TableReference::full(default_catalog, default_schema, table)
776                }
777                datafusion_common::TableReference::Partial { schema, table } => {
778                    TableReference::full(default_catalog, schema, table)
779                }
780                datafusion_common::TableReference::Full {
781                    catalog,
782                    schema,
783                    table,
784                } => TableReference::full(catalog, schema, table),
785            };
786
787            let table_id = self
788                .table_meta
789                .table_name_manager()
790                .get(table_ref.into())
791                .await
792                .map_err(BoxedError::new)
793                .context(ExternalSnafu)?
794                .with_context(|| UnexpectedSnafu {
795                    reason: format!("Failed to get table id for table: {}", table_ref),
796                })?
797                .table_id();
798            let table_info =
799                get_table_info(self.table_meta.table_info_manager(), &table_id).await?;
800            // first check if it's only one f64 value column
801            let value_cols = table_info
802                .table_info
803                .meta
804                .schema
805                .column_schemas()
806                .iter()
807                .filter(|col| col.data_type == ConcreteDataType::float64_datatype())
808                .collect::<Vec<_>>();
809            ensure!(
810                value_cols.len() == 1,
811                InvalidQuerySnafu {
812                    reason: format!(
813                        "TQL query only supports one f64 value column, table `{}`(id={}) has {} f64 value columns, columns are: {:?}",
814                        table_ref,
815                        table_id,
816                        value_cols.len(),
817                        value_cols
818                    ),
819                }
820            );
821            // TODO(discord9): do need to check rest columns is string and is tag column?
822            let pk_idxs = table_info
823                .table_info
824                .meta
825                .primary_key_indices
826                .iter()
827                .collect::<HashSet<_>>();
828
829            for (idx, col) in table_info
830                .table_info
831                .meta
832                .schema
833                .column_schemas()
834                .iter()
835                .enumerate()
836            {
837                if is_metric_engine_internal_column(&col.name) {
838                    continue;
839                }
840                // three cases:
841                // 1. val column
842                // 2. timestamp column
843                // 3. tag column (string)
844
845                let is_pk: bool = pk_idxs.contains(&&idx);
846
847                ensure!(
848                    col.data_type == ConcreteDataType::float64_datatype()
849                        || col.data_type.is_timestamp()
850                        || (col.data_type == ConcreteDataType::string_datatype() && is_pk),
851                    InvalidQuerySnafu {
852                        reason: format!(
853                            "TQL query only supports f64 value column, timestamp column and string tag columns, table `{}`(id={}) has column `{}` with type {:?} which is not supported",
854                            table_ref, table_id, col.name, col.data_type
855                        ),
856                    }
857                );
858            }
859        }
860        Ok(())
861    }
862
863    pub async fn remove_flow_inner(&self, flow_id: FlowId) -> Result<(), Error> {
864        let (task, shutdown_tx) = {
865            let mut runtime = self.runtime.write().await;
866            let Some((task, shutdown_tx)) = runtime.remove(flow_id) else {
867                warn!("Flow {flow_id} not found in tasks");
868                FlowNotFoundSnafu { id: flow_id }.fail()?
869            };
870            (task, shutdown_tx)
871        };
872
873        let had_shutdown_tx = notify_flow_shutdown(flow_id, shutdown_tx, "removed");
874        abort_flow_task(flow_id, Some(task), "removed");
875
876        if !had_shutdown_tx {
877            UnexpectedSnafu {
878                reason: format!("Can't found shutdown tx for flow {flow_id}"),
879            }
880            .fail()?
881        }
882
883        Ok(())
884    }
885
886    /// Only flush the dirty windows of the flow task with given flow id, by running the query on it.
887    /// As flush the whole time range is usually prohibitively expensive.
888    pub async fn flush_flow_inner(&self, flow_id: FlowId) -> Result<usize, Error> {
889        debug!("Try flush flow {flow_id}");
890        // need to wait a bit to ensure previous mirror insert is handled
891        // this is only useful for the case when we are flushing the flow right after inserting data into it
892        // TODO(discord9): find a better way to ensure the data is ready, maybe inform flownode from frontend?
893        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
894        let task = self.runtime.read().await.tasks.get(&flow_id).cloned();
895        let task = task.with_context(|| FlowNotFoundSnafu { id: flow_id })?;
896
897        let time_window_size = task
898            .config
899            .time_window_expr
900            .as_ref()
901            .and_then(|expr| *expr.time_window_size());
902
903        let cur_dirty_window_cnt = time_window_size.map(|time_window_size| {
904            task.state
905                .read()
906                .unwrap()
907                .dirty_time_windows
908                .effective_count(&time_window_size)
909        });
910
911        let res = task
912            .execute_once_serialized(
913                &self.query_engine,
914                &self.frontend_client,
915                cur_dirty_window_cnt,
916            )
917            .await?;
918
919        let affected_rows = res.map(|(r, _)| r).unwrap_or_default();
920        debug!(
921            "Successfully flush flow {flow_id}, affected rows={}",
922            affected_rows
923        );
924        Ok(affected_rows)
925    }
926
927    /// Determine if the batching mode flow task exists with given flow id
928    pub async fn flow_exist_inner(&self, flow_id: FlowId) -> bool {
929        self.runtime.read().await.tasks.contains_key(&flow_id)
930    }
931
932    async fn rollback_flow_runtime_if_current(&self, flow_id: FlowId, task: &BatchingTask) {
933        let (removed_task, removed_shutdown_tx) = {
934            let mut runtime = self.runtime.write().await;
935            runtime.remove_if_current(flow_id, task)
936        };
937
938        notify_flow_shutdown(flow_id, removed_shutdown_tx, "rolled back");
939        abort_flow_task(flow_id, removed_task, "rolled back");
940    }
941}
942
943fn notify_flow_shutdown(flow_id: FlowId, tx: Option<oneshot::Sender<()>>, action: &str) -> bool {
944    let Some(tx) = tx else {
945        return false;
946    };
947
948    if tx.send(()).is_err() {
949        warn!(
950            "Fail to shutdown {action} flow {flow_id} due to receiver already dropped, maybe flow {flow_id} is already dropped?"
951        );
952    }
953
954    true
955}
956
957fn abort_flow_task(flow_id: FlowId, task: Option<BatchingTask>, action: &str) -> bool {
958    let Some(task) = task else {
959        return false;
960    };
961
962    if let Some(handle) = task.state.write().unwrap().task_handle.take() {
963        handle.abort();
964        debug!("Aborted {action} flow task {flow_id}");
965        return true;
966    }
967
968    false
969}
970
971impl FlowEngine for BatchingEngine {
972    async fn create_flow(&self, args: CreateFlowArgs) -> Result<Option<FlowId>, Error> {
973        self.create_flow_inner(args).await
974    }
975    async fn remove_flow(&self, flow_id: FlowId) -> Result<(), Error> {
976        self.remove_flow_inner(flow_id).await
977    }
978    async fn flush_flow(&self, flow_id: FlowId) -> Result<usize, Error> {
979        self.flush_flow_inner(flow_id).await
980    }
981    async fn flow_exist(&self, flow_id: FlowId) -> Result<bool, Error> {
982        Ok(self.flow_exist_inner(flow_id).await)
983    }
984    async fn list_flows(&self) -> Result<impl IntoIterator<Item = FlowId>, Error> {
985        Ok(self
986            .runtime
987            .read()
988            .await
989            .tasks
990            .keys()
991            .cloned()
992            .collect::<Vec<_>>())
993    }
994    async fn handle_flow_inserts(
995        &self,
996        request: api::v1::region::InsertRequests,
997    ) -> Result<(), Error> {
998        self.handle_inserts_inner(request).await
999    }
1000    async fn handle_mark_window_dirty(
1001        &self,
1002        req: api::v1::flow::DirtyWindowRequests,
1003    ) -> Result<(), Error> {
1004        self.handle_mark_dirty_time_window(req).await
1005    }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use catalog::memory::new_memory_catalog_manager;
1011    use common_meta::key::TableMetadataManager;
1012    use common_meta::key::flow::FlowMetadataManager;
1013    use common_meta::kv_backend::memory::MemoryKvBackend;
1014    use query::options::QueryOptions;
1015    use session::context::QueryContext;
1016
1017    use super::*;
1018    use crate::test_utils::create_test_query_engine;
1019
1020    struct DropNotify(Option<oneshot::Sender<()>>);
1021
1022    impl Drop for DropNotify {
1023        fn drop(&mut self) {
1024            if let Some(tx) = self.0.take() {
1025                let _ = tx.send(());
1026            }
1027        }
1028    }
1029
1030    async fn new_test_engine() -> BatchingEngine {
1031        let kv_backend = Arc::new(MemoryKvBackend::new());
1032        let table_meta = Arc::new(TableMetadataManager::new(kv_backend.clone()));
1033        table_meta.init().await.unwrap();
1034        let flow_meta = Arc::new(FlowMetadataManager::new(kv_backend));
1035        let catalog_manager = new_memory_catalog_manager().unwrap();
1036        let query_engine = create_test_query_engine();
1037        let (frontend_client, _handler) =
1038            FrontendClient::from_empty_grpc_handler(QueryOptions::default());
1039
1040        BatchingEngine::new(
1041            Arc::new(frontend_client),
1042            query_engine,
1043            flow_meta,
1044            table_meta,
1045            catalog_manager,
1046            BatchingModeOptions::default(),
1047        )
1048    }
1049
1050    #[tokio::test]
1051    async fn test_flow_option_overrides_incremental_read_switch() {
1052        let engine = new_test_engine().await;
1053
1054        let default_opts = engine.batch_opts_for_flow_options(&HashMap::new()).unwrap();
1055        assert!(!default_opts.experimental_enable_incremental_read);
1056
1057        let enabled_opts = engine
1058            .batch_opts_for_flow_options(&HashMap::from([(
1059                FLOW_EXPERIMENTAL_ENABLE_INCREMENTAL_READ_KEY.to_string(),
1060                "true".to_string(),
1061            )]))
1062            .unwrap();
1063        assert!(enabled_opts.experimental_enable_incremental_read);
1064    }
1065
1066    #[test]
1067    fn test_table_options_enable_append_mode() {
1068        assert!(!BatchingEngine::table_options_enable_append_mode(
1069            &HashMap::new()
1070        ));
1071        assert!(!BatchingEngine::table_options_enable_append_mode(
1072            &HashMap::from([(APPEND_MODE_KEY.to_string(), "false".to_string())])
1073        ));
1074        assert!(BatchingEngine::table_options_enable_append_mode(
1075            &HashMap::from([(APPEND_MODE_KEY.to_string(), "TRUE".to_string())])
1076        ));
1077    }
1078
1079    #[test]
1080    fn test_sql_flow_requires_time_window_or_eval_interval() {
1081        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, true)
1082            .expect("SQL flow with a time-window expression should be accepted");
1083        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(Some(10), false).expect(
1084            "SQL flow with EVAL INTERVAL should be accepted as an explicit full-query flow",
1085        );
1086
1087        let err = BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, false)
1088            .expect_err("SQL flow without a time-window expression or EVAL INTERVAL should fail");
1089        assert!(matches!(err, Error::InvalidQuery { .. }), "{err}");
1090        assert!(
1091            err.to_string().contains("must specify EVAL INTERVAL"),
1092            "{err}"
1093        );
1094    }
1095
1096    #[tokio::test]
1097    async fn test_complex_sql_without_eval_interval_is_rejected_as_no_twe() {
1098        let query_engine = create_test_query_engine();
1099        let ctx = QueryContext::arc();
1100        let plan = sql_to_df_plan(
1101            ctx.clone(),
1102            query_engine.clone(),
1103            r#"
1104SELECT
1105    l.number,
1106    date_bin('5 minutes', l.ts) AS time_window
1107FROM numbers_with_ts l
1108JOIN numbers_with_ts r ON l.number = r.number
1109GROUP BY l.number, time_window
1110"#,
1111            true,
1112        )
1113        .await
1114        .unwrap();
1115
1116        let (_, time_window_expr, _, _) = find_time_window_expr(
1117            &plan,
1118            query_engine.engine_state().catalog_manager().clone(),
1119            ctx,
1120        )
1121        .await
1122        .unwrap();
1123        assert!(
1124            time_window_expr.is_none(),
1125            "complex SQL should be classified as having no safe TWE"
1126        );
1127
1128        BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(Some(10), false)
1129            .expect("complex SQL can run as an explicit full-query flow when EVAL INTERVAL is set");
1130        let err = BatchingEngine::ensure_sql_flow_has_twe_or_eval_interval(None, false)
1131            .expect_err("complex SQL without EVAL INTERVAL should fail creation");
1132        assert!(matches!(err, Error::InvalidQuery { .. }), "{err}");
1133    }
1134
1135    #[test]
1136    fn test_incremental_source_append_only_enforcement() {
1137        let table_name = [
1138            "greptime".to_string(),
1139            "public".to_string(),
1140            "numbers".to_string(),
1141        ];
1142        let disabled_opts = BatchingModeOptions::default();
1143        let enabled_opts = BatchingModeOptions {
1144            experimental_enable_incremental_read: true,
1145            ..Default::default()
1146        };
1147        let non_append_options = HashMap::new();
1148        let append_options = HashMap::from([(APPEND_MODE_KEY.to_string(), "true".to_string())]);
1149
1150        BatchingEngine::ensure_incremental_source_append_only(
1151            &disabled_opts,
1152            &table_name,
1153            &non_append_options,
1154        )
1155        .expect("disabled incremental read should not require append-only source");
1156        BatchingEngine::ensure_incremental_source_append_only(
1157            &enabled_opts,
1158            &table_name,
1159            &append_options,
1160        )
1161        .expect("append-only source should be accepted when incremental read is enabled");
1162
1163        let err = BatchingEngine::ensure_incremental_source_append_only(
1164            &enabled_opts,
1165            &table_name,
1166            &non_append_options,
1167        )
1168        .expect_err("non-append source should be rejected when incremental read is enabled");
1169        assert!(
1170            err.to_string()
1171                .contains("Flow incremental read requires append-only source table"),
1172            "{err}"
1173        );
1174    }
1175
1176    async fn new_test_task(flow_id: FlowId) -> (BatchingTask, oneshot::Sender<()>) {
1177        let query_engine = create_test_query_engine();
1178        let ctx = QueryContext::arc();
1179        let plan = sql_to_df_plan(
1180            ctx.clone(),
1181            query_engine.clone(),
1182            "SELECT number, ts FROM numbers_with_ts",
1183            true,
1184        )
1185        .await
1186        .unwrap();
1187        let (tx, rx) = oneshot::channel();
1188
1189        let task = BatchingTask::try_new(TaskArgs {
1190            flow_id,
1191            query: "SELECT number, ts FROM numbers_with_ts",
1192            plan,
1193            time_window_expr: None,
1194            expire_after: None,
1195            sink_table_name: [
1196                "greptime".to_string(),
1197                "public".to_string(),
1198                "sink".to_string(),
1199            ],
1200            source_table_names: vec![[
1201                "greptime".to_string(),
1202                "public".to_string(),
1203                "numbers_with_ts".to_string(),
1204            ]],
1205            query_ctx: ctx,
1206            catalog_manager: query_engine.engine_state().catalog_manager().clone(),
1207            shutdown_rx: rx,
1208            batch_opts: Arc::new(BatchingModeOptions::default()),
1209            flow_eval_interval: None,
1210            eval_schedule: None,
1211        })
1212        .unwrap();
1213
1214        (task, tx)
1215    }
1216
1217    async fn install_abort_observed_handle(task: &BatchingTask) -> oneshot::Receiver<()> {
1218        let (drop_tx, drop_rx) = oneshot::channel();
1219        let (entered_tx, entered_rx) = oneshot::channel();
1220        let handle = tokio::spawn(async move {
1221            let _guard = DropNotify(Some(drop_tx));
1222            let _ = entered_tx.send(());
1223            std::future::pending::<()>().await;
1224        });
1225        task.state.write().unwrap().task_handle = Some(handle);
1226        tokio::time::timeout(Duration::from_secs(1), entered_rx)
1227            .await
1228            .expect("test task handle should start")
1229            .expect("test task handle should report start");
1230        drop_rx
1231    }
1232
1233    #[tokio::test]
1234    async fn test_notify_flow_shutdown_sends_signal() {
1235        let (tx, rx) = oneshot::channel();
1236
1237        assert!(notify_flow_shutdown(42, Some(tx), "test"));
1238
1239        rx.await.expect("replaced flow should receive shutdown");
1240    }
1241
1242    #[test]
1243    fn test_notify_flow_shutdown_accepts_missing_sender() {
1244        assert!(!notify_flow_shutdown(42, None, "test"));
1245    }
1246
1247    #[tokio::test]
1248    async fn test_abort_flow_task_aborts_handle() {
1249        let (task, _shutdown_tx) = new_test_task(42).await;
1250        let drop_rx = install_abort_observed_handle(&task).await;
1251
1252        assert!(abort_flow_task(42, Some(task), "test"));
1253
1254        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1255            .await
1256            .expect("aborted task should be dropped")
1257            .expect("drop notifier should fire");
1258    }
1259
1260    #[tokio::test]
1261    async fn test_remove_flow_inner_aborts_registered_task() {
1262        let engine = new_test_engine().await;
1263        let (task, shutdown_tx) = new_test_task(42).await;
1264        let drop_rx = install_abort_observed_handle(&task).await;
1265
1266        engine.runtime.write().await.insert(42, task, shutdown_tx);
1267
1268        engine.remove_flow_inner(42).await.unwrap();
1269
1270        tokio::time::timeout(Duration::from_secs(1), drop_rx)
1271            .await
1272            .expect("removed task should be dropped")
1273            .expect("drop notifier should fire");
1274        assert!(!engine.flow_exist_inner(42).await);
1275        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1276    }
1277
1278    #[tokio::test]
1279    async fn test_or_replace_flow_runtime_replaces_old_handles_and_keeps_new_task() {
1280        let engine = new_test_engine().await;
1281        let (old_task, old_shutdown_tx) = new_test_task(42).await;
1282        let old_task_identity = old_task.clone();
1283        let old_drop_rx = install_abort_observed_handle(&old_task).await;
1284        let (new_task, new_shutdown_tx) = new_test_task(42).await;
1285        let new_task_identity = new_task.clone();
1286
1287        engine
1288            .runtime
1289            .write()
1290            .await
1291            .insert(42, old_task, old_shutdown_tx);
1292        let (replaced_old_task, replaced_old_shutdown_tx) =
1293            engine
1294                .runtime
1295                .write()
1296                .await
1297                .insert(42, new_task, new_shutdown_tx);
1298
1299        let replaced_old_task = replaced_old_task.expect("old task should be returned");
1300        assert!(Arc::ptr_eq(
1301            &replaced_old_task.state,
1302            &old_task_identity.state
1303        ));
1304        assert!(notify_flow_shutdown(
1305            42,
1306            replaced_old_shutdown_tx,
1307            "replaced"
1308        ));
1309        old_task_identity
1310            .state
1311            .write()
1312            .unwrap()
1313            .shutdown_rx
1314            .try_recv()
1315            .expect("old shutdown receiver should receive signal");
1316        assert!(abort_flow_task(42, Some(replaced_old_task), "replaced"));
1317
1318        tokio::time::timeout(Duration::from_secs(1), old_drop_rx)
1319            .await
1320            .expect("replaced task should be dropped")
1321            .expect("drop notifier should fire");
1322
1323        let runtime = engine.runtime.read().await;
1324        assert_eq!(1, runtime.tasks.len());
1325        assert_eq!(1, runtime.shutdown_txs.len());
1326        let registered_task = runtime.tasks.get(&42).expect("new task should remain");
1327        assert!(Arc::ptr_eq(
1328            &registered_task.state,
1329            &new_task_identity.state
1330        ));
1331        assert!(runtime.shutdown_txs.contains_key(&42));
1332        assert!(matches!(
1333            new_task_identity
1334                .state
1335                .write()
1336                .unwrap()
1337                .shutdown_rx
1338                .try_recv(),
1339            Err(oneshot::error::TryRecvError::Empty)
1340        ));
1341    }
1342
1343    #[tokio::test]
1344    async fn test_rollback_flow_runtime_if_current_removes_matching_task_only() {
1345        let engine = new_test_engine().await;
1346        let (old_task, _old_shutdown_tx) = new_test_task(42).await;
1347        let (current_task, current_shutdown_tx) = new_test_task(42).await;
1348        let current_task_identity = current_task.clone();
1349
1350        engine
1351            .runtime
1352            .write()
1353            .await
1354            .insert(42, current_task, current_shutdown_tx);
1355
1356        engine.rollback_flow_runtime_if_current(42, &old_task).await;
1357
1358        let registered_task = engine.runtime.read().await.tasks.get(&42).cloned().unwrap();
1359        assert!(Arc::ptr_eq(
1360            &registered_task.state,
1361            &current_task_identity.state
1362        ));
1363        assert!(engine.runtime.read().await.shutdown_txs.contains_key(&42));
1364
1365        engine
1366            .rollback_flow_runtime_if_current(42, &current_task_identity)
1367            .await;
1368        assert!(!engine.flow_exist_inner(42).await);
1369        assert!(!engine.runtime.read().await.shutdown_txs.contains_key(&42));
1370    }
1371}