From aaba41ea0774cf339ae576604c1a95dcc4a39a5a Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 18 Mar 2026 00:04:32 +0800 Subject: [PATCH] fix: type cast from subquery --- src/query/src/datafusion.rs | 2 +- src/query/src/planner.rs | 44 ++++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 405ec05cbd..6f20764e20 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -909,7 +909,7 @@ mod tests { ) ); assert_eq!( - "Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n TableScan: numbers projection=[number]", + "Limit: skip=0, fetch=20\n Projection: sum(numbers.number)\n Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n TableScan: numbers", format!("{}", logical_plan.display_indent()) ); } diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 44c9bc3956..0357a210b2 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -28,6 +28,7 @@ use datafusion::execution::context::SessionState; use datafusion::sql::planner::PlannerContext; use datafusion_common::ToDFSchema; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::{ Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType, ToStringifiedPlan, col, @@ -418,15 +419,26 @@ impl DfLogicalPlanner { /// /// This function walks through all expressions in the logical plan, /// including subqueries, to identify placeholders and their cast types. - fn extract_placeholder_cast_types( + pub(crate) fn extract_placeholder_cast_types( plan: &LogicalPlan, ) -> Result>> { let mut placeholder_types = HashMap::new(); let mut casted_placeholders = HashSet::new(); + Self::extract_from_plan(plan, &mut placeholder_types, &mut casted_placeholders)?; + + Ok(placeholder_types) + } + + fn extract_from_plan( + plan: &LogicalPlan, + placeholder_types: &mut HashMap>, + casted_placeholders: &mut HashSet, + ) -> Result<()> { plan.apply(|node| { for expr in node.expressions() { let _ = expr.apply(|e| { + // Handle casted placeholders if let DfExpr::Cast(cast) = e && let DfExpr::Placeholder(ph) = &*cast.expr { @@ -434,6 +446,7 @@ impl DfLogicalPlanner { casted_placeholders.insert(ph.id.clone()); } + // Handle bare (non-casted) placeholders if let DfExpr::Placeholder(ph) = e && !casted_placeholders.contains(&ph.id) && !placeholder_types.contains_key(&ph.id) @@ -441,13 +454,26 @@ impl DfLogicalPlanner { placeholder_types.insert(ph.id.clone(), None); } + // Recurse into subquery plans embedded in expressions + match e { + DfExpr::Exists(Exists { subquery, .. }) + | DfExpr::InSubquery(InSubquery { subquery, .. }) + | DfExpr::ScalarSubquery(subquery) => { + Self::extract_from_plan( + &subquery.subquery, + placeholder_types, + casted_placeholders, + )?; + } + _ => {} + } + Ok(TreeNodeRecursion::Continue) }); } Ok(TreeNodeRecursion::Continue) })?; - - Ok(placeholder_types) + Ok(()) } /// Gets inferred parameter types from a logical plan. @@ -619,4 +645,16 @@ mod tests { assert_eq!(type_2, &Some(DataType::Utf8)); assert_eq!(type_3, &Some(DataType::Int32)); } + + #[tokio::test] + async fn test_get_inferred_parameter_types_subquery() { + let plan = parse_sql_to_plan( + r#"SELECT * FROM test WHERE id = (SELECT id FROM test CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p LIMIT 1)"#, + ).await; + let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap(); + + assert_eq!(types.len(), 1); + let type_1 = types.get("$1").unwrap(); + assert_eq!(type_1, &Some(DataType::Utf8)); + } }