query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function::FunctionContext;
29use common_function::function_factory::ScalarFunctionFactory;
30use common_query::{Output, OutputData, OutputMeta};
31use common_recordbatch::adapter::RecordBatchStreamAdapter;
32use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
33use common_telemetry::tracing;
34use datafusion::catalog::TableFunction;
35use datafusion::dataframe::DataFrame;
36use datafusion::physical_plan::ExecutionPlan;
37use datafusion::physical_plan::analyze::AnalyzeExec;
38use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
39use datafusion_common::ResolvedTableReference;
40use datafusion_expr::{
41    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp,
42};
43use datatypes::prelude::VectorRef;
44use datatypes::schema::Schema;
45use futures_util::StreamExt;
46use session::context::QueryContextRef;
47use snafu::{OptionExt, ResultExt, ensure};
48use sqlparser::ast::AnalyzeFormat;
49use table::TableRef;
50use table::requests::{DeleteRequest, InsertRequest};
51use tracing::Span;
52
53use crate::analyze::DistAnalyzeExec;
54pub use crate::datafusion::planner::DfContextProviderAdapter;
55use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
56use crate::error::{
57    CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
58    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
59    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
60};
61use crate::executor::QueryExecutor;
62use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
63use crate::physical_wrapper::PhysicalPlanWrapperRef;
64use crate::planner::{DfLogicalPlanner, LogicalPlanner};
65use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
66use crate::{QueryEngine, metrics};
67
68/// Query parallelism hint key.
69/// This hint can be set in the query context to control the parallelism of the query execution.
70pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
71
72/// Whether to fallback to the original plan when failed to push down.
73pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
74
75pub struct DatafusionQueryEngine {
76    state: Arc<QueryEngineState>,
77    plugins: Plugins,
78}
79
80impl DatafusionQueryEngine {
81    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
82        Self { state, plugins }
83    }
84
85    #[tracing::instrument(skip_all)]
86    async fn exec_query_plan(
87        &self,
88        plan: LogicalPlan,
89        query_ctx: QueryContextRef,
90    ) -> Result<Output> {
91        let mut ctx = self.engine_context(query_ctx.clone());
92
93        // `create_physical_plan` will optimize logical plan internally
94        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
95        let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
96
97        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
98            wrapper.wrap(optimized_physical_plan, query_ctx)
99        } else {
100            optimized_physical_plan
101        };
102
103        Ok(Output::new(
104            OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
105            OutputMeta::new_with_plan(physical_plan),
106        ))
107    }
108
109    #[tracing::instrument(skip_all)]
110    async fn exec_dml_statement(
111        &self,
112        dml: DmlStatement,
113        query_ctx: QueryContextRef,
114    ) -> Result<Output> {
115        ensure!(
116            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
117            UnsupportedExprSnafu {
118                name: format!("DML op {}", dml.op),
119            }
120        );
121
122        let _timer = QUERY_STAGE_ELAPSED
123            .with_label_values(&[dml.op.name()])
124            .start_timer();
125
126        let default_catalog = &query_ctx.current_catalog().to_owned();
127        let default_schema = &query_ctx.current_schema();
128        let table_name = dml.table_name.resolve(default_catalog, default_schema);
129        let table = self.find_table(&table_name, &query_ctx).await?;
130
131        let output = self
132            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
133            .await?;
134        let mut stream = match output.data {
135            OutputData::RecordBatches(batches) => batches.as_stream(),
136            OutputData::Stream(stream) => stream,
137            _ => unreachable!(),
138        };
139
140        let mut affected_rows = 0;
141        let mut insert_cost = 0;
142
143        while let Some(batch) = stream.next().await {
144            let batch = batch.context(CreateRecordBatchSnafu)?;
145            let column_vectors = batch
146                .column_vectors(&table_name.to_string(), table.schema())
147                .map_err(BoxedError::new)
148                .context(QueryExecutionSnafu)?;
149
150            match dml.op {
151                WriteOp::Insert(_) => {
152                    // We ignore the insert op.
153                    let output = self
154                        .insert(&table_name, column_vectors, query_ctx.clone())
155                        .await?;
156                    let (rows, cost) = output.extract_rows_and_cost();
157                    affected_rows += rows;
158                    insert_cost += cost;
159                }
160                WriteOp::Delete => {
161                    affected_rows += self
162                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
163                        .await?;
164                }
165                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
166            }
167        }
168        Ok(Output::new(
169            OutputData::AffectedRows(affected_rows),
170            OutputMeta::new_with_cost(insert_cost),
171        ))
172    }
173
174    #[tracing::instrument(skip_all)]
175    async fn delete(
176        &self,
177        table_name: &ResolvedTableReference,
178        table: &TableRef,
179        column_vectors: HashMap<String, VectorRef>,
180        query_ctx: QueryContextRef,
181    ) -> Result<usize> {
182        let catalog_name = table_name.catalog.to_string();
183        let schema_name = table_name.schema.to_string();
184        let table_name = table_name.table.to_string();
185        let table_schema = table.schema();
186
187        ensure!(
188            !is_readonly_schema(&schema_name),
189            TableReadOnlySnafu { table: table_name }
190        );
191
192        let ts_column = table_schema
193            .timestamp_column()
194            .map(|x| &x.name)
195            .with_context(|| MissingTimestampColumnSnafu {
196                table_name: table_name.clone(),
197            })?;
198
199        let table_info = table.table_info();
200        let rowkey_columns = table_info
201            .meta
202            .row_key_column_names()
203            .collect::<Vec<&String>>();
204        let column_vectors = column_vectors
205            .into_iter()
206            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
207            .collect::<HashMap<_, _>>();
208
209        let request = DeleteRequest {
210            catalog_name,
211            schema_name,
212            table_name,
213            key_column_values: column_vectors,
214        };
215
216        self.state
217            .table_mutation_handler()
218            .context(MissingTableMutationHandlerSnafu)?
219            .delete(request, query_ctx)
220            .await
221            .context(TableMutationSnafu)
222    }
223
224    #[tracing::instrument(skip_all)]
225    async fn insert(
226        &self,
227        table_name: &ResolvedTableReference,
228        column_vectors: HashMap<String, VectorRef>,
229        query_ctx: QueryContextRef,
230    ) -> Result<Output> {
231        let catalog_name = table_name.catalog.to_string();
232        let schema_name = table_name.schema.to_string();
233        let table_name = table_name.table.to_string();
234
235        ensure!(
236            !is_readonly_schema(&schema_name),
237            TableReadOnlySnafu { table: table_name }
238        );
239
240        let request = InsertRequest {
241            catalog_name,
242            schema_name,
243            table_name,
244            columns_values: column_vectors,
245        };
246
247        self.state
248            .table_mutation_handler()
249            .context(MissingTableMutationHandlerSnafu)?
250            .insert(request, query_ctx)
251            .await
252            .context(TableMutationSnafu)
253    }
254
255    async fn find_table(
256        &self,
257        table_name: &ResolvedTableReference,
258        query_context: &QueryContextRef,
259    ) -> Result<TableRef> {
260        let catalog_name = table_name.catalog.as_ref();
261        let schema_name = table_name.schema.as_ref();
262        let table_name = table_name.table.as_ref();
263
264        self.state
265            .catalog_manager()
266            .table(catalog_name, schema_name, table_name, Some(query_context))
267            .await
268            .context(CatalogSnafu)?
269            .with_context(|| TableNotFoundSnafu { table: table_name })
270    }
271
272    #[tracing::instrument(skip_all)]
273    async fn create_physical_plan(
274        &self,
275        ctx: &mut QueryEngineContext,
276        logical_plan: &LogicalPlan,
277    ) -> Result<Arc<dyn ExecutionPlan>> {
278        /// Only print context on panic, to avoid cluttering logs.
279        ///
280        /// TODO(discord9): remove this once we catch the bug
281        #[derive(Debug)]
282        struct PanicLogger<'a> {
283            input_logical_plan: &'a LogicalPlan,
284            after_analyze: Option<LogicalPlan>,
285            after_optimize: Option<LogicalPlan>,
286            phy_plan: Option<Arc<dyn ExecutionPlan>>,
287        }
288        impl Drop for PanicLogger<'_> {
289            fn drop(&mut self) {
290                if std::thread::panicking() {
291                    common_telemetry::error!(
292                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
293                        self.input_logical_plan,
294                        self.after_analyze,
295                        self.after_optimize,
296                        self.phy_plan
297                    );
298                }
299            }
300        }
301
302        let mut logger = PanicLogger {
303            input_logical_plan: logical_plan,
304            after_analyze: None,
305            after_optimize: None,
306            phy_plan: None,
307        };
308
309        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
310        let state = ctx.state();
311
312        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
313
314        // special handle EXPLAIN plan
315        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
316            return state
317                .create_physical_plan(logical_plan)
318                .await
319                .context(error::DatafusionSnafu)
320                .map_err(BoxedError::new)
321                .context(QueryExecutionSnafu);
322        }
323
324        // analyze first
325        let analyzed_plan = state
326            .analyzer()
327            .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
328            .context(error::DatafusionSnafu)
329            .map_err(BoxedError::new)
330            .context(QueryExecutionSnafu)?;
331
332        logger.after_analyze = Some(analyzed_plan.clone());
333
334        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
335
336        // skip optimize for MergeScan
337        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
338            && ext.node.name() == MergeScanLogicalPlan::name()
339        {
340            analyzed_plan.clone()
341        } else {
342            state
343                .optimizer()
344                .optimize(analyzed_plan, state, |_, _| {})
345                .context(error::DatafusionSnafu)
346                .map_err(BoxedError::new)
347                .context(QueryExecutionSnafu)?
348        };
349
350        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
351        logger.after_optimize = Some(optimized_plan.clone());
352
353        let physical_plan = state
354            .query_planner()
355            .create_physical_plan(&optimized_plan, state)
356            .await?;
357
358        logger.phy_plan = Some(physical_plan.clone());
359        drop(logger);
360        Ok(physical_plan)
361    }
362
363    #[tracing::instrument(skip_all)]
364    pub fn optimize(
365        &self,
366        context: &QueryEngineContext,
367        plan: &LogicalPlan,
368    ) -> Result<LogicalPlan> {
369        let _timer = metrics::OPTIMIZE_LOGICAL_ELAPSED.start_timer();
370
371        // Optimized by extension rules
372        let optimized_plan = self
373            .state
374            .optimize_by_extension_rules(plan.clone(), context)
375            .context(error::DatafusionSnafu)
376            .map_err(BoxedError::new)
377            .context(QueryExecutionSnafu)?;
378
379        // Optimized by datafusion optimizer
380        let optimized_plan = self
381            .state
382            .session_state()
383            .optimize(&optimized_plan)
384            .context(error::DatafusionSnafu)
385            .map_err(BoxedError::new)
386            .context(QueryExecutionSnafu)?;
387
388        Ok(optimized_plan)
389    }
390
391    #[tracing::instrument(skip_all)]
392    fn optimize_physical_plan(
393        &self,
394        ctx: &mut QueryEngineContext,
395        plan: Arc<dyn ExecutionPlan>,
396    ) -> Result<Arc<dyn ExecutionPlan>> {
397        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
398
399        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
400        // if we need to optimize it again here.
401        // let state = ctx.state();
402        // let config = state.config_options();
403
404        // skip optimize AnalyzeExec plan
405        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
406        {
407            let format = if let Some(format) = ctx.query_ctx().explain_format()
408                && format.to_lowercase() == "json"
409            {
410                AnalyzeFormat::JSON
411            } else {
412                AnalyzeFormat::TEXT
413            };
414            // Sets the verbose flag of the query context.
415            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
416            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
417
418            Arc::new(DistAnalyzeExec::new(
419                analyze_plan.input().clone(),
420                analyze_plan.verbose(),
421                format,
422            ))
423            // let mut new_plan = analyze_plan.input().clone();
424            // for optimizer in state.physical_optimizers() {
425            //     new_plan = optimizer
426            //         .optimize(new_plan, config)
427            //         .context(DataFusionSnafu)?;
428            // }
429            // Arc::new(DistAnalyzeExec::new(new_plan))
430        } else {
431            plan
432            // let mut new_plan = plan;
433            // for optimizer in state.physical_optimizers() {
434            //     new_plan = optimizer
435            //         .optimize(new_plan, config)
436            //         .context(DataFusionSnafu)?;
437            // }
438            // new_plan
439        };
440
441        Ok(optimized_plan)
442    }
443}
444
445#[async_trait]
446impl QueryEngine for DatafusionQueryEngine {
447    fn as_any(&self) -> &dyn Any {
448        self
449    }
450
451    fn planner(&self) -> Arc<dyn LogicalPlanner> {
452        Arc::new(DfLogicalPlanner::new(self.state.clone()))
453    }
454
455    fn name(&self) -> &str {
456        "datafusion"
457    }
458
459    async fn describe(
460        &self,
461        plan: LogicalPlan,
462        query_ctx: QueryContextRef,
463    ) -> Result<DescribeResult> {
464        let ctx = self.engine_context(query_ctx);
465        if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
466            let schema = optimised_plan
467                .schema()
468                .clone()
469                .try_into()
470                .context(ConvertSchemaSnafu)?;
471            Ok(DescribeResult {
472                schema,
473                logical_plan: optimised_plan,
474            })
475        } else {
476            // Table's like those in information_schema cannot be optimized when
477            // it contains parameters. So we fallback to original plans.
478            let schema = plan
479                .schema()
480                .clone()
481                .try_into()
482                .context(ConvertSchemaSnafu)?;
483            Ok(DescribeResult {
484                schema,
485                logical_plan: plan,
486            })
487        }
488    }
489
490    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
491        match plan {
492            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
493            _ => self.exec_query_plan(plan, query_ctx).await,
494        }
495    }
496
497    /// Note in SQL queries, aggregate names are looked up using
498    /// lowercase unless the query uses quotes. For example,
499    ///
500    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
501    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
502    ///
503    /// So it's better to make UDAF name lowercase when creating one.
504    fn register_aggregate_function(&self, func: AggregateUDF) {
505        self.state.register_aggr_function(func);
506    }
507
508    /// Register an scalar function.
509    /// Will override if the function with same name is already registered.
510    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
511        self.state.register_scalar_function(func);
512    }
513
514    fn register_table_function(&self, func: Arc<TableFunction>) {
515        self.state.register_table_function(func);
516    }
517
518    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
519        self.state
520            .read_table(table)
521            .context(error::DatafusionSnafu)
522            .map_err(BoxedError::new)
523            .context(QueryExecutionSnafu)
524    }
525
526    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
527        let mut state = self.state.session_state();
528        state.config_mut().set_extension(query_ctx.clone());
529        // note that hints in "x-greptime-hints" is automatically parsed
530        // and set to query context's extension, so we can get it from query context.
531        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
532            if let Ok(n) = parallelism.parse::<u64>() {
533                if n > 0 {
534                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
535                    *state.config_mut() = new_cfg;
536                }
537            } else {
538                common_telemetry::warn!(
539                    "Failed to parse query_parallelism: {}, using default value",
540                    parallelism
541                );
542            }
543        }
544
545        // configure execution options
546        state.config_mut().options_mut().execution.time_zone = query_ctx.timezone().to_string();
547
548        // usually it's impossible to have both `set variable` set by sql client and
549        // hint in header by grpc client, so only need to deal with them separately
550        if query_ctx.configuration_parameter().allow_query_fallback() {
551            state
552                .config_mut()
553                .options_mut()
554                .extensions
555                .insert(DistPlannerOptions {
556                    allow_query_fallback: true,
557                });
558        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
559            // also check the query context for fallback hint
560            // if it is set, we will enable the fallback
561            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
562                state
563                    .config_mut()
564                    .options_mut()
565                    .extensions
566                    .insert(DistPlannerOptions {
567                        allow_query_fallback: true,
568                    });
569            }
570        }
571
572        state
573            .config_mut()
574            .options_mut()
575            .extensions
576            .insert(FunctionContext {
577                query_ctx: query_ctx.clone(),
578                state: self.engine_state().function_state(),
579            });
580
581        let config_options = state.config_options().clone();
582        let _ = state
583            .execution_props_mut()
584            .config_options
585            .insert(config_options);
586
587        QueryEngineContext::new(state, query_ctx)
588    }
589
590    fn engine_state(&self) -> &QueryEngineState {
591        &self.state
592    }
593}
594
595impl QueryExecutor for DatafusionQueryEngine {
596    #[tracing::instrument(skip_all)]
597    fn execute_stream(
598        &self,
599        ctx: &QueryEngineContext,
600        plan: &Arc<dyn ExecutionPlan>,
601    ) -> Result<SendableRecordBatchStream> {
602        let explain_verbose = ctx.query_ctx().explain_verbose();
603        let output_partitions = plan.properties().output_partitioning().partition_count();
604        if explain_verbose {
605            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
606        }
607
608        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
609        let task_ctx = ctx.build_task_ctx();
610        let span = Span::current();
611
612        match plan.properties().output_partitioning().partition_count() {
613            0 => {
614                let schema = Arc::new(
615                    Schema::try_from(plan.schema())
616                        .map_err(BoxedError::new)
617                        .context(QueryExecutionSnafu)?,
618                );
619                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
620            }
621            1 => {
622                let df_stream = plan
623                    .execute(0, task_ctx)
624                    .context(error::DatafusionSnafu)
625                    .map_err(BoxedError::new)
626                    .context(QueryExecutionSnafu)?;
627                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
628                    .context(error::ConvertDfRecordBatchStreamSnafu)
629                    .map_err(BoxedError::new)
630                    .context(QueryExecutionSnafu)?;
631                stream.set_metrics2(plan.clone());
632                stream.set_explain_verbose(explain_verbose);
633                let stream = OnDone::new(Box::pin(stream), move || {
634                    let exec_cost = exec_timer.stop_and_record();
635                    if explain_verbose {
636                        common_telemetry::info!(
637                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
638                            exec_cost,
639                        );
640                    }
641                });
642                Ok(Box::pin(stream))
643            }
644            _ => {
645                // merge into a single partition
646                let merged_plan = CoalescePartitionsExec::new(plan.clone());
647                // CoalescePartitionsExec must produce a single partition
648                assert_eq!(
649                    1,
650                    merged_plan
651                        .properties()
652                        .output_partitioning()
653                        .partition_count()
654                );
655                let df_stream = merged_plan
656                    .execute(0, task_ctx)
657                    .context(error::DatafusionSnafu)
658                    .map_err(BoxedError::new)
659                    .context(QueryExecutionSnafu)?;
660                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
661                    .context(error::ConvertDfRecordBatchStreamSnafu)
662                    .map_err(BoxedError::new)
663                    .context(QueryExecutionSnafu)?;
664                stream.set_metrics2(plan.clone());
665                stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
666                let stream = OnDone::new(Box::pin(stream), move || {
667                    let exec_cost = exec_timer.stop_and_record();
668                    if explain_verbose {
669                        common_telemetry::info!(
670                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
671                            exec_cost
672                        );
673                    }
674                });
675                Ok(Box::pin(stream))
676            }
677        }
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use std::sync::Arc;
684
685    use arrow::array::{ArrayRef, UInt64Array};
686    use catalog::RegisterTableRequest;
687    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
688    use common_recordbatch::util;
689    use datafusion::prelude::{col, lit};
690    use datatypes::prelude::ConcreteDataType;
691    use datatypes::schema::ColumnSchema;
692    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
693    use session::context::{QueryContext, QueryContextBuilder};
694    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
695
696    use super::*;
697    use crate::options::QueryOptions;
698    use crate::parser::QueryLanguageParser;
699    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
700
701    async fn create_test_engine() -> QueryEngineRef {
702        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
703        let req = RegisterTableRequest {
704            catalog: DEFAULT_CATALOG_NAME.to_string(),
705            schema: DEFAULT_SCHEMA_NAME.to_string(),
706            table_name: NUMBERS_TABLE_NAME.to_string(),
707            table_id: NUMBERS_TABLE_ID,
708            table: NumbersTable::table(NUMBERS_TABLE_ID),
709        };
710        catalog_manager.register_table_sync(req).unwrap();
711
712        QueryEngineFactory::new(
713            catalog_manager,
714            None,
715            None,
716            None,
717            None,
718            false,
719            QueryOptions::default(),
720        )
721        .query_engine()
722    }
723
724    #[tokio::test]
725    async fn test_sql_to_plan() {
726        let engine = create_test_engine().await;
727        let sql = "select sum(number) from numbers limit 20";
728
729        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
730        let plan = engine
731            .planner()
732            .plan(&stmt, QueryContext::arc())
733            .await
734            .unwrap();
735
736        assert_eq!(
737            plan.to_string(),
738            r#"Limit: skip=0, fetch=20
739  Projection: sum(numbers.number)
740    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
741      TableScan: numbers"#
742        );
743    }
744
745    #[tokio::test]
746    async fn test_execute() {
747        let engine = create_test_engine().await;
748        let sql = "select sum(number) from numbers limit 20";
749
750        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
751        let plan = engine
752            .planner()
753            .plan(&stmt, QueryContext::arc())
754            .await
755            .unwrap();
756
757        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
758
759        match output.data {
760            OutputData::Stream(recordbatch) => {
761                let numbers = util::collect(recordbatch).await.unwrap();
762                assert_eq!(1, numbers.len());
763                assert_eq!(numbers[0].num_columns(), 1);
764                assert_eq!(1, numbers[0].schema.num_columns());
765                assert_eq!(
766                    "sum(numbers.number)",
767                    numbers[0].schema.column_schemas()[0].name
768                );
769
770                let batch = &numbers[0];
771                assert_eq!(1, batch.num_columns());
772                assert_eq!(batch.column(0).len(), 1);
773
774                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
775                assert_eq!(batch.column(0), &expected);
776            }
777            _ => unreachable!(),
778        }
779    }
780
781    #[tokio::test]
782    async fn test_read_table() {
783        let engine = create_test_engine().await;
784
785        let engine = engine
786            .as_any()
787            .downcast_ref::<DatafusionQueryEngine>()
788            .unwrap();
789        let query_ctx = Arc::new(QueryContextBuilder::default().build());
790        let table = engine
791            .find_table(
792                &ResolvedTableReference {
793                    catalog: "greptime".into(),
794                    schema: "public".into(),
795                    table: "numbers".into(),
796                },
797                &query_ctx,
798            )
799            .await
800            .unwrap();
801
802        let df = engine.read_table(table).unwrap();
803        let df = df
804            .select_columns(&["number"])
805            .unwrap()
806            .filter(col("number").lt(lit(10)))
807            .unwrap();
808        let batches = df.collect().await.unwrap();
809        assert_eq!(1, batches.len());
810        let batch = &batches[0];
811
812        assert_eq!(1, batch.num_columns());
813        assert_eq!(batch.column(0).len(), 10);
814
815        assert_eq!(
816            Helper::try_into_vector(batch.column(0)).unwrap(),
817            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
818        );
819    }
820
821    #[tokio::test]
822    async fn test_describe() {
823        let engine = create_test_engine().await;
824        let sql = "select sum(number) from numbers limit 20";
825
826        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
827
828        let plan = engine
829            .planner()
830            .plan(&stmt, QueryContext::arc())
831            .await
832            .unwrap();
833
834        let DescribeResult {
835            schema,
836            logical_plan,
837        } = engine.describe(plan, QueryContext::arc()).await.unwrap();
838
839        assert_eq!(
840            schema.column_schemas()[0],
841            ColumnSchema::new(
842                "sum(numbers.number)",
843                ConcreteDataType::uint64_datatype(),
844                true
845            )
846        );
847        assert_eq!(
848            "Limit: skip=0, fetch=20\n  Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n    TableScan: numbers projection=[number]",
849            format!("{}", logical_plan.display_indent())
850        );
851    }
852}