mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-24 17:00:37 +00:00
fix(mysql): infer LIMIT placeholders in prepare (#8149)
Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -494,6 +494,36 @@ impl DfLogicalPlanner {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn infer_limit_placeholder_types(
|
||||
plan: &LogicalPlan,
|
||||
placeholder_types: &mut HashMap<String, Option<DataType>>,
|
||||
) -> Result<()> {
|
||||
plan.apply(|node| {
|
||||
if let LogicalPlan::Limit(limit) = node {
|
||||
for expr in limit.skip.iter().chain(limit.fetch.iter()) {
|
||||
expr.apply(|e| {
|
||||
if let DfExpr::Placeholder(ph) = e {
|
||||
placeholder_types
|
||||
.entry(ph.id.clone())
|
||||
.and_modify(|existing| {
|
||||
if existing.is_none() {
|
||||
*existing = Some(DataType::Int64);
|
||||
}
|
||||
})
|
||||
.or_insert(Some(DataType::Int64));
|
||||
}
|
||||
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets inferred parameter types from a logical plan.
|
||||
/// Returns a map where each parameter ID is mapped to:
|
||||
/// - Some(DataType) if the parameter type could be inferred
|
||||
@@ -501,7 +531,8 @@ impl DfLogicalPlanner {
|
||||
///
|
||||
/// This function first uses DataFusion's `get_parameter_types()` to infer types.
|
||||
/// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
|
||||
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
|
||||
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts
|
||||
/// and applies context-specific inference such as LIMIT/OFFSET placeholders.
|
||||
///
|
||||
/// This is because datafusion can only infer types for a limited cases.
|
||||
///
|
||||
@@ -510,19 +541,15 @@ impl DfLogicalPlanner {
|
||||
pub fn get_inferred_parameter_types(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<HashMap<String, Option<DataType>>> {
|
||||
let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
|
||||
let mut param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
|
||||
|
||||
let has_none = param_types.values().any(|v| v.is_none());
|
||||
|
||||
if !has_none {
|
||||
Ok(param_types)
|
||||
} else {
|
||||
if has_none {
|
||||
let cast_types = Self::extract_placeholder_cast_types(plan)?;
|
||||
|
||||
let mut merged = param_types;
|
||||
|
||||
for (id, opt_type) in cast_types {
|
||||
merged
|
||||
param_types
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if existing.is_none() {
|
||||
@@ -532,8 +559,10 @@ impl DfLogicalPlanner {
|
||||
.or_insert(opt_type);
|
||||
}
|
||||
|
||||
Ok(merged)
|
||||
Self::infer_limit_placeholder_types(plan, &mut param_types)?;
|
||||
}
|
||||
|
||||
Ok(param_types)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -793,6 +822,15 @@ mod tests {
|
||||
assert_eq!(type_3, &Some(DataType::Int32));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_inferred_parameter_types_limit_offset() {
|
||||
let plan = parse_sql_to_plan("SELECT id FROM test LIMIT $1 OFFSET $2").await;
|
||||
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
|
||||
|
||||
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
|
||||
assert_eq!(types.get("$2"), Some(&Some(DataType::Int64)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_pql_applies_extension_rules() {
|
||||
for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {
|
||||
|
||||
@@ -516,6 +516,38 @@ async fn test_query_prepared() -> Result<()> {
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// Regression test for #8142: LIMIT ? should work in prepared statements.
|
||||
// The LIMIT placeholder should be inferred as Int64 so the MySQL prepare
|
||||
// response advertises the correct parameter count.
|
||||
{
|
||||
let stmt = connection
|
||||
.prep("SELECT uint32s FROM all_datatypes LIMIT ?")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows: Vec<Row> = connection
|
||||
.exec(stmt, vec![mysql_async::Value::Int(1)])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
|
||||
}
|
||||
|
||||
// Also cover mixed placeholders: the WHERE placeholder is inferred from
|
||||
// the column type and the LIMIT placeholder is inferred from its context.
|
||||
{
|
||||
let stmt = connection
|
||||
.prep("SELECT uint32s FROM all_datatypes WHERE uint32s >= ? LIMIT ?")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows: Vec<Row> = connection
|
||||
.exec(
|
||||
stmt,
|
||||
vec![mysql_async::Value::UInt(0), mysql_async::Value::UInt(1)],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user