fix: type inference for sql rewrite (#8052)

fix: type inference for rewrited sql
This commit is contained in:
Ning Sun
2026-05-11 16:20:46 +08:00
committed by GitHub
parent e203ff9e1f
commit 5b47ec24ec
2 changed files with 37 additions and 3 deletions

View File

@@ -26,8 +26,8 @@ use common_telemetry::tracing;
use datafusion::common::{DFSchema, plan_err};
use datafusion::execution::context::SessionState;
use datafusion::sql::planner::PlannerContext;
use datafusion_common::ToDFSchema;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::{
Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
@@ -451,6 +451,19 @@ impl DfLogicalPlanner {
casted_placeholders.insert(ph.id.clone());
}
// Handle arrow_cast(Placeholder, 'type_string') generated by SQL rewriter
if let DfExpr::ScalarFunction(scalar_func) = e
&& scalar_func.name() == "arrow_cast"
&& scalar_func.args.len() == 2
&& let DfExpr::Placeholder(ph) = &scalar_func.args[0]
&& let DfExpr::Literal(ScalarValue::Utf8(Some(type_str)), _) =
&scalar_func.args[1]
&& let Ok(data_type) = type_str.parse::<DataType>()
{
placeholder_types.insert(ph.id.clone(), Some(data_type));
casted_placeholders.insert(ph.id.clone());
}
// Handle bare (non-casted) placeholders
if let DfExpr::Placeholder(ph) = e
&& !casted_placeholders.contains(&ph.id)
@@ -869,4 +882,25 @@ mod tests {
assert_eq!(types.get("$3"), Some(&Some(DataType::Int32)));
assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8)));
}
#[tokio::test]
async fn test_get_inferred_parameter_types_arrow_cast() {
let plan = parse_sql_to_plan("SELECT $1::INT64, $2::FLOAT64, $3::INT16, $4::INT32, $5::UINT8, $6::UINT16, $7::UINT32").await;
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
assert_eq!(types.get("$5"), Some(&Some(DataType::UInt8)));
assert_eq!(types.get("$6"), Some(&Some(DataType::UInt16)));
assert_eq!(types.get("$7"), Some(&Some(DataType::UInt32)));
let plan = parse_sql_to_plan("SELECT $1::INT8, $2::FLOAT8, $3::INT2, $4::INT8").await;
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
}
}

View File

@@ -398,8 +398,8 @@ pub(super) fn parameters_to_scalar_values(
return Err(invalid_parameter_error(
"unknown_parameter_type",
Some(format!(
"Cannot get parameter type information for parameter {}",
idx
"Cannot get type for parameter {}, try to provide a type using ${}::<type>",
idx, idx
)),
));
};