feat: add a fallback parameter type inference by reading cast type (#7712)

* feat: add a fallback parameter type inference by reading cast

* fix: typo

* fix: lint and typo

* refactor: make extract function private

* refactor: fix_placeholder_types is no longer needed
This commit is contained in:
Ning Sun
2026-02-25 10:30:02 +08:00
committed by GitHub
parent 279b009583
commit 07737188ef
6 changed files with 201 additions and 66 deletions

View File

@@ -14,9 +14,11 @@
use std::any::Any;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use arrow_schema::DataType;
use async_trait::async_trait;
use catalog::table_source::DfTableSourceProvider;
use common_error::ext::BoxedError;
@@ -25,6 +27,7 @@ use datafusion::common::{DFSchema, plan_err};
use datafusion::execution::context::SessionState;
use datafusion::sql::planner::PlannerContext;
use datafusion_common::ToDFSchema;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::{
Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
ToStringifiedPlan, col,
@@ -405,6 +408,89 @@ impl DfLogicalPlanner {
.fail(),
}
}
/// Extracts cast types for all placeholders in a logical plan.
/// Returns a map where each placeholder ID is mapped to:
/// - Some(DataType) if the placeholder is cast to a specific type
/// - None if the placeholder exists but has no cast
///
/// Example: `$1::TEXT` returns `{"$1": Some(DataType::Utf8)}`
///
/// This function walks through all expressions in the logical plan,
/// including subqueries, to identify placeholders and their cast types.
fn extract_placeholder_cast_types(
plan: &LogicalPlan,
) -> Result<HashMap<String, Option<DataType>>> {
let mut placeholder_types = HashMap::new();
let mut casted_placeholders = HashSet::new();
plan.apply(|node| {
for expr in node.expressions() {
let _ = expr.apply(|e| {
if let DfExpr::Cast(cast) = e
&& let DfExpr::Placeholder(ph) = &*cast.expr
{
placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
casted_placeholders.insert(ph.id.clone());
}
if let DfExpr::Placeholder(ph) = e
&& !casted_placeholders.contains(&ph.id)
&& !placeholder_types.contains_key(&ph.id)
{
placeholder_types.insert(ph.id.clone(), None);
}
Ok(TreeNodeRecursion::Continue)
});
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(placeholder_types)
}
/// Gets inferred parameter types from a logical plan.
/// Returns a map where each parameter ID is mapped to:
/// - Some(DataType) if the parameter type could be inferred
/// - None if the parameter type could not be inferred
///
/// This function first uses DataFusion's `get_parameter_types()` to infer types.
/// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
///
/// This is because datafusion can only infer types for a limited cases.
///
/// Example: For query `WHERE $1::TEXT AND $2`, DataFusion may not infer `$2`'s type,
/// but this function will return `{"$1": Some(DataType::Utf8), "$2": None}`.
pub fn get_inferred_parameter_types(
plan: &LogicalPlan,
) -> Result<HashMap<String, Option<DataType>>> {
let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
let has_none = param_types.values().any(|v| v.is_none());
if !has_none {
Ok(param_types)
} else {
let cast_types = Self::extract_placeholder_cast_types(plan)?;
let mut merged = param_types;
for (id, opt_type) in cast_types {
merged
.entry(id)
.and_modify(|existing| {
if existing.is_none() {
*existing = opt_type.clone();
}
})
.or_insert(opt_type);
}
Ok(merged)
}
}
}
#[async_trait]
@@ -453,3 +539,84 @@ impl LogicalPlanner for DfLogicalPlanner {
self
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_schema::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use session::context::QueryContext;
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::test_util::EmptyTable;
use super::*;
use crate::QueryEngineRef;
use crate::parser::QueryLanguageParser;
async fn create_test_engine() -> QueryEngineRef {
let columns = vec![
ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(columns));
let table_meta = TableMetaBuilder::empty()
.schema(schema)
.primary_key_indices(vec![0])
.value_indices(vec![1])
.next_column_id(1024)
.build()
.unwrap();
let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
let table = EmptyTable::from_table_info(&table_info);
crate::tests::new_query_engine_with_table(table)
}
async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let engine = create_test_engine().await;
engine
.planner()
.plan(&stmt, QueryContext::arc())
.await
.unwrap()
}
#[tokio::test]
async fn test_extract_placeholder_cast_types_multiple() {
let plan = parse_sql_to_plan(
"SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0",
)
.await;
let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap();
assert_eq!(types.len(), 5);
assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
assert_eq!(types.get("$3"), Some(&None));
assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
assert_eq!(types.get("$5"), Some(&Some(DataType::Float32)));
}
#[tokio::test]
async fn test_get_inferred_parameter_types_fallback_for_udf_args() {
// datafusion is not able to infer type for scalar function arguments
let plan = parse_sql_to_plan(
"SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3",
)
.await;
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
assert_eq!(types.len(), 3);
let type_1 = types.get("$1").unwrap();
let type_2 = types.get("$2").unwrap();
let type_3 = types.get("$3").unwrap();
assert!(type_1.is_none(), "Expected $1 to be None");
assert_eq!(type_2, &Some(DataType::Utf8));
assert_eq!(type_3, &Some(DataType::Int32));
}
}

View File

@@ -445,6 +445,14 @@ pub enum Error {
error: query::error::Error,
},
#[snafu(display("Failed to infer parameter types"))]
InferParameterTypes {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: query::error::Error,
},
#[snafu(display("{}", reason))]
UnexpectedResult {
reason: String,
@@ -721,6 +729,7 @@ impl ErrorExt for Error {
| InvalidPromRemoteRequest { .. }
| InvalidFlightTicket { .. }
| InvalidPrepareStatement { .. }
| InferParameterTypes { .. }
| DataFrame { .. }
| PreparedStmtTypeMismatch { .. }
| TimePrecision { .. }

View File

@@ -34,6 +34,7 @@ use opensrv_mysql::{
StatementMetaWriter, ValueInner,
};
use parking_lot::RwLock;
use query::planner::DfLogicalPlanner;
use query::query_engine::DescribeResult;
use rand::RngCore;
use session::context::{Channel, QueryContextRef};
@@ -45,10 +46,12 @@ use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
use crate::SqlPlan;
use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result};
use crate::error::{
self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
};
use crate::metrics::METRIC_AUTH_FAILURE;
use crate::mysql::helper::{
self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders,
self, format_placeholder, replace_placeholders, transform_placeholders,
};
use crate::mysql::writer;
use crate::mysql::writer::{create_mysql_column, handle_err};
@@ -206,7 +209,7 @@ impl MysqlInstanceShim {
let describe_result = self
.do_describe(statement.clone(), query_ctx.clone())
.await?;
let (mut plan, schema) = if let Some(DescribeResult {
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
}) = describe_result
@@ -216,17 +219,13 @@ impl MysqlInstanceShim {
(None, None)
};
let params = if let Some(plan) = &mut plan {
fix_placeholder_types(plan)?;
debug!("Plan after fix placeholder types: {:#?}", plan);
prepared_params(
&plan
.get_parameter_types()
.context(DataFrameSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect(),
)?
let params = if let Some(plan) = &plan {
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
.context(InferParameterTypesSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect();
prepared_params(&param_types)?
} else {
dummy_params(param_num)?
};
@@ -293,11 +292,9 @@ impl MysqlInstanceShim {
};
let outputs = match sql_plan.plan {
Some(mut plan) => {
fix_placeholder_types(&mut plan)?;
let param_types = plan
.get_parameter_types()
.context(DataFrameSnafu)?
Some(plan) => {
let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
.context(InferParameterTypesSnafu)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
.collect::<HashMap<_, _>>();

View File

@@ -13,16 +13,12 @@
// limitations under the License.
use std::ops::ControlFlow;
use std::sync::Arc;
use std::time::Duration;
use arrow_schema::Field;
use chrono::NaiveDate;
use common_query::prelude::ScalarValue;
use common_sql::convert::sql_value_to_value;
use common_time::{Date, Timestamp};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_expr::LogicalPlan;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::ColumnSchema;
use datatypes::types::TimestampType;
@@ -33,7 +29,7 @@ use snafu::ResultExt;
use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
use sql::statements::statement::Statement;
use crate::error::{self, DataFusionSnafu, Result};
use crate::error::{self, Result};
/// Returns the placeholder string "$i".
pub fn format_placeholder(i: usize) -> String {
@@ -82,40 +78,6 @@ pub fn transform_placeholders(stmt: Statement) -> Statement {
}
}
/// Give placeholder that cast to certain type `data_type` the same data type as is cast to
///
/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types`
pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
let give_placeholder_types = |mut e: datafusion_expr::Expr| {
if let datafusion_expr::Expr::Cast(cast) = &mut e {
if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
if ph.field.is_none() {
ph.field = Some(Arc::new(Field::new("", cast.data_type.clone(), true)));
common_telemetry::debug!(
"give placeholder type {:?} to {:?}",
cast.data_type,
ph
);
Ok(Transformed::yes(e))
} else {
Ok(Transformed::no(e))
}
} else {
Ok(Transformed::no(e))
}
} else {
Ok(Transformed::no(e))
}
};
let give_placeholder_types_recursively =
|e: datafusion_expr::Expr| e.transform(give_placeholder_types);
*plan = std::mem::take(plan)
.transform(|p| p.map_expressions(give_placeholder_types_recursively))
.context(DataFusionSnafu)?
.data;
Ok(())
}
fn visit_placeholders<V>(v: &mut V)
where
V: VisitMut,

View File

@@ -34,6 +34,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::{ClientInfo, ErrorHandler, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use query::planner::DfLogicalPlanner;
use query::query_engine::DescribeResult;
use session::Session;
use session::context::QueryContextRef;
@@ -43,7 +44,7 @@ use sql::parser::{ParseOptions, ParserContext};
use sql::statements::statement::Statement;
use crate::SqlPlan;
use crate::error::{DataFusionSnafu, Result};
use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
use crate::postgres::types::*;
use crate::postgres::utils::convert_err;
use crate::postgres::{PostgresServerHandlerInner, fixtures};
@@ -369,9 +370,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
// client provided parameter types, can be empty if client doesn't try to parse statement
let provided_param_types = &stmt.parameter_types;
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
let param_types = plan
.get_parameter_types()
.context(DataFusionSnafu)
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
.context(InferParameterTypesSnafu)
.map_err(convert_err)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))

View File

@@ -39,12 +39,13 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::types::format::FormatOptions as PgFormatOptions;
use query::planner::DfLogicalPlanner;
use session::context::QueryContextRef;
use snafu::ResultExt;
pub use self::error::{PgErrorCode, PgErrorSeverity};
use crate::SqlPlan;
use crate::error::{self as server_error, DataFusionSnafu, Result};
use crate::error::{self as server_error, InferParameterTypesSnafu, Result};
use crate::postgres::utils::convert_err;
pub(super) fn schema_to_pg(
@@ -364,9 +365,8 @@ pub(super) fn parameters_to_scalar_values(
let mut results = Vec::with_capacity(param_count);
let client_param_types = &portal.statement.parameter_types;
let server_param_types = plan
.get_parameter_types()
.context(DataFusionSnafu)
let server_param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
.context(InferParameterTypesSnafu)
.map_err(convert_err)?
.into_iter()
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))