Skip to main content

query/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use arrow_schema::DataType;
22use async_trait::async_trait;
23use catalog::table_source::DfTableSourceProvider;
24use common_error::ext::BoxedError;
25use common_telemetry::tracing;
26use datafusion::common::{DFSchema, plan_err};
27use datafusion::execution::context::SessionState;
28use datafusion::sql::planner::PlannerContext;
29use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
30use datafusion_common::{ScalarValue, ToDFSchema};
31use datafusion_expr::expr::{Exists, InSubquery};
32use datafusion_expr::{
33    Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
34    ToStringifiedPlan, col,
35};
36use datafusion_sql::planner::{ParserOptions, SqlToRel};
37use log_query::LogQuery;
38use promql_parser::parser::EvalStmt;
39use session::context::QueryContextRef;
40use snafu::{ResultExt, ensure};
41use sql::CteContent;
42use sql::ast::Expr as SqlExpr;
43use sql::statements::explain::ExplainStatement;
44use sql::statements::query::Query;
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47
48use crate::error::{
49    CteColumnSchemaMismatchSnafu, PlanSqlSnafu, QueryPlanSnafu, Result, SqlSnafu,
50    UnimplementedSnafu,
51};
52use crate::log_query::planner::LogQueryPlanner;
53use crate::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
54use crate::promql::planner::PromPlanner;
55use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
56use crate::range_select::plan_rewrite::RangePlanRewriter;
57use crate::{DfContextProviderAdapter, QueryEngineContext};
58
59#[async_trait]
60pub trait LogicalPlanner: Send + Sync {
61    async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan>;
62
63    async fn plan_logs_query(
64        &self,
65        query: LogQuery,
66        query_ctx: QueryContextRef,
67    ) -> Result<LogicalPlan>;
68
69    fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan>;
70
71    fn as_any(&self) -> &dyn Any;
72}
73
74pub struct DfLogicalPlanner {
75    engine_state: Arc<QueryEngineState>,
76    session_state: SessionState,
77}
78
79impl DfLogicalPlanner {
80    pub fn new(engine_state: Arc<QueryEngineState>) -> Self {
81        let session_state = engine_state.session_state();
82        Self {
83            engine_state,
84            session_state,
85        }
86    }
87
88    /// Basically the same with `explain_to_plan` in DataFusion, but adapted to Greptime's
89    /// `plan_sql` to support Greptime Statements.
90    async fn explain_to_plan(
91        &self,
92        explain: &ExplainStatement,
93        query_ctx: QueryContextRef,
94    ) -> Result<LogicalPlan> {
95        let plan = self.plan_sql(&explain.statement, query_ctx).await?;
96        if matches!(plan, LogicalPlan::Explain(_)) {
97            return plan_err!("Nested EXPLAINs are not supported").context(PlanSqlSnafu);
98        }
99
100        let verbose = explain.verbose;
101        let analyze = explain.analyze;
102        let format = explain.format.map(|f| f.to_string());
103
104        let plan = Arc::new(plan);
105        let schema = LogicalPlan::explain_schema();
106        let schema = ToDFSchema::to_dfschema_ref(schema)?;
107
108        if verbose && format.is_some() {
109            return plan_err!("EXPLAIN VERBOSE with FORMAT is not supported").context(PlanSqlSnafu);
110        }
111
112        if analyze {
113            // notice format is already set in query context, so can be ignore here
114            Ok(LogicalPlan::Analyze(Analyze {
115                verbose,
116                input: plan,
117                schema,
118            }))
119        } else {
120            let stringified_plans = vec![plan.to_stringified(PlanType::InitialLogicalPlan)];
121
122            // default to configuration value
123            let options = self.session_state.config().options();
124            let format = format
125                .map(|x| ExplainFormat::from_str(&x))
126                .transpose()?
127                .unwrap_or_else(|| options.explain.format.clone());
128
129            Ok(LogicalPlan::Explain(Explain {
130                verbose,
131                explain_format: format,
132                plan,
133                stringified_plans,
134                schema,
135                logical_optimization_succeeded: false,
136            }))
137        }
138    }
139
140    #[tracing::instrument(skip_all)]
141    #[async_recursion::async_recursion]
142    async fn plan_sql(&self, stmt: &Statement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
143        let mut planner_context = PlannerContext::new();
144        let mut stmt = Cow::Borrowed(stmt);
145        let mut is_tql_cte = false;
146
147        // handle explain before normal processing so we can explain Greptime Statements
148        if let Statement::Explain(explain) = stmt.as_ref() {
149            return self.explain_to_plan(explain, query_ctx).await;
150        }
151
152        // Check for hybrid CTEs before normal processing
153        if self.has_hybrid_ctes(stmt.as_ref()) {
154            let stmt_owned = stmt.into_owned();
155            let mut query = match stmt_owned {
156                Statement::Query(query) => query.as_ref().clone(),
157                _ => unreachable!("has_hybrid_ctes should only return true for Query statements"),
158            };
159            self.plan_query_with_hybrid_ctes(&query, query_ctx.clone(), &mut planner_context)
160                .await?;
161
162            // remove the processed TQL CTEs from the query
163            query.hybrid_cte = None;
164            stmt = Cow::Owned(Statement::Query(Box::new(query)));
165            is_tql_cte = true;
166        }
167
168        let mut df_stmt = stmt.as_ref().try_into().context(SqlSnafu)?;
169
170        // TODO(LFC): Remove this when Datafusion supports **both** the syntax and implementation of "explain with format".
171        if let datafusion::sql::parser::Statement::Statement(
172            box datafusion::sql::sqlparser::ast::Statement::Explain { .. },
173        ) = &mut df_stmt
174        {
175            UnimplementedSnafu {
176                operation: "EXPLAIN with FORMAT using raw datafusion planner",
177            }
178            .fail()?;
179        }
180
181        let table_provider = DfTableSourceProvider::new(
182            self.engine_state.catalog_manager().clone(),
183            self.engine_state.disallow_cross_catalog_query(),
184            query_ctx.clone(),
185            Arc::new(DefaultPlanDecoder::new(
186                self.session_state.clone(),
187                &query_ctx,
188            )?),
189            self.session_state
190                .config_options()
191                .sql_parser
192                .enable_ident_normalization,
193        );
194
195        let context_provider = DfContextProviderAdapter::try_new(
196            self.engine_state.clone(),
197            self.session_state.clone(),
198            Some(&df_stmt),
199            query_ctx.clone(),
200        )
201        .await?;
202
203        let config_options = self.session_state.config().options();
204        let parser_options = &config_options.sql_parser;
205        let parser_options = ParserOptions {
206            map_string_types_to_utf8view: false,
207            ..parser_options.into()
208        };
209
210        let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
211
212        // this IF is to handle different version of ASTs
213        let result = if is_tql_cte {
214            let Statement::Query(query) = stmt.into_owned() else {
215                unreachable!("is_tql_cte should only be true for Query statements");
216            };
217            let sqlparser_stmt = sqlparser::ast::Statement::Query(Box::new(query.inner));
218            sql_to_rel
219                .sql_statement_to_plan_with_context(sqlparser_stmt, &mut planner_context)
220                .context(PlanSqlSnafu)?
221        } else {
222            sql_to_rel
223                .statement_to_plan(df_stmt)
224                .context(PlanSqlSnafu)?
225        };
226
227        common_telemetry::debug!("Logical planner, statement to plan result: {result}");
228        let plan = RangePlanRewriter::new(table_provider, query_ctx.clone())
229            .rewrite(result)
230            .await?;
231
232        // Optimize logical plan by extension rules
233        let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
234        let plan = self
235            .engine_state
236            .optimize_by_extension_rules(plan, &context)?;
237        common_telemetry::debug!("Logical planner, optimize result: {plan}");
238
239        Ok(plan)
240    }
241
242    /// Generate a relational expression from a SQL expression
243    #[tracing::instrument(skip_all)]
244    pub(crate) async fn sql_to_expr(
245        &self,
246        sql: SqlExpr,
247        schema: &DFSchema,
248        normalize_ident: bool,
249        query_ctx: QueryContextRef,
250    ) -> Result<DfExpr> {
251        let context_provider = DfContextProviderAdapter::try_new(
252            self.engine_state.clone(),
253            self.session_state.clone(),
254            None,
255            query_ctx,
256        )
257        .await?;
258
259        let config_options = self.session_state.config().options();
260        let parser_options = &config_options.sql_parser;
261        let parser_options: ParserOptions = ParserOptions {
262            map_string_types_to_utf8view: false,
263            enable_ident_normalization: normalize_ident,
264            ..parser_options.into()
265        };
266
267        let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
268
269        Ok(sql_to_rel.sql_to_expr(sql, schema, &mut PlannerContext::new())?)
270    }
271
272    #[tracing::instrument(skip_all)]
273    async fn plan_pql(&self, stmt: &EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
274        let plan_decoder = Arc::new(DefaultPlanDecoder::new(
275            self.session_state.clone(),
276            &query_ctx,
277        )?);
278        let table_provider = DfTableSourceProvider::new(
279            self.engine_state.catalog_manager().clone(),
280            self.engine_state.disallow_cross_catalog_query(),
281            query_ctx.clone(),
282            plan_decoder,
283            self.session_state
284                .config_options()
285                .sql_parser
286                .enable_ident_normalization,
287        );
288        let plan = PromPlanner::stmt_to_plan(table_provider, stmt, &self.engine_state)
289            .await
290            .map_err(BoxedError::new)
291            .context(QueryPlanSnafu)?;
292
293        let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
294        Ok(self
295            .engine_state
296            .optimize_by_extension_rules(plan, &context)?)
297    }
298
299    #[tracing::instrument(skip_all)]
300    fn optimize_logical_plan(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
301        Ok(self.engine_state.optimize_logical_plan(plan)?)
302    }
303
304    /// Check if a statement contains hybrid CTEs (mix of SQL and TQL)
305    fn has_hybrid_ctes(&self, stmt: &Statement) -> bool {
306        if let Statement::Query(query) = stmt {
307            query
308                .hybrid_cte
309                .as_ref()
310                .map(|hybrid_cte| !hybrid_cte.cte_tables.is_empty())
311                .unwrap_or(false)
312        } else {
313            false
314        }
315    }
316
317    /// Plan a query with hybrid CTEs using DataFusion's native PlannerContext
318    async fn plan_query_with_hybrid_ctes(
319        &self,
320        query: &Query,
321        query_ctx: QueryContextRef,
322        planner_context: &mut PlannerContext,
323    ) -> Result<()> {
324        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
325
326        for cte in &hybrid_cte.cte_tables {
327            match &cte.content {
328                CteContent::Tql(tql) => {
329                    // Plan TQL and register in PlannerContext
330                    let mut logical_plan = self.tql_to_logical_plan(tql, query_ctx.clone()).await?;
331                    if !cte.columns.is_empty() {
332                        let schema = logical_plan.schema();
333                        let schema_fields = schema.fields().to_vec();
334                        ensure!(
335                            schema_fields.len() == cte.columns.len(),
336                            CteColumnSchemaMismatchSnafu {
337                                cte_name: cte.name.value.clone(),
338                                original: schema_fields
339                                    .iter()
340                                    .map(|field| field.name().clone())
341                                    .collect::<Vec<_>>(),
342                                expected: cte
343                                    .columns
344                                    .iter()
345                                    .map(|column| column.to_string())
346                                    .collect::<Vec<_>>(),
347                            }
348                        );
349                        let aliases = cte
350                            .columns
351                            .iter()
352                            .zip(schema_fields.iter())
353                            .map(|(column, field)| col(field.name()).alias(column.to_string()));
354                        logical_plan = LogicalPlanBuilder::from(logical_plan)
355                            .project(aliases)
356                            .context(PlanSqlSnafu)?
357                            .build()
358                            .context(PlanSqlSnafu)?;
359                    }
360
361                    // Wrap in SubqueryAlias to ensure proper table qualification for CTE
362                    logical_plan = LogicalPlan::SubqueryAlias(
363                        datafusion_expr::SubqueryAlias::try_new(
364                            Arc::new(logical_plan),
365                            cte.name.value.clone(),
366                        )
367                        .context(PlanSqlSnafu)?,
368                    );
369
370                    planner_context.insert_cte(&cte.name.value, logical_plan);
371                }
372                CteContent::Sql(_) => {
373                    // SQL CTEs should have been moved to the main query's WITH clause
374                    // during parsing, so we shouldn't encounter them here
375                    unreachable!("SQL CTEs should not be in hybrid_cte.cte_tables");
376                }
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Convert TQL to LogicalPlan directly
384    async fn tql_to_logical_plan(
385        &self,
386        tql: &Tql,
387        query_ctx: QueryContextRef,
388    ) -> Result<LogicalPlan> {
389        match tql {
390            Tql::Eval(eval) => {
391                // Convert TqlEval to PromQuery then to QueryStatement::Promql
392                let prom_query = PromQuery {
393                    query: eval.query.clone(),
394                    start: eval.start.clone(),
395                    end: eval.end.clone(),
396                    step: eval.step.clone(),
397                    lookback: eval
398                        .lookback
399                        .clone()
400                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
401                    alias: eval.alias.clone(),
402                };
403                let stmt = QueryLanguageParser::parse_promql(&prom_query, &query_ctx)?;
404
405                self.plan(&stmt, query_ctx).await
406            }
407            Tql::Explain(_) => UnimplementedSnafu {
408                operation: "TQL EXPLAIN in CTEs",
409            }
410            .fail(),
411            Tql::Analyze(_) => UnimplementedSnafu {
412                operation: "TQL ANALYZE in CTEs",
413            }
414            .fail(),
415        }
416    }
417
418    /// Extracts cast types for all placeholders in a logical plan.
419    /// Returns a map where each placeholder ID is mapped to:
420    /// - Some(DataType) if the placeholder is cast to a specific type
421    /// - None if the placeholder exists but has no cast
422    ///
423    /// Example: `$1::TEXT` returns `{"$1": Some(DataType::Utf8)}`
424    ///
425    /// This function walks through all expressions in the logical plan,
426    /// including subqueries, to identify placeholders and their cast types.
427    fn extract_placeholder_cast_types(
428        plan: &LogicalPlan,
429    ) -> Result<HashMap<String, Option<DataType>>> {
430        let mut placeholder_types = HashMap::new();
431        let mut casted_placeholders = HashSet::new();
432
433        Self::extract_from_plan(plan, &mut placeholder_types, &mut casted_placeholders)?;
434
435        Ok(placeholder_types)
436    }
437
438    fn extract_from_plan(
439        plan: &LogicalPlan,
440        placeholder_types: &mut HashMap<String, Option<DataType>>,
441        casted_placeholders: &mut HashSet<String>,
442    ) -> Result<()> {
443        plan.apply(|node| {
444            for expr in node.expressions() {
445                let _ = expr.apply(|e| {
446                    // Handle casted placeholders
447                    if let DfExpr::Cast(cast) = e
448                        && let DfExpr::Placeholder(ph) = &*cast.expr
449                    {
450                        placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
451                        casted_placeholders.insert(ph.id.clone());
452                    }
453
454                    // Handle arrow_cast(Placeholder, 'type_string') generated by SQL rewriter
455                    if let DfExpr::ScalarFunction(scalar_func) = e
456                        && scalar_func.name() == "arrow_cast"
457                        && scalar_func.args.len() == 2
458                        && let DfExpr::Placeholder(ph) = &scalar_func.args[0]
459                        && let DfExpr::Literal(ScalarValue::Utf8(Some(type_str)), _) =
460                            &scalar_func.args[1]
461                        && let Ok(data_type) = type_str.parse::<DataType>()
462                    {
463                        placeholder_types.insert(ph.id.clone(), Some(data_type));
464                        casted_placeholders.insert(ph.id.clone());
465                    }
466
467                    // Handle bare (non-casted) placeholders
468                    if let DfExpr::Placeholder(ph) = e
469                        && !casted_placeholders.contains(&ph.id)
470                        && !placeholder_types.contains_key(&ph.id)
471                    {
472                        placeholder_types.insert(ph.id.clone(), None);
473                    }
474
475                    // Recurse into subquery plans embedded in expressions
476                    match e {
477                        DfExpr::Exists(Exists { subquery, .. })
478                        | DfExpr::InSubquery(InSubquery { subquery, .. })
479                        | DfExpr::ScalarSubquery(subquery) => {
480                            Self::extract_from_plan(
481                                &subquery.subquery,
482                                placeholder_types,
483                                casted_placeholders,
484                            )?;
485                        }
486                        _ => {}
487                    }
488
489                    Ok(TreeNodeRecursion::Continue)
490                });
491            }
492            Ok(TreeNodeRecursion::Continue)
493        })?;
494        Ok(())
495    }
496
497    /// Gets inferred parameter types from a logical plan.
498    /// Returns a map where each parameter ID is mapped to:
499    /// - Some(DataType) if the parameter type could be inferred
500    /// - None if the parameter type could not be inferred
501    ///
502    /// This function first uses DataFusion's `get_parameter_types()` to infer types.
503    /// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
504    /// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
505    ///
506    /// This is because datafusion can only infer types for a limited cases.
507    ///
508    /// Example: For query `WHERE $1::TEXT AND $2`, DataFusion may not infer `$2`'s type,
509    /// but this function will return `{"$1": Some(DataType::Utf8), "$2": None}`.
510    pub fn get_inferred_parameter_types(
511        plan: &LogicalPlan,
512    ) -> Result<HashMap<String, Option<DataType>>> {
513        let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
514
515        let has_none = param_types.values().any(|v| v.is_none());
516
517        if !has_none {
518            Ok(param_types)
519        } else {
520            let cast_types = Self::extract_placeholder_cast_types(plan)?;
521
522            let mut merged = param_types;
523
524            for (id, opt_type) in cast_types {
525                merged
526                    .entry(id)
527                    .and_modify(|existing| {
528                        if existing.is_none() {
529                            *existing = opt_type.clone();
530                        }
531                    })
532                    .or_insert(opt_type);
533            }
534
535            Ok(merged)
536        }
537    }
538}
539
540#[async_trait]
541impl LogicalPlanner for DfLogicalPlanner {
542    #[tracing::instrument(skip_all)]
543    async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
544        match stmt {
545            QueryStatement::Sql(stmt) => self.plan_sql(stmt, query_ctx).await,
546            QueryStatement::Promql(stmt, _alias) => self.plan_pql(stmt, query_ctx).await,
547        }
548    }
549
550    async fn plan_logs_query(
551        &self,
552        query: LogQuery,
553        query_ctx: QueryContextRef,
554    ) -> Result<LogicalPlan> {
555        let plan_decoder = Arc::new(DefaultPlanDecoder::new(
556            self.session_state.clone(),
557            &query_ctx,
558        )?);
559        let table_provider = DfTableSourceProvider::new(
560            self.engine_state.catalog_manager().clone(),
561            self.engine_state.disallow_cross_catalog_query(),
562            query_ctx,
563            plan_decoder,
564            self.session_state
565                .config_options()
566                .sql_parser
567                .enable_ident_normalization,
568        );
569
570        let mut planner = LogQueryPlanner::new(table_provider, self.session_state.clone());
571        planner
572            .query_to_plan(query)
573            .await
574            .map_err(BoxedError::new)
575            .context(QueryPlanSnafu)
576    }
577
578    fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
579        self.optimize_logical_plan(plan)
580    }
581
582    fn as_any(&self) -> &dyn Any {
583        self
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use std::sync::Arc;
590
591    use arrow_schema::DataType;
592    use catalog::RegisterTableRequest;
593    use catalog::memory::MemoryCatalogManager;
594    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
595    use datatypes::prelude::ConcreteDataType;
596    use datatypes::schema::{ColumnSchema, Schema};
597    use session::context::QueryContext;
598    use store_api::metric_engine_consts::{
599        DATA_SCHEMA_TABLE_ID_COLUMN_NAME, DATA_SCHEMA_TSID_COLUMN_NAME, LOGICAL_TABLE_METADATA_KEY,
600        METRIC_ENGINE_NAME,
601    };
602    use table::metadata::{TableInfoBuilder, TableMetaBuilder};
603    use table::test_util::EmptyTable;
604
605    use super::*;
606    use crate::parser::{PromQuery, QueryLanguageParser};
607    use crate::{QueryEngineFactory, QueryEngineRef};
608
609    async fn create_test_engine() -> QueryEngineRef {
610        let columns = vec![
611            ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
612            ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
613        ];
614        let schema = Arc::new(Schema::new(columns));
615        let table_meta = TableMetaBuilder::empty()
616            .schema(schema)
617            .primary_key_indices(vec![0])
618            .value_indices(vec![1])
619            .next_column_id(1024)
620            .build()
621            .unwrap();
622        let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
623        let table = EmptyTable::from_table_info(&table_info);
624
625        crate::tests::new_query_engine_with_table(table)
626    }
627
628    fn create_promql_test_engine() -> QueryEngineRef {
629        let catalog_manager = MemoryCatalogManager::with_default_setup();
630        let physical_table_name = "phy";
631        let physical_table_id = 999u32;
632
633        let physical_schema = Arc::new(Schema::new(vec![
634            ColumnSchema::new(
635                DATA_SCHEMA_TABLE_ID_COLUMN_NAME.to_string(),
636                ConcreteDataType::uint32_datatype(),
637                false,
638            ),
639            ColumnSchema::new(
640                DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
641                ConcreteDataType::uint64_datatype(),
642                false,
643            ),
644            ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
645            ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
646            ColumnSchema::new(
647                "timestamp",
648                ConcreteDataType::timestamp_millisecond_datatype(),
649                false,
650            )
651            .with_time_index(true),
652            ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
653        ]));
654        let physical_meta = TableMetaBuilder::empty()
655            .schema(physical_schema)
656            .primary_key_indices(vec![0, 1, 2, 3])
657            .value_indices(vec![4, 5])
658            .engine(METRIC_ENGINE_NAME.to_string())
659            .next_column_id(1024)
660            .build()
661            .unwrap();
662        let physical_info = TableInfoBuilder::default()
663            .table_id(physical_table_id)
664            .name(physical_table_name)
665            .meta(physical_meta)
666            .build()
667            .unwrap();
668        catalog_manager
669            .register_table_sync(RegisterTableRequest {
670                catalog: DEFAULT_CATALOG_NAME.to_string(),
671                schema: DEFAULT_SCHEMA_NAME.to_string(),
672                table_name: physical_table_name.to_string(),
673                table_id: physical_table_id,
674                table: EmptyTable::from_table_info(&physical_info),
675            })
676            .unwrap();
677
678        let mut options = table::requests::TableOptions::default();
679        options.extra_options.insert(
680            LOGICAL_TABLE_METADATA_KEY.to_string(),
681            physical_table_name.to_string(),
682        );
683        let logical_schema = Arc::new(Schema::new(vec![
684            ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
685            ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
686            ColumnSchema::new(
687                "timestamp",
688                ConcreteDataType::timestamp_millisecond_datatype(),
689                false,
690            )
691            .with_time_index(true),
692            ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
693        ]));
694        let logical_meta = TableMetaBuilder::empty()
695            .schema(logical_schema)
696            .primary_key_indices(vec![0, 1])
697            .value_indices(vec![3])
698            .engine(METRIC_ENGINE_NAME.to_string())
699            .options(options)
700            .next_column_id(1024)
701            .build()
702            .unwrap();
703        let logical_info = TableInfoBuilder::default()
704            .table_id(1024)
705            .name("some_metric")
706            .meta(logical_meta)
707            .build()
708            .unwrap();
709        catalog_manager
710            .register_table_sync(RegisterTableRequest {
711                catalog: DEFAULT_CATALOG_NAME.to_string(),
712                schema: DEFAULT_SCHEMA_NAME.to_string(),
713                table_name: "some_metric".to_string(),
714                table_id: 1024,
715                table: EmptyTable::from_table_info(&logical_info),
716            })
717            .unwrap();
718
719        QueryEngineFactory::new(
720            catalog_manager,
721            None,
722            None,
723            None,
724            None,
725            false,
726            crate::options::QueryOptions::default(),
727        )
728        .query_engine()
729    }
730
731    async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
732        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
733        let engine = create_test_engine().await;
734        engine
735            .planner()
736            .plan(&stmt, QueryContext::arc())
737            .await
738            .unwrap()
739    }
740
741    async fn parse_promql_to_plan(query: &str) -> LogicalPlan {
742        let engine = create_promql_test_engine();
743        let query_ctx = QueryContext::arc();
744        let stmt = QueryLanguageParser::parse_promql(
745            &PromQuery {
746                query: query.to_string(),
747                start: "0".to_string(),
748                end: "10".to_string(),
749                step: "5s".to_string(),
750                lookback: "300s".to_string(),
751                alias: None,
752            },
753            &query_ctx,
754        )
755        .unwrap();
756
757        engine.planner().plan(&stmt, query_ctx).await.unwrap()
758    }
759
760    #[tokio::test]
761    async fn test_extract_placeholder_cast_types_multiple() {
762        let plan = parse_sql_to_plan(
763            "SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0",
764        )
765        .await;
766        let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap();
767
768        assert_eq!(types.len(), 5);
769        assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
770        assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
771        assert_eq!(types.get("$3"), Some(&None));
772        assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
773        assert_eq!(types.get("$5"), Some(&Some(DataType::Float32)));
774    }
775
776    #[tokio::test]
777    async fn test_get_inferred_parameter_types_fallback_for_udf_args() {
778        // datafusion is not able to infer type for scalar function arguments
779        let plan = parse_sql_to_plan(
780            "SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3",
781        )
782        .await;
783        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
784
785        assert_eq!(types.len(), 3);
786
787        let type_1 = types.get("$1").unwrap();
788        let type_2 = types.get("$2").unwrap();
789        let type_3 = types.get("$3").unwrap();
790
791        assert!(type_1.is_none(), "Expected $1 to be None");
792        assert_eq!(type_2, &Some(DataType::Utf8));
793        assert_eq!(type_3, &Some(DataType::Int32));
794    }
795
796    #[tokio::test]
797    async fn test_plan_pql_applies_extension_rules() {
798        for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {
799            let plan = parse_promql_to_plan(&format!(
800                "sum(irate(some_metric[1h])) / scalar(count({inner_agg}(some_metric) by (tag_0)))"
801            ))
802            .await;
803            let plan_str = plan.display_indent_schema().to_string();
804            assert!(plan_str.contains("Distinct:"), "{inner_agg}: {plan_str}");
805        }
806    }
807
808    #[tokio::test]
809    async fn test_plan_pql_filters_null_only_groups_for_non_count_inner_aggs() {
810        let count_plan = parse_promql_to_plan("scalar(count(count(some_metric) by (tag_0)))").await;
811        let count_plan_str = count_plan.display_indent_schema().to_string();
812        assert!(
813            !count_plan_str.contains("field_0 IS NOT NULL"),
814            "{count_plan_str}"
815        );
816
817        for inner_agg in ["sum", "avg", "min", "max", "stddev", "stdvar"] {
818            let plan = parse_promql_to_plan(&format!(
819                "scalar(count({inner_agg}(some_metric) by (tag_0)))"
820            ))
821            .await;
822            let plan_str = plan.display_indent_schema().to_string();
823            assert!(
824                plan_str.contains("field_0 IS NOT NULL"),
825                "{inner_agg}: {plan_str}"
826            );
827        }
828    }
829
830    #[tokio::test]
831    async fn test_plan_pql_skips_extension_rules_for_non_direct_or_unsupported_inner_agg() {
832        for query in [
833            "sum(irate(some_metric[1h])) / scalar(count(sum(irate(some_metric[1h])) by (tag_0)))",
834            "sum(irate(some_metric[1h])) / scalar(count(group(some_metric) by (tag_0)))",
835        ] {
836            let plan = parse_promql_to_plan(query).await;
837            let plan_str = plan.display_indent_schema().to_string();
838            assert!(!plan_str.contains("Distinct:"), "{query}: {plan_str}");
839        }
840    }
841
842    #[tokio::test]
843    async fn test_plan_sql_does_not_apply_nested_count_rule() {
844        let plan = parse_sql_to_plan(
845            "SELECT id, count(inner_count) \
846             FROM ( \
847                 SELECT id, count(name) AS inner_count \
848                 FROM test \
849                 GROUP BY id \
850                 ORDER BY id \
851                 LIMIT 1000000 \
852             ) t \
853             GROUP BY id \
854             ORDER BY id",
855        )
856        .await;
857
858        let plan_str = plan.display_indent_schema().to_string();
859        assert!(!plan_str.contains("Distinct:"), "{plan_str}");
860    }
861
862    #[tokio::test]
863    async fn test_get_inferred_parameter_types_subquery() {
864        let plan = parse_sql_to_plan(
865            r#"SELECT * FROM test WHERE id = (SELECT id FROM test CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p LIMIT 1)"#,
866        ).await;
867        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
868
869        assert_eq!(types.len(), 1);
870        let type_1 = types.get("$1").unwrap();
871        assert_eq!(type_1, &Some(DataType::Utf8));
872    }
873
874    #[tokio::test]
875    async fn test_get_inferred_parameter_types_insert() {
876        let plan = parse_sql_to_plan("INSERT INTO test (id, name) VALUES ($1, $2), ($3, $4)").await;
877        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
878
879        assert_eq!(types.len(), 4);
880        assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
881        assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
882        assert_eq!(types.get("$3"), Some(&Some(DataType::Int32)));
883        assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8)));
884    }
885
886    #[tokio::test]
887    async fn test_get_inferred_parameter_types_arrow_cast() {
888        let plan = parse_sql_to_plan("SELECT $1::INT64, $2::FLOAT64, $3::INT16, $4::INT32, $5::UINT8, $6::UINT16, $7::UINT32").await;
889        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
890
891        assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
892        assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
893        assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
894        assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
895        assert_eq!(types.get("$5"), Some(&Some(DataType::UInt8)));
896        assert_eq!(types.get("$6"), Some(&Some(DataType::UInt16)));
897        assert_eq!(types.get("$7"), Some(&Some(DataType::UInt32)));
898
899        let plan = parse_sql_to_plan("SELECT $1::INT8, $2::FLOAT8, $3::INT2, $4::INT8").await;
900        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
901
902        assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
903        assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
904        assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
905    }
906}