Skip to main content

query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod json_expr_planner;
19mod planner;
20
21use std::any::Any;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use async_trait::async_trait;
26use common_base::Plugins;
27use common_catalog::consts::is_readonly_schema;
28use common_error::ext::BoxedError;
29use common_function::function::FunctionContext;
30use common_function::function_factory::ScalarFunctionFactory;
31use common_query::{Output, OutputData, OutputMeta};
32use common_recordbatch::adapter::RecordBatchStreamAdapter;
33use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
34use common_telemetry::tracing;
35use datafusion::catalog::TableFunction;
36use datafusion::dataframe::DataFrame;
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion::physical_plan::analyze::AnalyzeExec;
39use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
40use datafusion_common::ResolvedTableReference;
41use datafusion_expr::{
42    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
43};
44use datatypes::prelude::VectorRef;
45use datatypes::schema::Schema;
46use futures_util::StreamExt;
47use session::context::QueryContextRef;
48use snafu::{OptionExt, ResultExt, ensure};
49use sqlparser::ast::AnalyzeFormat;
50use table::TableRef;
51use table::requests::{DeleteRequest, InsertRequest};
52use table::table::scan::{REGION_SCAN_EXEC_NAME, RegionScanExec};
53use tracing::Span;
54
55use crate::analyze::DistAnalyzeExec;
56pub use crate::datafusion::planner::DfContextProviderAdapter;
57use crate::dist_plan::{
58    DistPlannerOptions, MergeScanLogicalPlan, RemoteDynFilterReceiverInjectorRef,
59};
60use crate::error::{
61    CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
62    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
63    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
64};
65use crate::executor::QueryExecutor;
66use crate::metrics::{
67    OnDone, QUERY_STAGE_ELAPSED, maybe_attach_region_watermark_metrics,
68    should_collect_region_watermark_from_query_ctx,
69};
70use crate::physical_wrapper::PhysicalPlanWrapperRef;
71use crate::planner::{DfLogicalPlanner, LogicalPlanner};
72use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
73use crate::{QueryEngine, metrics};
74
75/// Query parallelism hint key.
76/// This hint can be set in the query context to control the parallelism of the query execution.
77pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
78
79/// Whether to fallback to the original plan when failed to push down.
80pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
81
82fn query_load_region_id(plan: &Arc<dyn ExecutionPlan>) -> Option<u64> {
83    let mut region_id = None;
84    let mut stack = vec![plan.clone()];
85
86    while let Some(plan) = stack.pop() {
87        if plan.name() == REGION_SCAN_EXEC_NAME
88            && let Some(scan) = plan.as_any().downcast_ref::<RegionScanExec>()
89            && let Some(scan_region_id) = scan.query_load_region_id()
90        {
91            match region_id {
92                Some(region_id) if region_id != scan_region_id => return None,
93                Some(_) => {}
94                None => region_id = Some(scan_region_id),
95            }
96        }
97        stack.extend(plan.children().into_iter().cloned());
98    }
99
100    region_id
101}
102
103pub struct DatafusionQueryEngine {
104    state: Arc<QueryEngineState>,
105    plugins: Plugins,
106}
107
108impl DatafusionQueryEngine {
109    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
110        Self { state, plugins }
111    }
112
113    #[tracing::instrument(skip_all)]
114    async fn exec_query_plan(
115        &self,
116        plan: LogicalPlan,
117        query_ctx: QueryContextRef,
118    ) -> Result<Output> {
119        let mut ctx = self.engine_context(query_ctx.clone());
120        let plan = if let Some(receiver_injector) =
121            self.plugins.get::<RemoteDynFilterReceiverInjectorRef>()
122        {
123            receiver_injector.maybe_inject(plan, query_ctx.clone())
124        } else {
125            plan
126        };
127
128        // `create_physical_plan` will optimize logical plan internally
129        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
130        let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
131        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
132            wrapper.wrap(physical_plan, query_ctx)
133        } else {
134            physical_plan
135        };
136
137        let stream = self.execute_stream(&ctx, &physical_plan)?;
138
139        Ok(Output::new(
140            OutputData::Stream(stream),
141            OutputMeta::new_with_plan(physical_plan),
142        ))
143    }
144
145    #[tracing::instrument(skip_all)]
146    async fn exec_dml_statement(
147        &self,
148        dml: DmlStatement,
149        query_ctx: QueryContextRef,
150    ) -> Result<Output> {
151        ensure!(
152            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
153            UnsupportedExprSnafu {
154                name: format!("DML op {}", dml.op),
155            }
156        );
157
158        let _timer = QUERY_STAGE_ELAPSED
159            .with_label_values(&[dml.op.name()])
160            .start_timer();
161
162        let default_catalog = &query_ctx.current_catalog().to_owned();
163        let default_schema = &query_ctx.current_schema();
164        let table_name = dml.table_name.resolve(default_catalog, default_schema);
165        let table = self.find_table(&table_name, &query_ctx).await?;
166
167        let Output { data, meta } = self
168            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
169            .await?;
170        let mut stream = match data {
171            OutputData::RecordBatches(batches) => batches.as_stream(),
172            OutputData::Stream(stream) => stream,
173            _ => unreachable!(),
174        };
175
176        let mut affected_rows = 0;
177        let mut insert_cost = 0;
178
179        while let Some(batch) = stream.next().await {
180            let batch = batch.context(CreateRecordBatchSnafu)?;
181            let column_vectors = batch
182                .column_vectors(&table_name.to_string(), table.schema())
183                .map_err(BoxedError::new)
184                .context(QueryExecutionSnafu)?;
185
186            match dml.op {
187                WriteOp::Insert(_) => {
188                    // We ignore the insert op.
189                    let output = self
190                        .insert(&table_name, column_vectors, query_ctx.clone())
191                        .await?;
192                    let (rows, cost) = output.extract_rows_and_cost();
193                    affected_rows += rows;
194                    insert_cost += cost;
195                }
196                WriteOp::Delete => {
197                    affected_rows += self
198                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
199                        .await?;
200                }
201                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
202            }
203        }
204        Ok(Output::new(
205            OutputData::AffectedRows(affected_rows),
206            OutputMeta::new(meta.plan, insert_cost),
207        ))
208    }
209
210    #[tracing::instrument(skip_all)]
211    async fn delete(
212        &self,
213        table_name: &ResolvedTableReference,
214        table: &TableRef,
215        column_vectors: HashMap<String, VectorRef>,
216        query_ctx: QueryContextRef,
217    ) -> Result<usize> {
218        let catalog_name = table_name.catalog.to_string();
219        let schema_name = table_name.schema.to_string();
220        let table_name = table_name.table.to_string();
221        let table_schema = table.schema();
222
223        ensure!(
224            !is_readonly_schema(&schema_name),
225            TableReadOnlySnafu { table: table_name }
226        );
227
228        let ts_column = table_schema
229            .timestamp_column()
230            .map(|x| &x.name)
231            .with_context(|| MissingTimestampColumnSnafu {
232                table_name: table_name.clone(),
233            })?;
234
235        let table_info = table.table_info();
236        let rowkey_columns = table_info
237            .meta
238            .row_key_column_names()
239            .collect::<Vec<&String>>();
240        let column_vectors = column_vectors
241            .into_iter()
242            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
243            .collect::<HashMap<_, _>>();
244
245        let request = DeleteRequest {
246            catalog_name,
247            schema_name,
248            table_name,
249            key_column_values: column_vectors,
250        };
251
252        self.state
253            .table_mutation_handler()
254            .context(MissingTableMutationHandlerSnafu)?
255            .delete(request, query_ctx)
256            .await
257            .context(TableMutationSnafu)
258    }
259
260    #[tracing::instrument(skip_all)]
261    async fn insert(
262        &self,
263        table_name: &ResolvedTableReference,
264        column_vectors: HashMap<String, VectorRef>,
265        query_ctx: QueryContextRef,
266    ) -> Result<Output> {
267        let catalog_name = table_name.catalog.to_string();
268        let schema_name = table_name.schema.to_string();
269        let table_name = table_name.table.to_string();
270
271        ensure!(
272            !is_readonly_schema(&schema_name),
273            TableReadOnlySnafu { table: table_name }
274        );
275
276        let request = InsertRequest {
277            catalog_name,
278            schema_name,
279            table_name,
280            columns_values: column_vectors,
281        };
282
283        self.state
284            .table_mutation_handler()
285            .context(MissingTableMutationHandlerSnafu)?
286            .insert(request, query_ctx)
287            .await
288            .context(TableMutationSnafu)
289    }
290
291    async fn find_table(
292        &self,
293        table_name: &ResolvedTableReference,
294        query_context: &QueryContextRef,
295    ) -> Result<TableRef> {
296        let catalog_name = table_name.catalog.as_ref();
297        let schema_name = table_name.schema.as_ref();
298        let table_name = table_name.table.as_ref();
299
300        self.state
301            .catalog_manager()
302            .table(catalog_name, schema_name, table_name, Some(query_context))
303            .await
304            .context(CatalogSnafu)?
305            .with_context(|| TableNotFoundSnafu { table: table_name })
306    }
307
308    #[tracing::instrument(skip_all)]
309    async fn create_physical_plan(
310        &self,
311        ctx: &mut QueryEngineContext,
312        logical_plan: &LogicalPlan,
313    ) -> Result<Arc<dyn ExecutionPlan>> {
314        /// Only print context on panic, to avoid cluttering logs.
315        ///
316        /// TODO(discord9): remove this once we catch the bug
317        #[derive(Debug)]
318        struct PanicLogger<'a> {
319            input_logical_plan: &'a LogicalPlan,
320            after_analyze: Option<LogicalPlan>,
321            after_optimize: Option<LogicalPlan>,
322            phy_plan: Option<Arc<dyn ExecutionPlan>>,
323        }
324        impl Drop for PanicLogger<'_> {
325            fn drop(&mut self) {
326                if std::thread::panicking() {
327                    common_telemetry::error!(
328                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
329                        self.input_logical_plan,
330                        self.after_analyze,
331                        self.after_optimize,
332                        self.phy_plan
333                    );
334                }
335            }
336        }
337
338        let mut logger = PanicLogger {
339            input_logical_plan: logical_plan,
340            after_analyze: None,
341            after_optimize: None,
342            phy_plan: None,
343        };
344
345        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
346        let state = ctx.state();
347
348        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
349
350        // special handle EXPLAIN plan
351        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
352            return state
353                .create_physical_plan(logical_plan)
354                .await
355                .map_err(Into::into);
356        }
357
358        // analyze first
359        let analyzed_plan = state.analyzer().execute_and_check(
360            logical_plan.clone(),
361            state.config_options(),
362            |_, _| {},
363        )?;
364
365        logger.after_analyze = Some(analyzed_plan.clone());
366
367        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
368
369        // skip optimize for MergeScan
370        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
371            && ext.node.name() == MergeScanLogicalPlan::name()
372        {
373            analyzed_plan.clone()
374        } else {
375            state
376                .optimizer()
377                .optimize(analyzed_plan, state, |_, _| {})?
378        };
379
380        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
381        logger.after_optimize = Some(optimized_plan.clone());
382
383        let physical_plan = state
384            .query_planner()
385            .create_physical_plan(&optimized_plan, state)
386            .await?;
387
388        logger.phy_plan = Some(physical_plan.clone());
389        drop(logger);
390        Ok(physical_plan)
391    }
392
393    #[tracing::instrument(skip_all)]
394    fn optimize_physical_plan(
395        &self,
396        ctx: &mut QueryEngineContext,
397        plan: Arc<dyn ExecutionPlan>,
398    ) -> Result<Arc<dyn ExecutionPlan>> {
399        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
400
401        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
402        // if we need to optimize it again here.
403        // let state = ctx.state();
404        // let config = state.config_options();
405
406        // skip optimize AnalyzeExec plan
407        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
408        {
409            let format = if let Some(format) = ctx.query_ctx().explain_format()
410                && format.to_lowercase() == "json"
411            {
412                AnalyzeFormat::JSON
413            } else {
414                AnalyzeFormat::TEXT
415            };
416            // Sets the verbose flag of the query context.
417            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
418            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
419
420            Arc::new(DistAnalyzeExec::new(
421                analyze_plan.input().clone(),
422                analyze_plan.verbose(),
423                format,
424            ))
425            // let mut new_plan = analyze_plan.input().clone();
426            // for optimizer in state.physical_optimizers() {
427            //     new_plan = optimizer
428            //         .optimize(new_plan, config)
429            //         .context(DataFusionSnafu)?;
430            // }
431            // Arc::new(DistAnalyzeExec::new(new_plan))
432        } else {
433            plan
434            // let mut new_plan = plan;
435            // for optimizer in state.physical_optimizers() {
436            //     new_plan = optimizer
437            //         .optimize(new_plan, config)
438            //         .context(DataFusionSnafu)?;
439            // }
440            // new_plan
441        };
442
443        Ok(optimized_plan)
444    }
445}
446
447#[async_trait]
448impl QueryEngine for DatafusionQueryEngine {
449    fn as_any(&self) -> &dyn Any {
450        self
451    }
452
453    fn planner(&self) -> Arc<dyn LogicalPlanner> {
454        Arc::new(DfLogicalPlanner::new(self.state.clone()))
455    }
456
457    fn name(&self) -> &str {
458        "datafusion"
459    }
460
461    async fn describe(
462        &self,
463        plan: LogicalPlan,
464        _query_ctx: QueryContextRef,
465    ) -> Result<DescribeResult> {
466        Ok(DescribeResult { logical_plan: plan })
467    }
468
469    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
470        match plan {
471            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
472            _ => self.exec_query_plan(plan, query_ctx).await,
473        }
474    }
475
476    /// Note in SQL queries, aggregate names are looked up using
477    /// lowercase unless the query uses quotes. For example,
478    ///
479    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
480    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
481    ///
482    /// So it's better to make UDAF name lowercase when creating one.
483    fn register_aggregate_function(&self, func: AggregateUDF) {
484        self.state.register_aggr_function(func);
485    }
486
487    /// Register an scalar function.
488    /// Will override if the function with same name is already registered.
489    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
490        self.state.register_scalar_function(func);
491    }
492
493    fn register_table_function(&self, func: Arc<TableFunction>) {
494        self.state.register_table_function(func);
495    }
496
497    fn register_window_function(&self, func: WindowUDF) {
498        self.state.register_window_function(func);
499    }
500
501    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
502        self.state.read_table(table).map_err(Into::into)
503    }
504
505    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
506        let mut state = self.state.session_state();
507        state.config_mut().set_extension(query_ctx.clone());
508        state.config_mut().set_extension(self.state.clone());
509        // note that hints in "x-greptime-hints" is automatically parsed
510        // and set to query context's extension, so we can get it from query context.
511        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
512            if let Ok(n) = parallelism.parse::<u64>() {
513                if n > 0 {
514                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
515                    *state.config_mut() = new_cfg;
516                }
517            } else {
518                common_telemetry::warn!(
519                    "Failed to parse query_parallelism: {}, using default value",
520                    parallelism
521                );
522            }
523        }
524
525        // configure execution options
526        state.config_mut().options_mut().execution.time_zone =
527            Some(query_ctx.timezone().to_string());
528
529        // usually it's impossible to have both `set variable` set by sql client and
530        // hint in header by grpc client, so only need to deal with them separately
531        if query_ctx.configuration_parameter().allow_query_fallback() {
532            state
533                .config_mut()
534                .options_mut()
535                .extensions
536                .insert(DistPlannerOptions {
537                    allow_query_fallback: true,
538                });
539        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
540            // also check the query context for fallback hint
541            // if it is set, we will enable the fallback
542            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
543                state
544                    .config_mut()
545                    .options_mut()
546                    .extensions
547                    .insert(DistPlannerOptions {
548                        allow_query_fallback: true,
549                    });
550            }
551        }
552
553        state
554            .config_mut()
555            .options_mut()
556            .extensions
557            .insert(FunctionContext {
558                query_ctx: query_ctx.clone(),
559                state: self.engine_state().function_state(),
560            });
561
562        let config_options = state.config_options().clone();
563        let _ = state
564            .execution_props_mut()
565            .config_options
566            .insert(config_options);
567
568        // Apply scheduled time from query context if present, so that `now()` /
569        // `current_timestamp()` functions evaluate against the logical scheduled time
570        // rather than wall-clock.
571        match crate::options::parse_scheduled_time_datetime(&query_ctx.extensions()) {
572            Ok(Some(scheduled_rt)) => {
573                state.execution_props_mut().query_execution_start_time = Some(scheduled_rt);
574            }
575            Ok(None) => {}
576            Err(err) => {
577                common_telemetry::warn!(err; "Ignoring invalid scheduled time query extension");
578            }
579        }
580
581        QueryEngineContext::new(state, query_ctx)
582    }
583
584    fn engine_state(&self) -> &QueryEngineState {
585        &self.state
586    }
587}
588
589impl QueryExecutor for DatafusionQueryEngine {
590    #[tracing::instrument(skip_all)]
591    fn execute_stream(
592        &self,
593        ctx: &QueryEngineContext,
594        plan: &Arc<dyn ExecutionPlan>,
595    ) -> Result<SendableRecordBatchStream> {
596        let query_ctx = ctx.query_ctx();
597        let explain_verbose = query_ctx.explain_verbose();
598        let should_collect_region_watermark =
599            should_collect_region_watermark_from_query_ctx(&query_ctx)?;
600        let output_partitions = plan.properties().output_partitioning().partition_count();
601        if explain_verbose {
602            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
603        }
604
605        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
606        let task_ctx = ctx.build_task_ctx();
607        let span = Span::current();
608
609        match plan.properties().output_partitioning().partition_count() {
610            0 => {
611                let schema = Arc::new(
612                    Schema::try_from(plan.schema())
613                        .map_err(BoxedError::new)
614                        .context(QueryExecutionSnafu)?,
615                );
616                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
617            }
618            1 => {
619                let df_stream = plan.execute(0, task_ctx)?;
620                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
621                    .context(error::ConvertDfRecordBatchStreamSnafu)
622                    .map_err(BoxedError::new)
623                    .context(QueryExecutionSnafu)?;
624                stream.set_metrics2(plan.clone());
625                stream.set_query_load_region_id(query_load_region_id(plan));
626                stream.set_explain_verbose(explain_verbose);
627                let stream = OnDone::new(Box::pin(stream), move || {
628                    let exec_cost = exec_timer.stop_and_record();
629                    if explain_verbose {
630                        common_telemetry::info!(
631                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
632                            exec_cost,
633                        );
634                    }
635                });
636                Ok(maybe_attach_region_watermark_metrics(
637                    Box::pin(stream),
638                    plan.clone(),
639                    should_collect_region_watermark,
640                ))
641            }
642            _ => {
643                // merge into a single partition
644                let merged_plan = CoalescePartitionsExec::new(plan.clone());
645                // CoalescePartitionsExec must produce a single partition
646                assert_eq!(
647                    1,
648                    merged_plan
649                        .properties()
650                        .output_partitioning()
651                        .partition_count()
652                );
653                let df_stream = merged_plan.execute(0, task_ctx)?;
654                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
655                    .context(error::ConvertDfRecordBatchStreamSnafu)
656                    .map_err(BoxedError::new)
657                    .context(QueryExecutionSnafu)?;
658                stream.set_metrics2(plan.clone());
659                stream.set_query_load_region_id(query_load_region_id(plan));
660                stream.set_explain_verbose(explain_verbose);
661                let stream = OnDone::new(Box::pin(stream), move || {
662                    let exec_cost = exec_timer.stop_and_record();
663                    if explain_verbose {
664                        common_telemetry::info!(
665                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
666                            exec_cost
667                        );
668                    }
669                });
670                Ok(maybe_attach_region_watermark_metrics(
671                    Box::pin(stream),
672                    plan.clone(),
673                    should_collect_region_watermark,
674                ))
675            }
676        }
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use std::fmt;
683    use std::sync::Arc;
684    use std::sync::atomic::{AtomicUsize, Ordering};
685
686    use api::v1::SemanticType;
687    use arrow::array::{ArrayRef, UInt64Array};
688    use arrow_schema::SortOptions;
689    use catalog::RegisterTableRequest;
690    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
691    use common_error::ext::BoxedError;
692    use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
693    use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
694    use datafusion::physical_plan::expressions::PhysicalSortExpr;
695    use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
696    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
697    use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
698    use datafusion::prelude::{col, lit};
699    use datafusion_common::{JoinType, NullEquality};
700    use datafusion_physical_expr::expressions::Column;
701    use datatypes::prelude::ConcreteDataType;
702    use datatypes::schema::{ColumnSchema, SchemaRef};
703    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
704    use session::context::{QueryContext, QueryContextBuilder};
705    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
706    use store_api::region_engine::{
707        PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
708    };
709    use store_api::storage::{RegionId, ScanRequest};
710    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
711    use table::table::scan::RegionScanExec;
712
713    use super::*;
714    use crate::options::QueryOptions;
715    use crate::parser::QueryLanguageParser;
716    use crate::part_sort::PartSortExec;
717    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
718
719    #[derive(Debug)]
720    struct RecordingScanner {
721        schema: SchemaRef,
722        metadata: RegionMetadataRef,
723        properties: ScannerProperties,
724        update_calls: Arc<AtomicUsize>,
725        last_filter_len: Arc<AtomicUsize>,
726    }
727
728    impl RecordingScanner {
729        fn new(
730            schema: SchemaRef,
731            metadata: RegionMetadataRef,
732            update_calls: Arc<AtomicUsize>,
733            last_filter_len: Arc<AtomicUsize>,
734        ) -> Self {
735            Self {
736                schema,
737                metadata,
738                properties: ScannerProperties::default(),
739                update_calls,
740                last_filter_len,
741            }
742        }
743    }
744
745    impl RegionScanner for RecordingScanner {
746        fn name(&self) -> &str {
747            "RecordingScanner"
748        }
749
750        fn properties(&self) -> &ScannerProperties {
751            &self.properties
752        }
753
754        fn schema(&self) -> SchemaRef {
755            self.schema.clone()
756        }
757
758        fn metadata(&self) -> RegionMetadataRef {
759            self.metadata.clone()
760        }
761
762        fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
763            self.properties.prepare(request);
764            Ok(())
765        }
766
767        fn scan_partition(
768            &self,
769            _ctx: &QueryScanContext,
770            _metrics_set: &ExecutionPlanMetricsSet,
771            _partition: usize,
772        ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
773            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
774        }
775
776        fn has_predicate_without_region(&self) -> bool {
777            true
778        }
779
780        fn add_dyn_filter_to_predicate(
781            &mut self,
782            filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
783        ) -> Vec<bool> {
784            self.update_calls.fetch_add(1, Ordering::Relaxed);
785            self.last_filter_len
786                .store(filter_exprs.len(), Ordering::Relaxed);
787            vec![true; filter_exprs.len()]
788        }
789
790        fn set_logical_region(&mut self, logical_region: bool) {
791            self.properties.set_logical_region(logical_region);
792        }
793
794        fn set_query_load_region_id(&mut self, region_id: store_api::storage::RegionId) {
795            self.properties.set_query_load_region_id(region_id);
796        }
797    }
798
799    impl DisplayAs for RecordingScanner {
800        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
801            write!(f, "RecordingScanner")
802        }
803    }
804
805    fn build_query_load_region_scan(
806        query_load_region_id: Option<RegionId>,
807    ) -> Arc<dyn ExecutionPlan> {
808        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
809            "ts",
810            ConcreteDataType::timestamp_millisecond_datatype(),
811            false,
812        )]));
813
814        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
815        metadata_builder
816            .push_column_metadata(ColumnMetadata {
817                column_schema: ColumnSchema::new(
818                    "ts",
819                    ConcreteDataType::timestamp_millisecond_datatype(),
820                    false,
821                )
822                .with_time_index(true),
823                semantic_type: SemanticType::Timestamp,
824                column_id: 1,
825            })
826            .primary_key(vec![]);
827        let metadata = Arc::new(metadata_builder.build().unwrap());
828        let mut scanner = RecordingScanner::new(
829            schema,
830            metadata,
831            Arc::new(AtomicUsize::new(0)),
832            Arc::new(AtomicUsize::new(0)),
833        );
834        if let Some(region_id) = query_load_region_id {
835            scanner.set_query_load_region_id(region_id);
836        }
837
838        Arc::new(RegionScanExec::new(Box::new(scanner), ScanRequest::default(), None).unwrap())
839    }
840
841    #[test]
842    fn query_load_region_id_ignores_scans_without_region_id() {
843        let query_load_region_id = RegionId::new(1024, 42);
844        let scan_without_region_id = build_query_load_region_scan(None);
845        let scan_with_region_id = build_query_load_region_scan(Some(query_load_region_id));
846        let on: JoinOn = vec![(
847            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
848            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
849        )];
850        let plan: Arc<dyn ExecutionPlan> = Arc::new(
851            HashJoinExec::try_new(
852                scan_without_region_id,
853                scan_with_region_id,
854                on,
855                None,
856                &JoinType::Inner,
857                None,
858                PartitionMode::CollectLeft,
859                NullEquality::NullEqualsNull,
860                false,
861            )
862            .unwrap(),
863        );
864
865        assert_eq!(
866            super::query_load_region_id(&plan),
867            Some(query_load_region_id.as_u64())
868        );
869    }
870
871    async fn create_test_engine() -> QueryEngineRef {
872        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
873        let req = RegisterTableRequest {
874            catalog: DEFAULT_CATALOG_NAME.to_string(),
875            schema: DEFAULT_SCHEMA_NAME.to_string(),
876            table_name: NUMBERS_TABLE_NAME.to_string(),
877            table_id: NUMBERS_TABLE_ID,
878            table: NumbersTable::table(NUMBERS_TABLE_ID),
879        };
880        catalog_manager.register_table_sync(req).unwrap();
881
882        QueryEngineFactory::new(
883            catalog_manager,
884            None,
885            None,
886            None,
887            None,
888            false,
889            QueryOptions::default(),
890        )
891        .query_engine()
892    }
893
894    #[tokio::test]
895    async fn test_sql_to_plan() {
896        let engine = create_test_engine().await;
897        let sql = "select sum(number) from numbers limit 20";
898
899        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
900        let plan = engine
901            .planner()
902            .plan(&stmt, QueryContext::arc())
903            .await
904            .unwrap();
905
906        assert_eq!(
907            plan.to_string(),
908            r#"Limit: skip=0, fetch=20
909  Projection: sum(numbers.number)
910    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
911      TableScan: numbers"#
912        );
913    }
914
915    #[tokio::test]
916    async fn test_execute() {
917        let engine = create_test_engine().await;
918        let sql = "select sum(number) from numbers limit 20";
919
920        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
921        let plan = engine
922            .planner()
923            .plan(&stmt, QueryContext::arc())
924            .await
925            .unwrap();
926
927        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
928
929        match output.data {
930            OutputData::Stream(recordbatch) => {
931                let numbers = util::collect(recordbatch).await.unwrap();
932                assert_eq!(1, numbers.len());
933                assert_eq!(numbers[0].num_columns(), 1);
934                assert_eq!(1, numbers[0].schema.num_columns());
935                assert_eq!(
936                    "sum(numbers.number)",
937                    numbers[0].schema.column_schemas()[0].name
938                );
939
940                let batch = &numbers[0];
941                assert_eq!(1, batch.num_columns());
942                assert_eq!(batch.column(0).len(), 1);
943
944                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
945                assert_eq!(batch.column(0), &expected);
946            }
947            _ => unreachable!(),
948        }
949    }
950
951    #[tokio::test]
952    async fn test_read_table() {
953        let engine = create_test_engine().await;
954
955        let engine = engine
956            .as_any()
957            .downcast_ref::<DatafusionQueryEngine>()
958            .unwrap();
959        let query_ctx = Arc::new(QueryContextBuilder::default().build());
960        let table = engine
961            .find_table(
962                &ResolvedTableReference {
963                    catalog: "greptime".into(),
964                    schema: "public".into(),
965                    table: "numbers".into(),
966                },
967                &query_ctx,
968            )
969            .await
970            .unwrap();
971
972        let df = engine.read_table(table).unwrap();
973        let df = df
974            .select_columns(&["number"])
975            .unwrap()
976            .filter(col("number").lt(lit(10)))
977            .unwrap();
978        let batches = df.collect().await.unwrap();
979        assert_eq!(1, batches.len());
980        let batch = &batches[0];
981
982        assert_eq!(1, batch.num_columns());
983        assert_eq!(batch.column(0).len(), 10);
984
985        assert_eq!(
986            Helper::try_into_vector(batch.column(0)).unwrap(),
987            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
988        );
989    }
990
991    #[tokio::test]
992    async fn test_describe() {
993        let engine = create_test_engine().await;
994        let sql = "select sum(number) from numbers limit 20";
995
996        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
997
998        let plan = engine
999            .planner()
1000            .plan(&stmt, QueryContext::arc())
1001            .await
1002            .unwrap();
1003
1004        let DescribeResult { logical_plan } =
1005            engine.describe(plan, QueryContext::arc()).await.unwrap();
1006
1007        let schema: Schema = logical_plan.schema().clone().try_into().unwrap();
1008
1009        assert_eq!(
1010            schema.column_schemas()[0],
1011            ColumnSchema::new(
1012                "sum(numbers.number)",
1013                ConcreteDataType::uint64_datatype(),
1014                true
1015            )
1016        );
1017        assert_eq!(
1018            "Limit: skip=0, fetch=20\n  Projection: sum(numbers.number)\n    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n      TableScan: numbers",
1019            format!("{}", logical_plan.display_indent())
1020        );
1021    }
1022
1023    #[tokio::test]
1024    async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
1025        let engine = create_test_engine().await;
1026        let engine = engine
1027            .as_any()
1028            .downcast_ref::<DatafusionQueryEngine>()
1029            .unwrap();
1030        let engine_ctx = engine.engine_context(QueryContext::arc());
1031        let state = engine_ctx.state();
1032
1033        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
1034            "ts",
1035            ConcreteDataType::timestamp_millisecond_datatype(),
1036            false,
1037        )]));
1038
1039        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
1040        metadata_builder
1041            .push_column_metadata(ColumnMetadata {
1042                column_schema: ColumnSchema::new(
1043                    "ts",
1044                    ConcreteDataType::timestamp_millisecond_datatype(),
1045                    false,
1046                )
1047                .with_time_index(true),
1048                semantic_type: SemanticType::Timestamp,
1049                column_id: 1,
1050            })
1051            .primary_key(vec![]);
1052        let metadata = Arc::new(metadata_builder.build().unwrap());
1053
1054        let update_calls = Arc::new(AtomicUsize::new(0));
1055        let last_filter_len = Arc::new(AtomicUsize::new(0));
1056        let scanner = Box::new(RecordingScanner::new(
1057            schema,
1058            metadata,
1059            update_calls.clone(),
1060            last_filter_len.clone(),
1061        ));
1062        let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
1063
1064        let sort_expr = PhysicalSortExpr {
1065            expr: Arc::new(Column::new("ts", 0)),
1066            options: SortOptions {
1067                descending: true,
1068                ..Default::default()
1069            },
1070        };
1071        let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
1072        let mut plan: Arc<dyn ExecutionPlan> =
1073            Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
1074
1075        for optimizer in state.physical_optimizers() {
1076            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1077        }
1078
1079        assert!(update_calls.load(Ordering::Relaxed) > 0);
1080        assert!(last_filter_len.load(Ordering::Relaxed) > 0);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
1085        let engine = create_test_engine().await;
1086        let engine = engine
1087            .as_any()
1088            .downcast_ref::<DatafusionQueryEngine>()
1089            .unwrap();
1090        let engine_ctx = engine.engine_context(QueryContext::arc());
1091        let state = engine_ctx.state();
1092
1093        assert!(
1094            state
1095                .config_options()
1096                .optimizer
1097                .enable_join_dynamic_filter_pushdown
1098        );
1099
1100        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
1101            "ts",
1102            ConcreteDataType::timestamp_millisecond_datatype(),
1103            false,
1104        )]));
1105
1106        let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
1107        left_metadata_builder
1108            .push_column_metadata(ColumnMetadata {
1109                column_schema: ColumnSchema::new(
1110                    "ts",
1111                    ConcreteDataType::timestamp_millisecond_datatype(),
1112                    false,
1113                )
1114                .with_time_index(true),
1115                semantic_type: SemanticType::Timestamp,
1116                column_id: 1,
1117            })
1118            .primary_key(vec![]);
1119        let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
1120
1121        let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
1122        right_metadata_builder
1123            .push_column_metadata(ColumnMetadata {
1124                column_schema: ColumnSchema::new(
1125                    "ts",
1126                    ConcreteDataType::timestamp_millisecond_datatype(),
1127                    false,
1128                )
1129                .with_time_index(true),
1130                semantic_type: SemanticType::Timestamp,
1131                column_id: 1,
1132            })
1133            .primary_key(vec![]);
1134        let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1135
1136        let left_update_calls = Arc::new(AtomicUsize::new(0));
1137        let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1138        let right_update_calls = Arc::new(AtomicUsize::new(0));
1139        let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1140
1141        let left_scan = Arc::new(
1142            RegionScanExec::new(
1143                Box::new(RecordingScanner::new(
1144                    schema.clone(),
1145                    left_metadata,
1146                    left_update_calls.clone(),
1147                    left_last_filter_len.clone(),
1148                )),
1149                ScanRequest::default(),
1150                None,
1151            )
1152            .unwrap(),
1153        );
1154        let right_scan = Arc::new(
1155            RegionScanExec::new(
1156                Box::new(RecordingScanner::new(
1157                    schema,
1158                    right_metadata,
1159                    right_update_calls.clone(),
1160                    right_last_filter_len.clone(),
1161                )),
1162                ScanRequest::default(),
1163                None,
1164            )
1165            .unwrap(),
1166        );
1167
1168        let on: JoinOn = vec![(
1169            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1170            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1171        )];
1172
1173        let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1174            HashJoinExec::try_new(
1175                left_scan,
1176                right_scan,
1177                on,
1178                None,
1179                &JoinType::Inner,
1180                None,
1181                PartitionMode::CollectLeft,
1182                NullEquality::NullEqualsNull,
1183                false,
1184            )
1185            .unwrap(),
1186        );
1187
1188        for optimizer in state.physical_optimizers() {
1189            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1190        }
1191
1192        assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1193        assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1194        assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1195        assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1196    }
1197}