Skip to main content

query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod json_expr_planner;
19mod planner;
20
21use std::any::Any;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use async_trait::async_trait;
26use common_base::Plugins;
27use common_catalog::consts::is_readonly_schema;
28use common_error::ext::BoxedError;
29use common_function::function::FunctionContext;
30use common_function::function_factory::ScalarFunctionFactory;
31use common_query::{Output, OutputData, OutputMeta};
32use common_recordbatch::adapter::RecordBatchStreamAdapter;
33use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
34use common_telemetry::tracing;
35use datafusion::catalog::TableFunction;
36use datafusion::dataframe::DataFrame;
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion::physical_plan::analyze::AnalyzeExec;
39use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
40use datafusion_common::ResolvedTableReference;
41use datafusion_expr::{
42    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
43};
44use datatypes::prelude::VectorRef;
45use datatypes::schema::Schema;
46use futures_util::StreamExt;
47use session::context::QueryContextRef;
48use snafu::{OptionExt, ResultExt, ensure};
49use sqlparser::ast::AnalyzeFormat;
50use table::TableRef;
51use table::requests::{DeleteRequest, InsertRequest};
52use tracing::Span;
53
54use crate::analyze::DistAnalyzeExec;
55pub use crate::datafusion::planner::DfContextProviderAdapter;
56use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
57use crate::error::{
58    CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
59    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
60    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
61};
62use crate::executor::QueryExecutor;
63use crate::metrics::{
64    OnDone, QUERY_STAGE_ELAPSED, maybe_attach_region_watermark_metrics,
65    should_collect_region_watermark_from_query_ctx,
66};
67use crate::physical_wrapper::PhysicalPlanWrapperRef;
68use crate::planner::{DfLogicalPlanner, LogicalPlanner};
69use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
70use crate::{QueryEngine, metrics};
71
72/// Query parallelism hint key.
73/// This hint can be set in the query context to control the parallelism of the query execution.
74pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
75
76/// Whether to fallback to the original plan when failed to push down.
77pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
78
79pub struct DatafusionQueryEngine {
80    state: Arc<QueryEngineState>,
81    plugins: Plugins,
82}
83
84impl DatafusionQueryEngine {
85    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
86        Self { state, plugins }
87    }
88
89    #[tracing::instrument(skip_all)]
90    async fn exec_query_plan(
91        &self,
92        plan: LogicalPlan,
93        query_ctx: QueryContextRef,
94    ) -> Result<Output> {
95        let mut ctx = self.engine_context(query_ctx.clone());
96
97        // `create_physical_plan` will optimize logical plan internally
98        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
99        let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
100
101        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
102            wrapper.wrap(optimized_physical_plan, query_ctx)
103        } else {
104            optimized_physical_plan
105        };
106
107        let stream = self.execute_stream(&ctx, &physical_plan)?;
108
109        Ok(Output::new(
110            OutputData::Stream(stream),
111            OutputMeta::new_with_plan(physical_plan),
112        ))
113    }
114
115    #[tracing::instrument(skip_all)]
116    async fn exec_dml_statement(
117        &self,
118        dml: DmlStatement,
119        query_ctx: QueryContextRef,
120    ) -> Result<Output> {
121        ensure!(
122            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
123            UnsupportedExprSnafu {
124                name: format!("DML op {}", dml.op),
125            }
126        );
127
128        let _timer = QUERY_STAGE_ELAPSED
129            .with_label_values(&[dml.op.name()])
130            .start_timer();
131
132        let default_catalog = &query_ctx.current_catalog().to_owned();
133        let default_schema = &query_ctx.current_schema();
134        let table_name = dml.table_name.resolve(default_catalog, default_schema);
135        let table = self.find_table(&table_name, &query_ctx).await?;
136
137        let Output { data, meta } = self
138            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
139            .await?;
140        let mut stream = match data {
141            OutputData::RecordBatches(batches) => batches.as_stream(),
142            OutputData::Stream(stream) => stream,
143            _ => unreachable!(),
144        };
145
146        let mut affected_rows = 0;
147        let mut insert_cost = 0;
148
149        while let Some(batch) = stream.next().await {
150            let batch = batch.context(CreateRecordBatchSnafu)?;
151            let column_vectors = batch
152                .column_vectors(&table_name.to_string(), table.schema())
153                .map_err(BoxedError::new)
154                .context(QueryExecutionSnafu)?;
155
156            match dml.op {
157                WriteOp::Insert(_) => {
158                    // We ignore the insert op.
159                    let output = self
160                        .insert(&table_name, column_vectors, query_ctx.clone())
161                        .await?;
162                    let (rows, cost) = output.extract_rows_and_cost();
163                    affected_rows += rows;
164                    insert_cost += cost;
165                }
166                WriteOp::Delete => {
167                    affected_rows += self
168                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
169                        .await?;
170                }
171                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
172            }
173        }
174        Ok(Output::new(
175            OutputData::AffectedRows(affected_rows),
176            OutputMeta::new(meta.plan, insert_cost),
177        ))
178    }
179
180    #[tracing::instrument(skip_all)]
181    async fn delete(
182        &self,
183        table_name: &ResolvedTableReference,
184        table: &TableRef,
185        column_vectors: HashMap<String, VectorRef>,
186        query_ctx: QueryContextRef,
187    ) -> Result<usize> {
188        let catalog_name = table_name.catalog.to_string();
189        let schema_name = table_name.schema.to_string();
190        let table_name = table_name.table.to_string();
191        let table_schema = table.schema();
192
193        ensure!(
194            !is_readonly_schema(&schema_name),
195            TableReadOnlySnafu { table: table_name }
196        );
197
198        let ts_column = table_schema
199            .timestamp_column()
200            .map(|x| &x.name)
201            .with_context(|| MissingTimestampColumnSnafu {
202                table_name: table_name.clone(),
203            })?;
204
205        let table_info = table.table_info();
206        let rowkey_columns = table_info
207            .meta
208            .row_key_column_names()
209            .collect::<Vec<&String>>();
210        let column_vectors = column_vectors
211            .into_iter()
212            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
213            .collect::<HashMap<_, _>>();
214
215        let request = DeleteRequest {
216            catalog_name,
217            schema_name,
218            table_name,
219            key_column_values: column_vectors,
220        };
221
222        self.state
223            .table_mutation_handler()
224            .context(MissingTableMutationHandlerSnafu)?
225            .delete(request, query_ctx)
226            .await
227            .context(TableMutationSnafu)
228    }
229
230    #[tracing::instrument(skip_all)]
231    async fn insert(
232        &self,
233        table_name: &ResolvedTableReference,
234        column_vectors: HashMap<String, VectorRef>,
235        query_ctx: QueryContextRef,
236    ) -> Result<Output> {
237        let catalog_name = table_name.catalog.to_string();
238        let schema_name = table_name.schema.to_string();
239        let table_name = table_name.table.to_string();
240
241        ensure!(
242            !is_readonly_schema(&schema_name),
243            TableReadOnlySnafu { table: table_name }
244        );
245
246        let request = InsertRequest {
247            catalog_name,
248            schema_name,
249            table_name,
250            columns_values: column_vectors,
251        };
252
253        self.state
254            .table_mutation_handler()
255            .context(MissingTableMutationHandlerSnafu)?
256            .insert(request, query_ctx)
257            .await
258            .context(TableMutationSnafu)
259    }
260
261    async fn find_table(
262        &self,
263        table_name: &ResolvedTableReference,
264        query_context: &QueryContextRef,
265    ) -> Result<TableRef> {
266        let catalog_name = table_name.catalog.as_ref();
267        let schema_name = table_name.schema.as_ref();
268        let table_name = table_name.table.as_ref();
269
270        self.state
271            .catalog_manager()
272            .table(catalog_name, schema_name, table_name, Some(query_context))
273            .await
274            .context(CatalogSnafu)?
275            .with_context(|| TableNotFoundSnafu { table: table_name })
276    }
277
278    #[tracing::instrument(skip_all)]
279    async fn create_physical_plan(
280        &self,
281        ctx: &mut QueryEngineContext,
282        logical_plan: &LogicalPlan,
283    ) -> Result<Arc<dyn ExecutionPlan>> {
284        /// Only print context on panic, to avoid cluttering logs.
285        ///
286        /// TODO(discord9): remove this once we catch the bug
287        #[derive(Debug)]
288        struct PanicLogger<'a> {
289            input_logical_plan: &'a LogicalPlan,
290            after_analyze: Option<LogicalPlan>,
291            after_optimize: Option<LogicalPlan>,
292            phy_plan: Option<Arc<dyn ExecutionPlan>>,
293        }
294        impl Drop for PanicLogger<'_> {
295            fn drop(&mut self) {
296                if std::thread::panicking() {
297                    common_telemetry::error!(
298                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
299                        self.input_logical_plan,
300                        self.after_analyze,
301                        self.after_optimize,
302                        self.phy_plan
303                    );
304                }
305            }
306        }
307
308        let mut logger = PanicLogger {
309            input_logical_plan: logical_plan,
310            after_analyze: None,
311            after_optimize: None,
312            phy_plan: None,
313        };
314
315        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
316        let state = ctx.state();
317
318        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
319
320        // special handle EXPLAIN plan
321        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
322            return state
323                .create_physical_plan(logical_plan)
324                .await
325                .map_err(Into::into);
326        }
327
328        // analyze first
329        let analyzed_plan = state.analyzer().execute_and_check(
330            logical_plan.clone(),
331            state.config_options(),
332            |_, _| {},
333        )?;
334
335        logger.after_analyze = Some(analyzed_plan.clone());
336
337        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
338
339        // skip optimize for MergeScan
340        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
341            && ext.node.name() == MergeScanLogicalPlan::name()
342        {
343            analyzed_plan.clone()
344        } else {
345            state
346                .optimizer()
347                .optimize(analyzed_plan, state, |_, _| {})?
348        };
349
350        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
351        logger.after_optimize = Some(optimized_plan.clone());
352
353        let physical_plan = state
354            .query_planner()
355            .create_physical_plan(&optimized_plan, state)
356            .await?;
357
358        logger.phy_plan = Some(physical_plan.clone());
359        drop(logger);
360        Ok(physical_plan)
361    }
362
363    #[tracing::instrument(skip_all)]
364    fn optimize_physical_plan(
365        &self,
366        ctx: &mut QueryEngineContext,
367        plan: Arc<dyn ExecutionPlan>,
368    ) -> Result<Arc<dyn ExecutionPlan>> {
369        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
370
371        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
372        // if we need to optimize it again here.
373        // let state = ctx.state();
374        // let config = state.config_options();
375
376        // skip optimize AnalyzeExec plan
377        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
378        {
379            let format = if let Some(format) = ctx.query_ctx().explain_format()
380                && format.to_lowercase() == "json"
381            {
382                AnalyzeFormat::JSON
383            } else {
384                AnalyzeFormat::TEXT
385            };
386            // Sets the verbose flag of the query context.
387            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
388            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
389
390            Arc::new(DistAnalyzeExec::new(
391                analyze_plan.input().clone(),
392                analyze_plan.verbose(),
393                format,
394            ))
395            // let mut new_plan = analyze_plan.input().clone();
396            // for optimizer in state.physical_optimizers() {
397            //     new_plan = optimizer
398            //         .optimize(new_plan, config)
399            //         .context(DataFusionSnafu)?;
400            // }
401            // Arc::new(DistAnalyzeExec::new(new_plan))
402        } else {
403            plan
404            // let mut new_plan = plan;
405            // for optimizer in state.physical_optimizers() {
406            //     new_plan = optimizer
407            //         .optimize(new_plan, config)
408            //         .context(DataFusionSnafu)?;
409            // }
410            // new_plan
411        };
412
413        Ok(optimized_plan)
414    }
415}
416
417#[async_trait]
418impl QueryEngine for DatafusionQueryEngine {
419    fn as_any(&self) -> &dyn Any {
420        self
421    }
422
423    fn planner(&self) -> Arc<dyn LogicalPlanner> {
424        Arc::new(DfLogicalPlanner::new(self.state.clone()))
425    }
426
427    fn name(&self) -> &str {
428        "datafusion"
429    }
430
431    async fn describe(
432        &self,
433        plan: LogicalPlan,
434        _query_ctx: QueryContextRef,
435    ) -> Result<DescribeResult> {
436        Ok(DescribeResult { logical_plan: plan })
437    }
438
439    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
440        match plan {
441            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
442            _ => self.exec_query_plan(plan, query_ctx).await,
443        }
444    }
445
446    /// Note in SQL queries, aggregate names are looked up using
447    /// lowercase unless the query uses quotes. For example,
448    ///
449    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
450    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
451    ///
452    /// So it's better to make UDAF name lowercase when creating one.
453    fn register_aggregate_function(&self, func: AggregateUDF) {
454        self.state.register_aggr_function(func);
455    }
456
457    /// Register an scalar function.
458    /// Will override if the function with same name is already registered.
459    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
460        self.state.register_scalar_function(func);
461    }
462
463    fn register_table_function(&self, func: Arc<TableFunction>) {
464        self.state.register_table_function(func);
465    }
466
467    fn register_window_function(&self, func: WindowUDF) {
468        self.state.register_window_function(func);
469    }
470
471    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
472        self.state.read_table(table).map_err(Into::into)
473    }
474
475    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
476        let mut state = self.state.session_state();
477        state.config_mut().set_extension(query_ctx.clone());
478        state.config_mut().set_extension(self.state.clone());
479        // note that hints in "x-greptime-hints" is automatically parsed
480        // and set to query context's extension, so we can get it from query context.
481        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
482            if let Ok(n) = parallelism.parse::<u64>() {
483                if n > 0 {
484                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
485                    *state.config_mut() = new_cfg;
486                }
487            } else {
488                common_telemetry::warn!(
489                    "Failed to parse query_parallelism: {}, using default value",
490                    parallelism
491                );
492            }
493        }
494
495        // configure execution options
496        state.config_mut().options_mut().execution.time_zone =
497            Some(query_ctx.timezone().to_string());
498
499        // usually it's impossible to have both `set variable` set by sql client and
500        // hint in header by grpc client, so only need to deal with them separately
501        if query_ctx.configuration_parameter().allow_query_fallback() {
502            state
503                .config_mut()
504                .options_mut()
505                .extensions
506                .insert(DistPlannerOptions {
507                    allow_query_fallback: true,
508                });
509        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
510            // also check the query context for fallback hint
511            // if it is set, we will enable the fallback
512            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
513                state
514                    .config_mut()
515                    .options_mut()
516                    .extensions
517                    .insert(DistPlannerOptions {
518                        allow_query_fallback: true,
519                    });
520            }
521        }
522
523        state
524            .config_mut()
525            .options_mut()
526            .extensions
527            .insert(FunctionContext {
528                query_ctx: query_ctx.clone(),
529                state: self.engine_state().function_state(),
530            });
531
532        let config_options = state.config_options().clone();
533        let _ = state
534            .execution_props_mut()
535            .config_options
536            .insert(config_options);
537
538        QueryEngineContext::new(state, query_ctx)
539    }
540
541    fn engine_state(&self) -> &QueryEngineState {
542        &self.state
543    }
544}
545
546impl QueryExecutor for DatafusionQueryEngine {
547    #[tracing::instrument(skip_all)]
548    fn execute_stream(
549        &self,
550        ctx: &QueryEngineContext,
551        plan: &Arc<dyn ExecutionPlan>,
552    ) -> Result<SendableRecordBatchStream> {
553        let query_ctx = ctx.query_ctx();
554        let explain_verbose = query_ctx.explain_verbose();
555        let should_collect_region_watermark =
556            should_collect_region_watermark_from_query_ctx(&query_ctx)?;
557        let output_partitions = plan.properties().output_partitioning().partition_count();
558        if explain_verbose {
559            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
560        }
561
562        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
563        let task_ctx = ctx.build_task_ctx();
564        let span = Span::current();
565
566        match plan.properties().output_partitioning().partition_count() {
567            0 => {
568                let schema = Arc::new(
569                    Schema::try_from(plan.schema())
570                        .map_err(BoxedError::new)
571                        .context(QueryExecutionSnafu)?,
572                );
573                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
574            }
575            1 => {
576                let df_stream = plan.execute(0, task_ctx)?;
577                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
578                    .context(error::ConvertDfRecordBatchStreamSnafu)
579                    .map_err(BoxedError::new)
580                    .context(QueryExecutionSnafu)?;
581                stream.set_metrics2(plan.clone());
582                stream.set_explain_verbose(explain_verbose);
583                let stream = OnDone::new(Box::pin(stream), move || {
584                    let exec_cost = exec_timer.stop_and_record();
585                    if explain_verbose {
586                        common_telemetry::info!(
587                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
588                            exec_cost,
589                        );
590                    }
591                });
592                Ok(maybe_attach_region_watermark_metrics(
593                    Box::pin(stream),
594                    plan.clone(),
595                    should_collect_region_watermark,
596                ))
597            }
598            _ => {
599                // merge into a single partition
600                let merged_plan = CoalescePartitionsExec::new(plan.clone());
601                // CoalescePartitionsExec must produce a single partition
602                assert_eq!(
603                    1,
604                    merged_plan
605                        .properties()
606                        .output_partitioning()
607                        .partition_count()
608                );
609                let df_stream = merged_plan.execute(0, task_ctx)?;
610                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
611                    .context(error::ConvertDfRecordBatchStreamSnafu)
612                    .map_err(BoxedError::new)
613                    .context(QueryExecutionSnafu)?;
614                stream.set_metrics2(plan.clone());
615                stream.set_explain_verbose(explain_verbose);
616                let stream = OnDone::new(Box::pin(stream), move || {
617                    let exec_cost = exec_timer.stop_and_record();
618                    if explain_verbose {
619                        common_telemetry::info!(
620                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
621                            exec_cost
622                        );
623                    }
624                });
625                Ok(maybe_attach_region_watermark_metrics(
626                    Box::pin(stream),
627                    plan.clone(),
628                    should_collect_region_watermark,
629                ))
630            }
631        }
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use std::fmt;
638    use std::sync::Arc;
639    use std::sync::atomic::{AtomicUsize, Ordering};
640
641    use api::v1::SemanticType;
642    use arrow::array::{ArrayRef, UInt64Array};
643    use arrow_schema::SortOptions;
644    use catalog::RegisterTableRequest;
645    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
646    use common_error::ext::BoxedError;
647    use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
648    use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
649    use datafusion::physical_plan::expressions::PhysicalSortExpr;
650    use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
651    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
652    use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
653    use datafusion::prelude::{col, lit};
654    use datafusion_common::{JoinType, NullEquality};
655    use datafusion_physical_expr::expressions::Column;
656    use datatypes::prelude::ConcreteDataType;
657    use datatypes::schema::{ColumnSchema, SchemaRef};
658    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
659    use session::context::{QueryContext, QueryContextBuilder};
660    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
661    use store_api::region_engine::{
662        PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
663    };
664    use store_api::storage::{RegionId, ScanRequest};
665    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
666    use table::table::scan::RegionScanExec;
667
668    use super::*;
669    use crate::options::QueryOptions;
670    use crate::parser::QueryLanguageParser;
671    use crate::part_sort::PartSortExec;
672    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
673
674    #[derive(Debug)]
675    struct RecordingScanner {
676        schema: SchemaRef,
677        metadata: RegionMetadataRef,
678        properties: ScannerProperties,
679        update_calls: Arc<AtomicUsize>,
680        last_filter_len: Arc<AtomicUsize>,
681    }
682
683    impl RecordingScanner {
684        fn new(
685            schema: SchemaRef,
686            metadata: RegionMetadataRef,
687            update_calls: Arc<AtomicUsize>,
688            last_filter_len: Arc<AtomicUsize>,
689        ) -> Self {
690            Self {
691                schema,
692                metadata,
693                properties: ScannerProperties::default(),
694                update_calls,
695                last_filter_len,
696            }
697        }
698    }
699
700    impl RegionScanner for RecordingScanner {
701        fn name(&self) -> &str {
702            "RecordingScanner"
703        }
704
705        fn properties(&self) -> &ScannerProperties {
706            &self.properties
707        }
708
709        fn schema(&self) -> SchemaRef {
710            self.schema.clone()
711        }
712
713        fn metadata(&self) -> RegionMetadataRef {
714            self.metadata.clone()
715        }
716
717        fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
718            self.properties.prepare(request);
719            Ok(())
720        }
721
722        fn scan_partition(
723            &self,
724            _ctx: &QueryScanContext,
725            _metrics_set: &ExecutionPlanMetricsSet,
726            _partition: usize,
727        ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
728            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
729        }
730
731        fn has_predicate_without_region(&self) -> bool {
732            true
733        }
734
735        fn add_dyn_filter_to_predicate(
736            &mut self,
737            filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
738        ) -> Vec<bool> {
739            self.update_calls.fetch_add(1, Ordering::Relaxed);
740            self.last_filter_len
741                .store(filter_exprs.len(), Ordering::Relaxed);
742            vec![true; filter_exprs.len()]
743        }
744
745        fn set_logical_region(&mut self, logical_region: bool) {
746            self.properties.set_logical_region(logical_region);
747        }
748    }
749
750    impl DisplayAs for RecordingScanner {
751        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
752            write!(f, "RecordingScanner")
753        }
754    }
755
756    async fn create_test_engine() -> QueryEngineRef {
757        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
758        let req = RegisterTableRequest {
759            catalog: DEFAULT_CATALOG_NAME.to_string(),
760            schema: DEFAULT_SCHEMA_NAME.to_string(),
761            table_name: NUMBERS_TABLE_NAME.to_string(),
762            table_id: NUMBERS_TABLE_ID,
763            table: NumbersTable::table(NUMBERS_TABLE_ID),
764        };
765        catalog_manager.register_table_sync(req).unwrap();
766
767        QueryEngineFactory::new(
768            catalog_manager,
769            None,
770            None,
771            None,
772            None,
773            false,
774            QueryOptions::default(),
775        )
776        .query_engine()
777    }
778
779    #[tokio::test]
780    async fn test_sql_to_plan() {
781        let engine = create_test_engine().await;
782        let sql = "select sum(number) from numbers limit 20";
783
784        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
785        let plan = engine
786            .planner()
787            .plan(&stmt, QueryContext::arc())
788            .await
789            .unwrap();
790
791        assert_eq!(
792            plan.to_string(),
793            r#"Limit: skip=0, fetch=20
794  Projection: sum(numbers.number)
795    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
796      TableScan: numbers"#
797        );
798    }
799
800    #[tokio::test]
801    async fn test_execute() {
802        let engine = create_test_engine().await;
803        let sql = "select sum(number) from numbers limit 20";
804
805        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
806        let plan = engine
807            .planner()
808            .plan(&stmt, QueryContext::arc())
809            .await
810            .unwrap();
811
812        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
813
814        match output.data {
815            OutputData::Stream(recordbatch) => {
816                let numbers = util::collect(recordbatch).await.unwrap();
817                assert_eq!(1, numbers.len());
818                assert_eq!(numbers[0].num_columns(), 1);
819                assert_eq!(1, numbers[0].schema.num_columns());
820                assert_eq!(
821                    "sum(numbers.number)",
822                    numbers[0].schema.column_schemas()[0].name
823                );
824
825                let batch = &numbers[0];
826                assert_eq!(1, batch.num_columns());
827                assert_eq!(batch.column(0).len(), 1);
828
829                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
830                assert_eq!(batch.column(0), &expected);
831            }
832            _ => unreachable!(),
833        }
834    }
835
836    #[tokio::test]
837    async fn test_read_table() {
838        let engine = create_test_engine().await;
839
840        let engine = engine
841            .as_any()
842            .downcast_ref::<DatafusionQueryEngine>()
843            .unwrap();
844        let query_ctx = Arc::new(QueryContextBuilder::default().build());
845        let table = engine
846            .find_table(
847                &ResolvedTableReference {
848                    catalog: "greptime".into(),
849                    schema: "public".into(),
850                    table: "numbers".into(),
851                },
852                &query_ctx,
853            )
854            .await
855            .unwrap();
856
857        let df = engine.read_table(table).unwrap();
858        let df = df
859            .select_columns(&["number"])
860            .unwrap()
861            .filter(col("number").lt(lit(10)))
862            .unwrap();
863        let batches = df.collect().await.unwrap();
864        assert_eq!(1, batches.len());
865        let batch = &batches[0];
866
867        assert_eq!(1, batch.num_columns());
868        assert_eq!(batch.column(0).len(), 10);
869
870        assert_eq!(
871            Helper::try_into_vector(batch.column(0)).unwrap(),
872            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
873        );
874    }
875
876    #[tokio::test]
877    async fn test_describe() {
878        let engine = create_test_engine().await;
879        let sql = "select sum(number) from numbers limit 20";
880
881        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
882
883        let plan = engine
884            .planner()
885            .plan(&stmt, QueryContext::arc())
886            .await
887            .unwrap();
888
889        let DescribeResult { logical_plan } =
890            engine.describe(plan, QueryContext::arc()).await.unwrap();
891
892        let schema: Schema = logical_plan.schema().clone().try_into().unwrap();
893
894        assert_eq!(
895            schema.column_schemas()[0],
896            ColumnSchema::new(
897                "sum(numbers.number)",
898                ConcreteDataType::uint64_datatype(),
899                true
900            )
901        );
902        assert_eq!(
903            "Limit: skip=0, fetch=20\n  Projection: sum(numbers.number)\n    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n      TableScan: numbers",
904            format!("{}", logical_plan.display_indent())
905        );
906    }
907
908    #[tokio::test]
909    async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
910        let engine = create_test_engine().await;
911        let engine = engine
912            .as_any()
913            .downcast_ref::<DatafusionQueryEngine>()
914            .unwrap();
915        let engine_ctx = engine.engine_context(QueryContext::arc());
916        let state = engine_ctx.state();
917
918        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
919            "ts",
920            ConcreteDataType::timestamp_millisecond_datatype(),
921            false,
922        )]));
923
924        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
925        metadata_builder
926            .push_column_metadata(ColumnMetadata {
927                column_schema: ColumnSchema::new(
928                    "ts",
929                    ConcreteDataType::timestamp_millisecond_datatype(),
930                    false,
931                )
932                .with_time_index(true),
933                semantic_type: SemanticType::Timestamp,
934                column_id: 1,
935            })
936            .primary_key(vec![]);
937        let metadata = Arc::new(metadata_builder.build().unwrap());
938
939        let update_calls = Arc::new(AtomicUsize::new(0));
940        let last_filter_len = Arc::new(AtomicUsize::new(0));
941        let scanner = Box::new(RecordingScanner::new(
942            schema,
943            metadata,
944            update_calls.clone(),
945            last_filter_len.clone(),
946        ));
947        let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
948
949        let sort_expr = PhysicalSortExpr {
950            expr: Arc::new(Column::new("ts", 0)),
951            options: SortOptions {
952                descending: true,
953                ..Default::default()
954            },
955        };
956        let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
957        let mut plan: Arc<dyn ExecutionPlan> =
958            Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
959
960        for optimizer in state.physical_optimizers() {
961            plan = optimizer.optimize(plan, state.config_options()).unwrap();
962        }
963
964        assert!(update_calls.load(Ordering::Relaxed) > 0);
965        assert!(last_filter_len.load(Ordering::Relaxed) > 0);
966    }
967
968    #[tokio::test]
969    async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
970        let engine = create_test_engine().await;
971        let engine = engine
972            .as_any()
973            .downcast_ref::<DatafusionQueryEngine>()
974            .unwrap();
975        let engine_ctx = engine.engine_context(QueryContext::arc());
976        let state = engine_ctx.state();
977
978        assert!(
979            state
980                .config_options()
981                .optimizer
982                .enable_join_dynamic_filter_pushdown
983        );
984
985        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
986            "ts",
987            ConcreteDataType::timestamp_millisecond_datatype(),
988            false,
989        )]));
990
991        let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
992        left_metadata_builder
993            .push_column_metadata(ColumnMetadata {
994                column_schema: ColumnSchema::new(
995                    "ts",
996                    ConcreteDataType::timestamp_millisecond_datatype(),
997                    false,
998                )
999                .with_time_index(true),
1000                semantic_type: SemanticType::Timestamp,
1001                column_id: 1,
1002            })
1003            .primary_key(vec![]);
1004        let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
1005
1006        let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
1007        right_metadata_builder
1008            .push_column_metadata(ColumnMetadata {
1009                column_schema: ColumnSchema::new(
1010                    "ts",
1011                    ConcreteDataType::timestamp_millisecond_datatype(),
1012                    false,
1013                )
1014                .with_time_index(true),
1015                semantic_type: SemanticType::Timestamp,
1016                column_id: 1,
1017            })
1018            .primary_key(vec![]);
1019        let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1020
1021        let left_update_calls = Arc::new(AtomicUsize::new(0));
1022        let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1023        let right_update_calls = Arc::new(AtomicUsize::new(0));
1024        let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1025
1026        let left_scan = Arc::new(
1027            RegionScanExec::new(
1028                Box::new(RecordingScanner::new(
1029                    schema.clone(),
1030                    left_metadata,
1031                    left_update_calls.clone(),
1032                    left_last_filter_len.clone(),
1033                )),
1034                ScanRequest::default(),
1035                None,
1036            )
1037            .unwrap(),
1038        );
1039        let right_scan = Arc::new(
1040            RegionScanExec::new(
1041                Box::new(RecordingScanner::new(
1042                    schema,
1043                    right_metadata,
1044                    right_update_calls.clone(),
1045                    right_last_filter_len.clone(),
1046                )),
1047                ScanRequest::default(),
1048                None,
1049            )
1050            .unwrap(),
1051        );
1052
1053        let on: JoinOn = vec![(
1054            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1055            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1056        )];
1057
1058        let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1059            HashJoinExec::try_new(
1060                left_scan,
1061                right_scan,
1062                on,
1063                None,
1064                &JoinType::Inner,
1065                None,
1066                PartitionMode::CollectLeft,
1067                NullEquality::NullEqualsNull,
1068                false,
1069            )
1070            .unwrap(),
1071        );
1072
1073        for optimizer in state.physical_optimizers() {
1074            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1075        }
1076
1077        assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1078        assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1079        assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1080        assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1081    }
1082}