fix: type cast from subquery

This commit is contained in:
Ning Sun
2026-03-18 00:04:32 +08:00
parent 07ae4369f9
commit aaba41ea07
2 changed files with 42 additions and 4 deletions

View File

@@ -909,7 +909,7 @@ mod tests {
)
);
assert_eq!(
"Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n TableScan: numbers projection=[number]",
"Limit: skip=0, fetch=20\n Projection: sum(numbers.number)\n Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n TableScan: numbers",
format!("{}", logical_plan.display_indent())
);
}

View File

@@ -28,6 +28,7 @@ use datafusion::execution::context::SessionState;
use datafusion::sql::planner::PlannerContext;
use datafusion_common::ToDFSchema;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::{
Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
ToStringifiedPlan, col,
@@ -418,15 +419,26 @@ impl DfLogicalPlanner {
///
/// This function walks through all expressions in the logical plan,
/// including subqueries, to identify placeholders and their cast types.
fn extract_placeholder_cast_types(
pub(crate) fn extract_placeholder_cast_types(
plan: &LogicalPlan,
) -> Result<HashMap<String, Option<DataType>>> {
let mut placeholder_types = HashMap::new();
let mut casted_placeholders = HashSet::new();
Self::extract_from_plan(plan, &mut placeholder_types, &mut casted_placeholders)?;
Ok(placeholder_types)
}
fn extract_from_plan(
plan: &LogicalPlan,
placeholder_types: &mut HashMap<String, Option<DataType>>,
casted_placeholders: &mut HashSet<String>,
) -> Result<()> {
plan.apply(|node| {
for expr in node.expressions() {
let _ = expr.apply(|e| {
// Handle casted placeholders
if let DfExpr::Cast(cast) = e
&& let DfExpr::Placeholder(ph) = &*cast.expr
{
@@ -434,6 +446,7 @@ impl DfLogicalPlanner {
casted_placeholders.insert(ph.id.clone());
}
// Handle bare (non-casted) placeholders
if let DfExpr::Placeholder(ph) = e
&& !casted_placeholders.contains(&ph.id)
&& !placeholder_types.contains_key(&ph.id)
@@ -441,13 +454,26 @@ impl DfLogicalPlanner {
placeholder_types.insert(ph.id.clone(), None);
}
// Recurse into subquery plans embedded in expressions
match e {
DfExpr::Exists(Exists { subquery, .. })
| DfExpr::InSubquery(InSubquery { subquery, .. })
| DfExpr::ScalarSubquery(subquery) => {
Self::extract_from_plan(
&subquery.subquery,
placeholder_types,
casted_placeholders,
)?;
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
});
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(placeholder_types)
Ok(())
}
/// Gets inferred parameter types from a logical plan.
@@ -619,4 +645,16 @@ mod tests {
assert_eq!(type_2, &Some(DataType::Utf8));
assert_eq!(type_3, &Some(DataType::Int32));
}
#[tokio::test]
async fn test_get_inferred_parameter_types_subquery() {
let plan = parse_sql_to_plan(
r#"SELECT * FROM test WHERE id = (SELECT id FROM test CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p LIMIT 1)"#,
).await;
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
assert_eq!(types.len(), 1);
let type_1 = types.get("$1").unwrap();
assert_eq!(type_1, &Some(DataType::Utf8));
}
}