diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index b95803a52b..c4f4af3c6a 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -494,6 +494,36 @@ impl DfLogicalPlanner { Ok(()) } + fn infer_limit_placeholder_types( + plan: &LogicalPlan, + placeholder_types: &mut HashMap>, + ) -> Result<()> { + plan.apply(|node| { + if let LogicalPlan::Limit(limit) = node { + for expr in limit.skip.iter().chain(limit.fetch.iter()) { + expr.apply(|e| { + if let DfExpr::Placeholder(ph) = e { + placeholder_types + .entry(ph.id.clone()) + .and_modify(|existing| { + if existing.is_none() { + *existing = Some(DataType::Int64); + } + }) + .or_insert(Some(DataType::Int64)); + } + + Ok(TreeNodeRecursion::Continue) + })?; + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(()) + } + /// Gets inferred parameter types from a logical plan. /// Returns a map where each parameter ID is mapped to: /// - Some(DataType) if the parameter type could be inferred @@ -501,7 +531,8 @@ impl DfLogicalPlanner { /// /// This function first uses DataFusion's `get_parameter_types()` to infer types. /// If any parameters have `None` values (i.e., DataFusion couldn't infer their types), - /// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts. + /// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts + /// and applies context-specific inference such as LIMIT/OFFSET placeholders. /// /// This is because datafusion can only infer types for a limited cases. /// @@ -510,19 +541,15 @@ impl DfLogicalPlanner { pub fn get_inferred_parameter_types( plan: &LogicalPlan, ) -> Result>> { - let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?; + let mut param_types = plan.get_parameter_types().context(PlanSqlSnafu)?; let has_none = param_types.values().any(|v| v.is_none()); - if !has_none { - Ok(param_types) - } else { + if has_none { let cast_types = Self::extract_placeholder_cast_types(plan)?; - let mut merged = param_types; - for (id, opt_type) in cast_types { - merged + param_types .entry(id) .and_modify(|existing| { if existing.is_none() { @@ -532,8 +559,10 @@ impl DfLogicalPlanner { .or_insert(opt_type); } - Ok(merged) + Self::infer_limit_placeholder_types(plan, &mut param_types)?; } + + Ok(param_types) } } @@ -793,6 +822,15 @@ mod tests { assert_eq!(type_3, &Some(DataType::Int32)); } + #[tokio::test] + async fn test_get_inferred_parameter_types_limit_offset() { + let plan = parse_sql_to_plan("SELECT id FROM test LIMIT $1 OFFSET $2").await; + let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap(); + + assert_eq!(types.get("$1"), Some(&Some(DataType::Int64))); + assert_eq!(types.get("$2"), Some(&Some(DataType::Int64))); + } + #[tokio::test] async fn test_plan_pql_applies_extension_rules() { for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] { diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index e0cb086dda..0694cc7746 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -516,6 +516,38 @@ async fn test_query_prepared() -> Result<()> { _ => unreachable!(), } + // Regression test for #8142: LIMIT ? should work in prepared statements. + // The LIMIT placeholder should be inferred as Int64 so the MySQL prepare + // response advertises the correct parameter count. + { + let stmt = connection + .prep("SELECT uint32s FROM all_datatypes LIMIT ?") + .await + .unwrap(); + let rows: Vec = connection + .exec(stmt, vec![mysql_async::Value::Int(1)]) + .await + .unwrap(); + assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row"); + } + + // Also cover mixed placeholders: the WHERE placeholder is inferred from + // the column type and the LIMIT placeholder is inferred from its context. + { + let stmt = connection + .prep("SELECT uint32s FROM all_datatypes WHERE uint32s >= ? LIMIT ?") + .await + .unwrap(); + let rows: Vec = connection + .exec( + stmt, + vec![mysql_async::Value::UInt(0), mysql_async::Value::UInt(1)], + ) + .await + .unwrap(); + assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row"); + } + Ok(()) }