From 07737188efb1aed2a7fee5b9cebb3b914424c85c Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 25 Feb 2026 10:30:02 +0800 Subject: [PATCH] feat: add a fallback parameter type inference by reading cast type (#7712) * feat: add a fallback parameter type inference by reading cast * fix: typo * fix: lint and typo * refactor: make extract function private * refactor: fix_placeholder_types is no longer needed --- src/query/src/planner.rs | 167 ++++++++++++++++++++++++++++ src/servers/src/error.rs | 9 ++ src/servers/src/mysql/handler.rs | 35 +++--- src/servers/src/mysql/helper.rs | 40 +------ src/servers/src/postgres/handler.rs | 8 +- src/servers/src/postgres/types.rs | 8 +- 6 files changed, 201 insertions(+), 66 deletions(-) diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 91ed874a90..44c9bc3956 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -14,9 +14,11 @@ use std::any::Any; use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; +use arrow_schema::DataType; use async_trait::async_trait; use catalog::table_source::DfTableSourceProvider; use common_error::ext::BoxedError; @@ -25,6 +27,7 @@ use datafusion::common::{DFSchema, plan_err}; use datafusion::execution::context::SessionState; use datafusion::sql::planner::PlannerContext; use datafusion_common::ToDFSchema; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_expr::{ Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType, ToStringifiedPlan, col, @@ -405,6 +408,89 @@ impl DfLogicalPlanner { .fail(), } } + + /// Extracts cast types for all placeholders in a logical plan. + /// Returns a map where each placeholder ID is mapped to: + /// - Some(DataType) if the placeholder is cast to a specific type + /// - None if the placeholder exists but has no cast + /// + /// Example: `$1::TEXT` returns `{"$1": Some(DataType::Utf8)}` + /// + /// This function walks through all expressions in the logical plan, + /// including subqueries, to identify placeholders and their cast types. + fn extract_placeholder_cast_types( + plan: &LogicalPlan, + ) -> Result>> { + let mut placeholder_types = HashMap::new(); + let mut casted_placeholders = HashSet::new(); + + plan.apply(|node| { + for expr in node.expressions() { + let _ = expr.apply(|e| { + if let DfExpr::Cast(cast) = e + && let DfExpr::Placeholder(ph) = &*cast.expr + { + placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone())); + casted_placeholders.insert(ph.id.clone()); + } + + if let DfExpr::Placeholder(ph) = e + && !casted_placeholders.contains(&ph.id) + && !placeholder_types.contains_key(&ph.id) + { + placeholder_types.insert(ph.id.clone(), None); + } + + Ok(TreeNodeRecursion::Continue) + }); + } + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(placeholder_types) + } + + /// Gets inferred parameter types from a logical plan. + /// Returns a map where each parameter ID is mapped to: + /// - Some(DataType) if the parameter type could be inferred + /// - None if the parameter type could not be inferred + /// + /// This function first uses DataFusion's `get_parameter_types()` to infer types. + /// If any parameters have `None` values (i.e., DataFusion couldn't infer their types), + /// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts. + /// + /// This is because datafusion can only infer types for a limited cases. + /// + /// Example: For query `WHERE $1::TEXT AND $2`, DataFusion may not infer `$2`'s type, + /// but this function will return `{"$1": Some(DataType::Utf8), "$2": None}`. + pub fn get_inferred_parameter_types( + plan: &LogicalPlan, + ) -> Result>> { + let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?; + + let has_none = param_types.values().any(|v| v.is_none()); + + if !has_none { + Ok(param_types) + } else { + let cast_types = Self::extract_placeholder_cast_types(plan)?; + + let mut merged = param_types; + + for (id, opt_type) in cast_types { + merged + .entry(id) + .and_modify(|existing| { + if existing.is_none() { + *existing = opt_type.clone(); + } + }) + .or_insert(opt_type); + } + + Ok(merged) + } + } } #[async_trait] @@ -453,3 +539,84 @@ impl LogicalPlanner for DfLogicalPlanner { self } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_schema::DataType; + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::{ColumnSchema, Schema}; + use session::context::QueryContext; + use table::metadata::{TableInfoBuilder, TableMetaBuilder}; + use table::test_util::EmptyTable; + + use super::*; + use crate::QueryEngineRef; + use crate::parser::QueryLanguageParser; + + async fn create_test_engine() -> QueryEngineRef { + let columns = vec![ + ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false), + ColumnSchema::new("name", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(columns)); + let table_meta = TableMetaBuilder::empty() + .schema(schema) + .primary_key_indices(vec![0]) + .value_indices(vec![1]) + .next_column_id(1024) + .build() + .unwrap(); + let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap(); + let table = EmptyTable::from_table_info(&table_info); + + crate::tests::new_query_engine_with_table(table) + } + + async fn parse_sql_to_plan(sql: &str) -> LogicalPlan { + let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap(); + let engine = create_test_engine().await; + engine + .planner() + .plan(&stmt, QueryContext::arc()) + .await + .unwrap() + } + + #[tokio::test] + async fn test_extract_placeholder_cast_types_multiple() { + let plan = parse_sql_to_plan( + "SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0", + ) + .await; + let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap(); + + assert_eq!(types.len(), 5); + assert_eq!(types.get("$1"), Some(&Some(DataType::Int32))); + assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8))); + assert_eq!(types.get("$3"), Some(&None)); + assert_eq!(types.get("$4"), Some(&Some(DataType::Int32))); + assert_eq!(types.get("$5"), Some(&Some(DataType::Float32))); + } + + #[tokio::test] + async fn test_get_inferred_parameter_types_fallback_for_udf_args() { + // datafusion is not able to infer type for scalar function arguments + let plan = parse_sql_to_plan( + "SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3", + ) + .await; + let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap(); + + assert_eq!(types.len(), 3); + + let type_1 = types.get("$1").unwrap(); + let type_2 = types.get("$2").unwrap(); + let type_3 = types.get("$3").unwrap(); + + assert!(type_1.is_none(), "Expected $1 to be None"); + assert_eq!(type_2, &Some(DataType::Utf8)); + assert_eq!(type_3, &Some(DataType::Int32)); + } +} diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 92fb89af03..18ac964f05 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -445,6 +445,14 @@ pub enum Error { error: query::error::Error, }, + #[snafu(display("Failed to infer parameter types"))] + InferParameterTypes { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: query::error::Error, + }, + #[snafu(display("{}", reason))] UnexpectedResult { reason: String, @@ -721,6 +729,7 @@ impl ErrorExt for Error { | InvalidPromRemoteRequest { .. } | InvalidFlightTicket { .. } | InvalidPrepareStatement { .. } + | InferParameterTypes { .. } | DataFrame { .. } | PreparedStmtTypeMismatch { .. } | TimePrecision { .. } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 25caff98a6..dd67012a52 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -34,6 +34,7 @@ use opensrv_mysql::{ StatementMetaWriter, ValueInner, }; use parking_lot::RwLock; +use query::planner::DfLogicalPlanner; use query::query_engine::DescribeResult; use rand::RngCore; use session::context::{Channel, QueryContextRef}; @@ -45,10 +46,12 @@ use sql::statements::statement::Statement; use tokio::io::AsyncWrite; use crate::SqlPlan; -use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result}; +use crate::error::{ + self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result, +}; use crate::metrics::METRIC_AUTH_FAILURE; use crate::mysql::helper::{ - self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders, + self, format_placeholder, replace_placeholders, transform_placeholders, }; use crate::mysql::writer; use crate::mysql::writer::{create_mysql_column, handle_err}; @@ -206,7 +209,7 @@ impl MysqlInstanceShim { let describe_result = self .do_describe(statement.clone(), query_ctx.clone()) .await?; - let (mut plan, schema) = if let Some(DescribeResult { + let (plan, schema) = if let Some(DescribeResult { logical_plan, schema, }) = describe_result @@ -216,17 +219,13 @@ impl MysqlInstanceShim { (None, None) }; - let params = if let Some(plan) = &mut plan { - fix_placeholder_types(plan)?; - debug!("Plan after fix placeholder types: {:#?}", plan); - prepared_params( - &plan - .get_parameter_types() - .context(DataFrameSnafu)? - .into_iter() - .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) - .collect(), - )? + let params = if let Some(plan) = &plan { + let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) + .context(InferParameterTypesSnafu)? + .into_iter() + .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) + .collect(); + prepared_params(¶m_types)? } else { dummy_params(param_num)? }; @@ -293,11 +292,9 @@ impl MysqlInstanceShim { }; let outputs = match sql_plan.plan { - Some(mut plan) => { - fix_placeholder_types(&mut plan)?; - let param_types = plan - .get_parameter_types() - .context(DataFrameSnafu)? + Some(plan) => { + let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan) + .context(InferParameterTypesSnafu)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) .collect::>(); diff --git a/src/servers/src/mysql/helper.rs b/src/servers/src/mysql/helper.rs index ce940c3974..2ee2421892 100644 --- a/src/servers/src/mysql/helper.rs +++ b/src/servers/src/mysql/helper.rs @@ -13,16 +13,12 @@ // limitations under the License. use std::ops::ControlFlow; -use std::sync::Arc; use std::time::Duration; -use arrow_schema::Field; use chrono::NaiveDate; use common_query::prelude::ScalarValue; use common_sql::convert::sql_value_to_value; use common_time::{Date, Timestamp}; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_expr::LogicalPlan; use datatypes::prelude::ConcreteDataType; use datatypes::schema::ColumnSchema; use datatypes::types::TimestampType; @@ -33,7 +29,7 @@ use snafu::ResultExt; use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut}; use sql::statements::statement::Statement; -use crate::error::{self, DataFusionSnafu, Result}; +use crate::error::{self, Result}; /// Returns the placeholder string "$i". pub fn format_placeholder(i: usize) -> String { @@ -82,40 +78,6 @@ pub fn transform_placeholders(stmt: Statement) -> Statement { } } -/// Give placeholder that cast to certain type `data_type` the same data type as is cast to -/// -/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types` -pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> { - let give_placeholder_types = |mut e: datafusion_expr::Expr| { - if let datafusion_expr::Expr::Cast(cast) = &mut e { - if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr { - if ph.field.is_none() { - ph.field = Some(Arc::new(Field::new("", cast.data_type.clone(), true))); - common_telemetry::debug!( - "give placeholder type {:?} to {:?}", - cast.data_type, - ph - ); - Ok(Transformed::yes(e)) - } else { - Ok(Transformed::no(e)) - } - } else { - Ok(Transformed::no(e)) - } - } else { - Ok(Transformed::no(e)) - } - }; - let give_placeholder_types_recursively = - |e: datafusion_expr::Expr| e.transform(give_placeholder_types); - *plan = std::mem::take(plan) - .transform(|p| p.map_expressions(give_placeholder_types_recursively)) - .context(DataFusionSnafu)? - .data; - Ok(()) -} - fn visit_placeholders(v: &mut V) where V: VisitMut, diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 5fb7281472..56ca14e85d 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -34,6 +34,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, ErrorHandler, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; +use query::planner::DfLogicalPlanner; use query::query_engine::DescribeResult; use session::Session; use session::context::QueryContextRef; @@ -43,7 +44,7 @@ use sql::parser::{ParseOptions, ParserContext}; use sql::statements::statement::Statement; use crate::SqlPlan; -use crate::error::{DataFusionSnafu, Result}; +use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result}; use crate::postgres::types::*; use crate::postgres::utils::convert_err; use crate::postgres::{PostgresServerHandlerInner, fixtures}; @@ -369,9 +370,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { // client provided parameter types, can be empty if client doesn't try to parse statement let provided_param_types = &stmt.parameter_types; let server_inferenced_types = if let Some(plan) = &sql_plan.plan { - let param_types = plan - .get_parameter_types() - .context(DataFusionSnafu) + let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) + .context(InferParameterTypesSnafu) .map_err(convert_err)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index b11735015f..0b76819bc9 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -39,12 +39,13 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::types::format::FormatOptions as PgFormatOptions; +use query::planner::DfLogicalPlanner; use session::context::QueryContextRef; use snafu::ResultExt; pub use self::error::{PgErrorCode, PgErrorSeverity}; use crate::SqlPlan; -use crate::error::{self as server_error, DataFusionSnafu, Result}; +use crate::error::{self as server_error, InferParameterTypesSnafu, Result}; use crate::postgres::utils::convert_err; pub(super) fn schema_to_pg( @@ -364,9 +365,8 @@ pub(super) fn parameters_to_scalar_values( let mut results = Vec::with_capacity(param_count); let client_param_types = &portal.statement.parameter_types; - let server_param_types = plan - .get_parameter_types() - .context(DataFusionSnafu) + let server_param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) + .context(InferParameterTypesSnafu) .map_err(convert_err)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))