fix(mysql): infer LIMIT placeholders in prepare (#8149)

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-05-22 07:40:10 +08:00
committed by GitHub
parent 4668dd43bd
commit f1ad472075
2 changed files with 79 additions and 9 deletions

View File

@@ -494,6 +494,36 @@ impl DfLogicalPlanner {
Ok(())
}
fn infer_limit_placeholder_types(
plan: &LogicalPlan,
placeholder_types: &mut HashMap<String, Option<DataType>>,
) -> Result<()> {
plan.apply(|node| {
if let LogicalPlan::Limit(limit) = node {
for expr in limit.skip.iter().chain(limit.fetch.iter()) {
expr.apply(|e| {
if let DfExpr::Placeholder(ph) = e {
placeholder_types
.entry(ph.id.clone())
.and_modify(|existing| {
if existing.is_none() {
*existing = Some(DataType::Int64);
}
})
.or_insert(Some(DataType::Int64));
}
Ok(TreeNodeRecursion::Continue)
})?;
}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
/// Gets inferred parameter types from a logical plan.
/// Returns a map where each parameter ID is mapped to:
/// - Some(DataType) if the parameter type could be inferred
@@ -501,7 +531,8 @@ impl DfLogicalPlanner {
///
/// This function first uses DataFusion's `get_parameter_types()` to infer types.
/// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts
/// and applies context-specific inference such as LIMIT/OFFSET placeholders.
///
/// This is because datafusion can only infer types for a limited cases.
///
@@ -510,19 +541,15 @@ impl DfLogicalPlanner {
pub fn get_inferred_parameter_types(
plan: &LogicalPlan,
) -> Result<HashMap<String, Option<DataType>>> {
let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
let mut param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
let has_none = param_types.values().any(|v| v.is_none());
if !has_none {
Ok(param_types)
} else {
if has_none {
let cast_types = Self::extract_placeholder_cast_types(plan)?;
let mut merged = param_types;
for (id, opt_type) in cast_types {
merged
param_types
.entry(id)
.and_modify(|existing| {
if existing.is_none() {
@@ -532,8 +559,10 @@ impl DfLogicalPlanner {
.or_insert(opt_type);
}
Ok(merged)
Self::infer_limit_placeholder_types(plan, &mut param_types)?;
}
Ok(param_types)
}
}
@@ -793,6 +822,15 @@ mod tests {
assert_eq!(type_3, &Some(DataType::Int32));
}
#[tokio::test]
async fn test_get_inferred_parameter_types_limit_offset() {
let plan = parse_sql_to_plan("SELECT id FROM test LIMIT $1 OFFSET $2").await;
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
assert_eq!(types.get("$2"), Some(&Some(DataType::Int64)));
}
#[tokio::test]
async fn test_plan_pql_applies_extension_rules() {
for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {

View File

@@ -516,6 +516,38 @@ async fn test_query_prepared() -> Result<()> {
_ => unreachable!(),
}
// Regression test for #8142: LIMIT ? should work in prepared statements.
// The LIMIT placeholder should be inferred as Int64 so the MySQL prepare
// response advertises the correct parameter count.
{
let stmt = connection
.prep("SELECT uint32s FROM all_datatypes LIMIT ?")
.await
.unwrap();
let rows: Vec<Row> = connection
.exec(stmt, vec![mysql_async::Value::Int(1)])
.await
.unwrap();
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
}
// Also cover mixed placeholders: the WHERE placeholder is inferred from
// the column type and the LIMIT placeholder is inferred from its context.
{
let stmt = connection
.prep("SELECT uint32s FROM all_datatypes WHERE uint32s >= ? LIMIT ?")
.await
.unwrap();
let rows: Vec<Row> = connection
.exec(
stmt,
vec![mysql_async::Value::UInt(0), mysql_async::Value::UInt(1)],
)
.await
.unwrap();
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
}
Ok(())
}