Skip to main content

query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod json_expr_planner;
19mod planner;
20
21use std::any::Any;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use async_trait::async_trait;
26use common_base::Plugins;
27use common_catalog::consts::is_readonly_schema;
28use common_error::ext::BoxedError;
29use common_function::function::FunctionContext;
30use common_function::function_factory::ScalarFunctionFactory;
31use common_query::{Output, OutputData, OutputMeta};
32use common_recordbatch::adapter::RecordBatchStreamAdapter;
33use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
34use common_telemetry::tracing;
35use datafusion::catalog::TableFunction;
36use datafusion::dataframe::DataFrame;
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion::physical_plan::analyze::AnalyzeExec;
39use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
40use datafusion_common::ResolvedTableReference;
41use datafusion_expr::{
42    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
43};
44use datatypes::prelude::VectorRef;
45use datatypes::schema::Schema;
46use futures_util::StreamExt;
47use session::context::QueryContextRef;
48use snafu::{OptionExt, ResultExt, ensure};
49use sqlparser::ast::AnalyzeFormat;
50use table::TableRef;
51use table::requests::{DeleteRequest, InsertRequest};
52use tracing::Span;
53
54use crate::analyze::DistAnalyzeExec;
55pub use crate::datafusion::planner::DfContextProviderAdapter;
56use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
57use crate::error::{
58    CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
59    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
60    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
61};
62use crate::executor::QueryExecutor;
63use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
64use crate::physical_wrapper::PhysicalPlanWrapperRef;
65use crate::planner::{DfLogicalPlanner, LogicalPlanner};
66use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
67use crate::{QueryEngine, metrics};
68
69/// Query parallelism hint key.
70/// This hint can be set in the query context to control the parallelism of the query execution.
71pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
72
73/// Whether to fallback to the original plan when failed to push down.
74pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
75
76pub struct DatafusionQueryEngine {
77    state: Arc<QueryEngineState>,
78    plugins: Plugins,
79}
80
81impl DatafusionQueryEngine {
82    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
83        Self { state, plugins }
84    }
85
86    #[tracing::instrument(skip_all)]
87    async fn exec_query_plan(
88        &self,
89        plan: LogicalPlan,
90        query_ctx: QueryContextRef,
91    ) -> Result<Output> {
92        let mut ctx = self.engine_context(query_ctx.clone());
93
94        // `create_physical_plan` will optimize logical plan internally
95        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
96        let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
97
98        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
99            wrapper.wrap(optimized_physical_plan, query_ctx)
100        } else {
101            optimized_physical_plan
102        };
103
104        Ok(Output::new(
105            OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
106            OutputMeta::new_with_plan(physical_plan),
107        ))
108    }
109
110    #[tracing::instrument(skip_all)]
111    async fn exec_dml_statement(
112        &self,
113        dml: DmlStatement,
114        query_ctx: QueryContextRef,
115    ) -> Result<Output> {
116        ensure!(
117            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
118            UnsupportedExprSnafu {
119                name: format!("DML op {}", dml.op),
120            }
121        );
122
123        let _timer = QUERY_STAGE_ELAPSED
124            .with_label_values(&[dml.op.name()])
125            .start_timer();
126
127        let default_catalog = &query_ctx.current_catalog().to_owned();
128        let default_schema = &query_ctx.current_schema();
129        let table_name = dml.table_name.resolve(default_catalog, default_schema);
130        let table = self.find_table(&table_name, &query_ctx).await?;
131
132        let output = self
133            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
134            .await?;
135        let mut stream = match output.data {
136            OutputData::RecordBatches(batches) => batches.as_stream(),
137            OutputData::Stream(stream) => stream,
138            _ => unreachable!(),
139        };
140
141        let mut affected_rows = 0;
142        let mut insert_cost = 0;
143
144        while let Some(batch) = stream.next().await {
145            let batch = batch.context(CreateRecordBatchSnafu)?;
146            let column_vectors = batch
147                .column_vectors(&table_name.to_string(), table.schema())
148                .map_err(BoxedError::new)
149                .context(QueryExecutionSnafu)?;
150
151            match dml.op {
152                WriteOp::Insert(_) => {
153                    // We ignore the insert op.
154                    let output = self
155                        .insert(&table_name, column_vectors, query_ctx.clone())
156                        .await?;
157                    let (rows, cost) = output.extract_rows_and_cost();
158                    affected_rows += rows;
159                    insert_cost += cost;
160                }
161                WriteOp::Delete => {
162                    affected_rows += self
163                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
164                        .await?;
165                }
166                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
167            }
168        }
169        Ok(Output::new(
170            OutputData::AffectedRows(affected_rows),
171            OutputMeta::new_with_cost(insert_cost),
172        ))
173    }
174
175    #[tracing::instrument(skip_all)]
176    async fn delete(
177        &self,
178        table_name: &ResolvedTableReference,
179        table: &TableRef,
180        column_vectors: HashMap<String, VectorRef>,
181        query_ctx: QueryContextRef,
182    ) -> Result<usize> {
183        let catalog_name = table_name.catalog.to_string();
184        let schema_name = table_name.schema.to_string();
185        let table_name = table_name.table.to_string();
186        let table_schema = table.schema();
187
188        ensure!(
189            !is_readonly_schema(&schema_name),
190            TableReadOnlySnafu { table: table_name }
191        );
192
193        let ts_column = table_schema
194            .timestamp_column()
195            .map(|x| &x.name)
196            .with_context(|| MissingTimestampColumnSnafu {
197                table_name: table_name.clone(),
198            })?;
199
200        let table_info = table.table_info();
201        let rowkey_columns = table_info
202            .meta
203            .row_key_column_names()
204            .collect::<Vec<&String>>();
205        let column_vectors = column_vectors
206            .into_iter()
207            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
208            .collect::<HashMap<_, _>>();
209
210        let request = DeleteRequest {
211            catalog_name,
212            schema_name,
213            table_name,
214            key_column_values: column_vectors,
215        };
216
217        self.state
218            .table_mutation_handler()
219            .context(MissingTableMutationHandlerSnafu)?
220            .delete(request, query_ctx)
221            .await
222            .context(TableMutationSnafu)
223    }
224
225    #[tracing::instrument(skip_all)]
226    async fn insert(
227        &self,
228        table_name: &ResolvedTableReference,
229        column_vectors: HashMap<String, VectorRef>,
230        query_ctx: QueryContextRef,
231    ) -> Result<Output> {
232        let catalog_name = table_name.catalog.to_string();
233        let schema_name = table_name.schema.to_string();
234        let table_name = table_name.table.to_string();
235
236        ensure!(
237            !is_readonly_schema(&schema_name),
238            TableReadOnlySnafu { table: table_name }
239        );
240
241        let request = InsertRequest {
242            catalog_name,
243            schema_name,
244            table_name,
245            columns_values: column_vectors,
246        };
247
248        self.state
249            .table_mutation_handler()
250            .context(MissingTableMutationHandlerSnafu)?
251            .insert(request, query_ctx)
252            .await
253            .context(TableMutationSnafu)
254    }
255
256    async fn find_table(
257        &self,
258        table_name: &ResolvedTableReference,
259        query_context: &QueryContextRef,
260    ) -> Result<TableRef> {
261        let catalog_name = table_name.catalog.as_ref();
262        let schema_name = table_name.schema.as_ref();
263        let table_name = table_name.table.as_ref();
264
265        self.state
266            .catalog_manager()
267            .table(catalog_name, schema_name, table_name, Some(query_context))
268            .await
269            .context(CatalogSnafu)?
270            .with_context(|| TableNotFoundSnafu { table: table_name })
271    }
272
273    #[tracing::instrument(skip_all)]
274    async fn create_physical_plan(
275        &self,
276        ctx: &mut QueryEngineContext,
277        logical_plan: &LogicalPlan,
278    ) -> Result<Arc<dyn ExecutionPlan>> {
279        /// Only print context on panic, to avoid cluttering logs.
280        ///
281        /// TODO(discord9): remove this once we catch the bug
282        #[derive(Debug)]
283        struct PanicLogger<'a> {
284            input_logical_plan: &'a LogicalPlan,
285            after_analyze: Option<LogicalPlan>,
286            after_optimize: Option<LogicalPlan>,
287            phy_plan: Option<Arc<dyn ExecutionPlan>>,
288        }
289        impl Drop for PanicLogger<'_> {
290            fn drop(&mut self) {
291                if std::thread::panicking() {
292                    common_telemetry::error!(
293                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
294                        self.input_logical_plan,
295                        self.after_analyze,
296                        self.after_optimize,
297                        self.phy_plan
298                    );
299                }
300            }
301        }
302
303        let mut logger = PanicLogger {
304            input_logical_plan: logical_plan,
305            after_analyze: None,
306            after_optimize: None,
307            phy_plan: None,
308        };
309
310        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
311        let state = ctx.state();
312
313        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
314
315        // special handle EXPLAIN plan
316        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
317            return state
318                .create_physical_plan(logical_plan)
319                .await
320                .map_err(Into::into);
321        }
322
323        // analyze first
324        let analyzed_plan = state.analyzer().execute_and_check(
325            logical_plan.clone(),
326            state.config_options(),
327            |_, _| {},
328        )?;
329
330        logger.after_analyze = Some(analyzed_plan.clone());
331
332        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
333
334        // skip optimize for MergeScan
335        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
336            && ext.node.name() == MergeScanLogicalPlan::name()
337        {
338            analyzed_plan.clone()
339        } else {
340            state
341                .optimizer()
342                .optimize(analyzed_plan, state, |_, _| {})?
343        };
344
345        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
346        logger.after_optimize = Some(optimized_plan.clone());
347
348        let physical_plan = state
349            .query_planner()
350            .create_physical_plan(&optimized_plan, state)
351            .await?;
352
353        logger.phy_plan = Some(physical_plan.clone());
354        drop(logger);
355        Ok(physical_plan)
356    }
357
358    #[tracing::instrument(skip_all)]
359    fn optimize_physical_plan(
360        &self,
361        ctx: &mut QueryEngineContext,
362        plan: Arc<dyn ExecutionPlan>,
363    ) -> Result<Arc<dyn ExecutionPlan>> {
364        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
365
366        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
367        // if we need to optimize it again here.
368        // let state = ctx.state();
369        // let config = state.config_options();
370
371        // skip optimize AnalyzeExec plan
372        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
373        {
374            let format = if let Some(format) = ctx.query_ctx().explain_format()
375                && format.to_lowercase() == "json"
376            {
377                AnalyzeFormat::JSON
378            } else {
379                AnalyzeFormat::TEXT
380            };
381            // Sets the verbose flag of the query context.
382            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
383            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
384
385            Arc::new(DistAnalyzeExec::new(
386                analyze_plan.input().clone(),
387                analyze_plan.verbose(),
388                format,
389            ))
390            // let mut new_plan = analyze_plan.input().clone();
391            // for optimizer in state.physical_optimizers() {
392            //     new_plan = optimizer
393            //         .optimize(new_plan, config)
394            //         .context(DataFusionSnafu)?;
395            // }
396            // Arc::new(DistAnalyzeExec::new(new_plan))
397        } else {
398            plan
399            // let mut new_plan = plan;
400            // for optimizer in state.physical_optimizers() {
401            //     new_plan = optimizer
402            //         .optimize(new_plan, config)
403            //         .context(DataFusionSnafu)?;
404            // }
405            // new_plan
406        };
407
408        Ok(optimized_plan)
409    }
410}
411
412#[async_trait]
413impl QueryEngine for DatafusionQueryEngine {
414    fn as_any(&self) -> &dyn Any {
415        self
416    }
417
418    fn planner(&self) -> Arc<dyn LogicalPlanner> {
419        Arc::new(DfLogicalPlanner::new(self.state.clone()))
420    }
421
422    fn name(&self) -> &str {
423        "datafusion"
424    }
425
426    async fn describe(
427        &self,
428        plan: LogicalPlan,
429        _query_ctx: QueryContextRef,
430    ) -> Result<DescribeResult> {
431        Ok(DescribeResult { logical_plan: plan })
432    }
433
434    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
435        match plan {
436            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
437            _ => self.exec_query_plan(plan, query_ctx).await,
438        }
439    }
440
441    /// Note in SQL queries, aggregate names are looked up using
442    /// lowercase unless the query uses quotes. For example,
443    ///
444    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
445    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
446    ///
447    /// So it's better to make UDAF name lowercase when creating one.
448    fn register_aggregate_function(&self, func: AggregateUDF) {
449        self.state.register_aggr_function(func);
450    }
451
452    /// Register an scalar function.
453    /// Will override if the function with same name is already registered.
454    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
455        self.state.register_scalar_function(func);
456    }
457
458    fn register_table_function(&self, func: Arc<TableFunction>) {
459        self.state.register_table_function(func);
460    }
461
462    fn register_window_function(&self, func: WindowUDF) {
463        self.state.register_window_function(func);
464    }
465
466    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
467        self.state.read_table(table).map_err(Into::into)
468    }
469
470    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
471        let mut state = self.state.session_state();
472        state.config_mut().set_extension(query_ctx.clone());
473        // note that hints in "x-greptime-hints" is automatically parsed
474        // and set to query context's extension, so we can get it from query context.
475        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
476            if let Ok(n) = parallelism.parse::<u64>() {
477                if n > 0 {
478                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
479                    *state.config_mut() = new_cfg;
480                }
481            } else {
482                common_telemetry::warn!(
483                    "Failed to parse query_parallelism: {}, using default value",
484                    parallelism
485                );
486            }
487        }
488
489        // configure execution options
490        state.config_mut().options_mut().execution.time_zone =
491            Some(query_ctx.timezone().to_string());
492
493        // usually it's impossible to have both `set variable` set by sql client and
494        // hint in header by grpc client, so only need to deal with them separately
495        if query_ctx.configuration_parameter().allow_query_fallback() {
496            state
497                .config_mut()
498                .options_mut()
499                .extensions
500                .insert(DistPlannerOptions {
501                    allow_query_fallback: true,
502                });
503        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
504            // also check the query context for fallback hint
505            // if it is set, we will enable the fallback
506            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
507                state
508                    .config_mut()
509                    .options_mut()
510                    .extensions
511                    .insert(DistPlannerOptions {
512                        allow_query_fallback: true,
513                    });
514            }
515        }
516
517        state
518            .config_mut()
519            .options_mut()
520            .extensions
521            .insert(FunctionContext {
522                query_ctx: query_ctx.clone(),
523                state: self.engine_state().function_state(),
524            });
525
526        let config_options = state.config_options().clone();
527        let _ = state
528            .execution_props_mut()
529            .config_options
530            .insert(config_options);
531
532        QueryEngineContext::new(state, query_ctx)
533    }
534
535    fn engine_state(&self) -> &QueryEngineState {
536        &self.state
537    }
538}
539
540impl QueryExecutor for DatafusionQueryEngine {
541    #[tracing::instrument(skip_all)]
542    fn execute_stream(
543        &self,
544        ctx: &QueryEngineContext,
545        plan: &Arc<dyn ExecutionPlan>,
546    ) -> Result<SendableRecordBatchStream> {
547        let explain_verbose = ctx.query_ctx().explain_verbose();
548        let output_partitions = plan.properties().output_partitioning().partition_count();
549        if explain_verbose {
550            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
551        }
552
553        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
554        let task_ctx = ctx.build_task_ctx();
555        let span = Span::current();
556
557        match plan.properties().output_partitioning().partition_count() {
558            0 => {
559                let schema = Arc::new(
560                    Schema::try_from(plan.schema())
561                        .map_err(BoxedError::new)
562                        .context(QueryExecutionSnafu)?,
563                );
564                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
565            }
566            1 => {
567                let df_stream = plan.execute(0, task_ctx)?;
568                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
569                    .context(error::ConvertDfRecordBatchStreamSnafu)
570                    .map_err(BoxedError::new)
571                    .context(QueryExecutionSnafu)?;
572                stream.set_metrics2(plan.clone());
573                stream.set_explain_verbose(explain_verbose);
574                let stream = OnDone::new(Box::pin(stream), move || {
575                    let exec_cost = exec_timer.stop_and_record();
576                    if explain_verbose {
577                        common_telemetry::info!(
578                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
579                            exec_cost,
580                        );
581                    }
582                });
583                Ok(Box::pin(stream))
584            }
585            _ => {
586                // merge into a single partition
587                let merged_plan = CoalescePartitionsExec::new(plan.clone());
588                // CoalescePartitionsExec must produce a single partition
589                assert_eq!(
590                    1,
591                    merged_plan
592                        .properties()
593                        .output_partitioning()
594                        .partition_count()
595                );
596                let df_stream = merged_plan.execute(0, task_ctx)?;
597                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
598                    .context(error::ConvertDfRecordBatchStreamSnafu)
599                    .map_err(BoxedError::new)
600                    .context(QueryExecutionSnafu)?;
601                stream.set_metrics2(plan.clone());
602                stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
603                let stream = OnDone::new(Box::pin(stream), move || {
604                    let exec_cost = exec_timer.stop_and_record();
605                    if explain_verbose {
606                        common_telemetry::info!(
607                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
608                            exec_cost
609                        );
610                    }
611                });
612                Ok(Box::pin(stream))
613            }
614        }
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use std::fmt;
621    use std::sync::Arc;
622    use std::sync::atomic::{AtomicUsize, Ordering};
623
624    use api::v1::SemanticType;
625    use arrow::array::{ArrayRef, UInt64Array};
626    use arrow_schema::SortOptions;
627    use catalog::RegisterTableRequest;
628    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
629    use common_error::ext::BoxedError;
630    use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
631    use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
632    use datafusion::physical_plan::expressions::PhysicalSortExpr;
633    use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
634    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
635    use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
636    use datafusion::prelude::{col, lit};
637    use datafusion_common::{JoinType, NullEquality};
638    use datafusion_physical_expr::expressions::Column;
639    use datatypes::prelude::ConcreteDataType;
640    use datatypes::schema::{ColumnSchema, SchemaRef};
641    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
642    use session::context::{QueryContext, QueryContextBuilder};
643    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
644    use store_api::region_engine::{
645        PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
646    };
647    use store_api::storage::{RegionId, ScanRequest};
648    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
649    use table::table::scan::RegionScanExec;
650
651    use super::*;
652    use crate::options::QueryOptions;
653    use crate::parser::QueryLanguageParser;
654    use crate::part_sort::PartSortExec;
655    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
656
657    #[derive(Debug)]
658    struct RecordingScanner {
659        schema: SchemaRef,
660        metadata: RegionMetadataRef,
661        properties: ScannerProperties,
662        update_calls: Arc<AtomicUsize>,
663        last_filter_len: Arc<AtomicUsize>,
664    }
665
666    impl RecordingScanner {
667        fn new(
668            schema: SchemaRef,
669            metadata: RegionMetadataRef,
670            update_calls: Arc<AtomicUsize>,
671            last_filter_len: Arc<AtomicUsize>,
672        ) -> Self {
673            Self {
674                schema,
675                metadata,
676                properties: ScannerProperties::default(),
677                update_calls,
678                last_filter_len,
679            }
680        }
681    }
682
683    impl RegionScanner for RecordingScanner {
684        fn name(&self) -> &str {
685            "RecordingScanner"
686        }
687
688        fn properties(&self) -> &ScannerProperties {
689            &self.properties
690        }
691
692        fn schema(&self) -> SchemaRef {
693            self.schema.clone()
694        }
695
696        fn metadata(&self) -> RegionMetadataRef {
697            self.metadata.clone()
698        }
699
700        fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
701            self.properties.prepare(request);
702            Ok(())
703        }
704
705        fn scan_partition(
706            &self,
707            _ctx: &QueryScanContext,
708            _metrics_set: &ExecutionPlanMetricsSet,
709            _partition: usize,
710        ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
711            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
712        }
713
714        fn has_predicate_without_region(&self) -> bool {
715            true
716        }
717
718        fn add_dyn_filter_to_predicate(
719            &mut self,
720            filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
721        ) -> Vec<bool> {
722            self.update_calls.fetch_add(1, Ordering::Relaxed);
723            self.last_filter_len
724                .store(filter_exprs.len(), Ordering::Relaxed);
725            vec![true; filter_exprs.len()]
726        }
727
728        fn set_logical_region(&mut self, logical_region: bool) {
729            self.properties.set_logical_region(logical_region);
730        }
731    }
732
733    impl DisplayAs for RecordingScanner {
734        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
735            write!(f, "RecordingScanner")
736        }
737    }
738
739    async fn create_test_engine() -> QueryEngineRef {
740        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
741        let req = RegisterTableRequest {
742            catalog: DEFAULT_CATALOG_NAME.to_string(),
743            schema: DEFAULT_SCHEMA_NAME.to_string(),
744            table_name: NUMBERS_TABLE_NAME.to_string(),
745            table_id: NUMBERS_TABLE_ID,
746            table: NumbersTable::table(NUMBERS_TABLE_ID),
747        };
748        catalog_manager.register_table_sync(req).unwrap();
749
750        QueryEngineFactory::new(
751            catalog_manager,
752            None,
753            None,
754            None,
755            None,
756            false,
757            QueryOptions::default(),
758        )
759        .query_engine()
760    }
761
762    #[tokio::test]
763    async fn test_sql_to_plan() {
764        let engine = create_test_engine().await;
765        let sql = "select sum(number) from numbers limit 20";
766
767        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
768        let plan = engine
769            .planner()
770            .plan(&stmt, QueryContext::arc())
771            .await
772            .unwrap();
773
774        assert_eq!(
775            plan.to_string(),
776            r#"Limit: skip=0, fetch=20
777  Projection: sum(numbers.number)
778    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
779      TableScan: numbers"#
780        );
781    }
782
783    #[tokio::test]
784    async fn test_execute() {
785        let engine = create_test_engine().await;
786        let sql = "select sum(number) from numbers limit 20";
787
788        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
789        let plan = engine
790            .planner()
791            .plan(&stmt, QueryContext::arc())
792            .await
793            .unwrap();
794
795        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
796
797        match output.data {
798            OutputData::Stream(recordbatch) => {
799                let numbers = util::collect(recordbatch).await.unwrap();
800                assert_eq!(1, numbers.len());
801                assert_eq!(numbers[0].num_columns(), 1);
802                assert_eq!(1, numbers[0].schema.num_columns());
803                assert_eq!(
804                    "sum(numbers.number)",
805                    numbers[0].schema.column_schemas()[0].name
806                );
807
808                let batch = &numbers[0];
809                assert_eq!(1, batch.num_columns());
810                assert_eq!(batch.column(0).len(), 1);
811
812                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
813                assert_eq!(batch.column(0), &expected);
814            }
815            _ => unreachable!(),
816        }
817    }
818
819    #[tokio::test]
820    async fn test_read_table() {
821        let engine = create_test_engine().await;
822
823        let engine = engine
824            .as_any()
825            .downcast_ref::<DatafusionQueryEngine>()
826            .unwrap();
827        let query_ctx = Arc::new(QueryContextBuilder::default().build());
828        let table = engine
829            .find_table(
830                &ResolvedTableReference {
831                    catalog: "greptime".into(),
832                    schema: "public".into(),
833                    table: "numbers".into(),
834                },
835                &query_ctx,
836            )
837            .await
838            .unwrap();
839
840        let df = engine.read_table(table).unwrap();
841        let df = df
842            .select_columns(&["number"])
843            .unwrap()
844            .filter(col("number").lt(lit(10)))
845            .unwrap();
846        let batches = df.collect().await.unwrap();
847        assert_eq!(1, batches.len());
848        let batch = &batches[0];
849
850        assert_eq!(1, batch.num_columns());
851        assert_eq!(batch.column(0).len(), 10);
852
853        assert_eq!(
854            Helper::try_into_vector(batch.column(0)).unwrap(),
855            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
856        );
857    }
858
859    #[tokio::test]
860    async fn test_describe() {
861        let engine = create_test_engine().await;
862        let sql = "select sum(number) from numbers limit 20";
863
864        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
865
866        let plan = engine
867            .planner()
868            .plan(&stmt, QueryContext::arc())
869            .await
870            .unwrap();
871
872        let DescribeResult { logical_plan } =
873            engine.describe(plan, QueryContext::arc()).await.unwrap();
874
875        let schema: Schema = logical_plan.schema().clone().try_into().unwrap();
876
877        assert_eq!(
878            schema.column_schemas()[0],
879            ColumnSchema::new(
880                "sum(numbers.number)",
881                ConcreteDataType::uint64_datatype(),
882                true
883            )
884        );
885        assert_eq!(
886            "Limit: skip=0, fetch=20\n  Projection: sum(numbers.number)\n    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n      TableScan: numbers",
887            format!("{}", logical_plan.display_indent())
888        );
889    }
890
891    #[tokio::test]
892    async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
893        let engine = create_test_engine().await;
894        let engine = engine
895            .as_any()
896            .downcast_ref::<DatafusionQueryEngine>()
897            .unwrap();
898        let engine_ctx = engine.engine_context(QueryContext::arc());
899        let state = engine_ctx.state();
900
901        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
902            "ts",
903            ConcreteDataType::timestamp_millisecond_datatype(),
904            false,
905        )]));
906
907        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
908        metadata_builder
909            .push_column_metadata(ColumnMetadata {
910                column_schema: ColumnSchema::new(
911                    "ts",
912                    ConcreteDataType::timestamp_millisecond_datatype(),
913                    false,
914                )
915                .with_time_index(true),
916                semantic_type: SemanticType::Timestamp,
917                column_id: 1,
918            })
919            .primary_key(vec![]);
920        let metadata = Arc::new(metadata_builder.build().unwrap());
921
922        let update_calls = Arc::new(AtomicUsize::new(0));
923        let last_filter_len = Arc::new(AtomicUsize::new(0));
924        let scanner = Box::new(RecordingScanner::new(
925            schema,
926            metadata,
927            update_calls.clone(),
928            last_filter_len.clone(),
929        ));
930        let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
931
932        let sort_expr = PhysicalSortExpr {
933            expr: Arc::new(Column::new("ts", 0)),
934            options: SortOptions {
935                descending: true,
936                ..Default::default()
937            },
938        };
939        let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
940        let mut plan: Arc<dyn ExecutionPlan> =
941            Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
942
943        for optimizer in state.physical_optimizers() {
944            plan = optimizer.optimize(plan, state.config_options()).unwrap();
945        }
946
947        assert!(update_calls.load(Ordering::Relaxed) > 0);
948        assert!(last_filter_len.load(Ordering::Relaxed) > 0);
949    }
950
951    #[tokio::test]
952    async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
953        let engine = create_test_engine().await;
954        let engine = engine
955            .as_any()
956            .downcast_ref::<DatafusionQueryEngine>()
957            .unwrap();
958        let engine_ctx = engine.engine_context(QueryContext::arc());
959        let state = engine_ctx.state();
960
961        assert!(
962            state
963                .config_options()
964                .optimizer
965                .enable_join_dynamic_filter_pushdown
966        );
967
968        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
969            "ts",
970            ConcreteDataType::timestamp_millisecond_datatype(),
971            false,
972        )]));
973
974        let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
975        left_metadata_builder
976            .push_column_metadata(ColumnMetadata {
977                column_schema: ColumnSchema::new(
978                    "ts",
979                    ConcreteDataType::timestamp_millisecond_datatype(),
980                    false,
981                )
982                .with_time_index(true),
983                semantic_type: SemanticType::Timestamp,
984                column_id: 1,
985            })
986            .primary_key(vec![]);
987        let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
988
989        let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
990        right_metadata_builder
991            .push_column_metadata(ColumnMetadata {
992                column_schema: ColumnSchema::new(
993                    "ts",
994                    ConcreteDataType::timestamp_millisecond_datatype(),
995                    false,
996                )
997                .with_time_index(true),
998                semantic_type: SemanticType::Timestamp,
999                column_id: 1,
1000            })
1001            .primary_key(vec![]);
1002        let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1003
1004        let left_update_calls = Arc::new(AtomicUsize::new(0));
1005        let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1006        let right_update_calls = Arc::new(AtomicUsize::new(0));
1007        let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1008
1009        let left_scan = Arc::new(
1010            RegionScanExec::new(
1011                Box::new(RecordingScanner::new(
1012                    schema.clone(),
1013                    left_metadata,
1014                    left_update_calls.clone(),
1015                    left_last_filter_len.clone(),
1016                )),
1017                ScanRequest::default(),
1018                None,
1019            )
1020            .unwrap(),
1021        );
1022        let right_scan = Arc::new(
1023            RegionScanExec::new(
1024                Box::new(RecordingScanner::new(
1025                    schema,
1026                    right_metadata,
1027                    right_update_calls.clone(),
1028                    right_last_filter_len.clone(),
1029                )),
1030                ScanRequest::default(),
1031                None,
1032            )
1033            .unwrap(),
1034        );
1035
1036        let on: JoinOn = vec![(
1037            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1038            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1039        )];
1040
1041        let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1042            HashJoinExec::try_new(
1043                left_scan,
1044                right_scan,
1045                on,
1046                None,
1047                &JoinType::Inner,
1048                None,
1049                PartitionMode::CollectLeft,
1050                NullEquality::NullEqualsNull,
1051                false,
1052            )
1053            .unwrap(),
1054        );
1055
1056        for optimizer in state.physical_optimizers() {
1057            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1058        }
1059
1060        assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1061        assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1062        assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1063        assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1064    }
1065}