mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 01:40:36 +00:00
fix(mysql): keep unknown prepare placeholders (#8150)
* fix(mysql): keep unknown prepare placeholders Signed-off-by: discord9 <discord9@163.com> * fix(mysql): use span-based placeholder fallback Signed-off-by: discord9 <discord9@163.com> * fix(mysql): visit placeholders in all statements Signed-off-by: discord9 <discord9@163.com> * refactor(mysql): remove placeholder transform wrapper Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -12705,6 +12705,7 @@ dependencies = [
|
||||
"metric-engine",
|
||||
"mime_guess",
|
||||
"mysql_async",
|
||||
"mysql_common 0.34.1",
|
||||
"notify",
|
||||
"object-pool",
|
||||
"once_cell",
|
||||
|
||||
@@ -82,6 +82,7 @@ log-query.workspace = true
|
||||
loki-proto.workspace = true
|
||||
metric-engine.workspace = true
|
||||
mime_guess = "2.0"
|
||||
mysql_common = "0.34"
|
||||
notify.workspace = true
|
||||
object-pool = "0.5"
|
||||
once_cell.workspace = true
|
||||
|
||||
@@ -30,6 +30,7 @@ use datafusion_expr::LogicalPlan;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::Schema;
|
||||
use itertools::Itertools;
|
||||
use mysql_common::Value as MysqlValue;
|
||||
use opensrv_mysql::{
|
||||
AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
|
||||
StatementMetaWriter, ValueInner,
|
||||
@@ -51,9 +52,7 @@ use crate::error::{
|
||||
self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
|
||||
};
|
||||
use crate::metrics::METRIC_AUTH_FAILURE;
|
||||
use crate::mysql::helper::{
|
||||
self, format_placeholder, replace_placeholders, transform_placeholders,
|
||||
};
|
||||
use crate::mysql::helper::{self, format_placeholder, transform_placeholders_with_count};
|
||||
use crate::mysql::writer;
|
||||
use crate::mysql::writer::{create_mysql_column, handle_err};
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
@@ -192,28 +191,31 @@ impl MysqlInstanceShim {
|
||||
return Ok((vec![], vec![]));
|
||||
}
|
||||
|
||||
let (query, param_num) = replace_placeholders(raw_query);
|
||||
|
||||
let statement = validate_query(raw_query).await?;
|
||||
|
||||
// We have to transform the placeholder, because DataFusion only parses placeholders
|
||||
// in the form of "$i", it can't process "?" right now.
|
||||
let statement = transform_placeholders(statement);
|
||||
let (statement, placeholder_count) = transform_placeholders_with_count(statement);
|
||||
let param_num = placeholder_count + 1;
|
||||
|
||||
let describe_result = self
|
||||
.do_describe(statement.clone(), query_ctx.clone())
|
||||
.await?;
|
||||
let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
|
||||
|
||||
let params = if let Some(plan) = &plan {
|
||||
let (params, can_cache_as_plan) = if let Some(plan) = &plan {
|
||||
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
|
||||
.context(InferParameterTypesSnafu)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
.collect();
|
||||
prepared_params(¶m_types)?
|
||||
|
||||
(
|
||||
prepared_params(¶m_types, param_num)?,
|
||||
all_params_have_types(¶m_types, param_num),
|
||||
)
|
||||
} else {
|
||||
dummy_params(param_num)?
|
||||
(dummy_params(param_num)?, false)
|
||||
};
|
||||
|
||||
let columns =
|
||||
@@ -239,17 +241,20 @@ impl MysqlInstanceShim {
|
||||
.unwrap_or_default();
|
||||
|
||||
match plan {
|
||||
Some(plan) if params.len() == param_num - 1 => {
|
||||
Some(plan) if can_cache_as_plan => {
|
||||
self.save_plan(SqlPlan::Plan(plan, statement), stmt_key)
|
||||
.inspect_err(|e| {
|
||||
error!(e; "Failed to save prepared statement");
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
self.save_plan(SqlPlan::Statement(statement, query), stmt_key)
|
||||
.inspect_err(|e| {
|
||||
error!(e; "Failed to save prepared statement");
|
||||
})?;
|
||||
self.save_plan(
|
||||
SqlPlan::Statement(statement, raw_query.to_string()),
|
||||
stmt_key,
|
||||
)
|
||||
.inspect_err(|e| {
|
||||
error!(e; "Failed to save prepared statement");
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +317,7 @@ impl MysqlInstanceShim {
|
||||
self.do_query(&query, query_ctx.clone()).await
|
||||
}
|
||||
}
|
||||
SqlPlan::Statement(_stmt, query) => {
|
||||
SqlPlan::Statement(stmt, query) => {
|
||||
let param_strs = match params {
|
||||
Params::ProtocolParams(params) => {
|
||||
params.iter().map(convert_param_value_to_string).collect()
|
||||
@@ -323,7 +328,7 @@ impl MysqlInstanceShim {
|
||||
"do_execute Replacing with Params: {:?}, Original Query: {}",
|
||||
param_strs, query
|
||||
);
|
||||
let query = replace_params(param_strs, query);
|
||||
let query = replace_params(param_strs, stmt, query)?;
|
||||
debug!("Mysql execute replaced query: {}", query);
|
||||
self.do_query(&query, query_ctx.clone()).await
|
||||
}
|
||||
@@ -662,19 +667,133 @@ fn convert_param_value_to_string(param: &ParamValue) -> String {
|
||||
ValueInner::UInt(u) => u.to_string(),
|
||||
ValueInner::Double(u) => u.to_string(),
|
||||
ValueInner::NULL => "NULL".to_string(),
|
||||
ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
|
||||
// MySQL prepared fallback emits SQL text. Delegate bytes/string literal
|
||||
// escaping to mysql_common. `false` means normal MySQL backslash escapes;
|
||||
// if NO_BACKSLASH_ESCAPES is supported in this path later, wire the
|
||||
// session SQL mode here.
|
||||
ValueInner::Bytes(b) => MysqlValue::Bytes(b.to_vec()).as_sql(false),
|
||||
ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
|
||||
ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
|
||||
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_params(params: Vec<String>, query: String) -> String {
|
||||
let mut query = query;
|
||||
for (index, param) in (1..).zip(params) {
|
||||
query = query.replace(&format_placeholder(index), ¶m);
|
||||
fn replace_params(params: Vec<String>, stmt: Statement, mut query: String) -> Result<String> {
|
||||
let spans = helper::placeholder_spans(stmt);
|
||||
ensure!(
|
||||
spans.len() == params.len(),
|
||||
error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"Prepared statement expected {} parameters but got {}",
|
||||
spans.len(),
|
||||
params.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let mut replacements = Vec::with_capacity(spans.len());
|
||||
for span in spans {
|
||||
let start = location_to_byte_offset(&query, span.start_line, span.start_column)
|
||||
.ok_or_else(|| {
|
||||
error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid placeholder start span: line {}, column {}",
|
||||
span.start_line, span.start_column
|
||||
),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let end =
|
||||
location_to_byte_offset(&query, span.end_line, span.end_column).ok_or_else(|| {
|
||||
error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid placeholder end span: line {}, column {}",
|
||||
span.end_line, span.end_column
|
||||
),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let param = span
|
||||
.index
|
||||
.checked_sub(1)
|
||||
.and_then(|idx| params.get(idx))
|
||||
.ok_or_else(|| {
|
||||
error::InternalSnafu {
|
||||
err_msg: format!("Missing prepared statement parameter {}", span.index),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
|
||||
ensure!(
|
||||
start < end && end <= query.len(),
|
||||
error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid placeholder byte span: {}..{} for query length {}",
|
||||
start,
|
||||
end,
|
||||
query.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
query.get(start..end) == Some("?"),
|
||||
error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"Prepared statement placeholder span maps to {:?} instead of '?'",
|
||||
query.get(start..end)
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
replacements.push((start, end, param.clone()));
|
||||
}
|
||||
query
|
||||
|
||||
replacements.sort_unstable_by_key(|(start, _, _)| *start);
|
||||
for windows in replacements.windows(2) {
|
||||
ensure!(
|
||||
windows[0].1 <= windows[1].0,
|
||||
error::InternalSnafu {
|
||||
err_msg: "Overlapping placeholder spans in prepared statement".to_string()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// All spans are computed against the original query. Apply replacements
|
||||
// from right to left so changing one parameter's string length never shifts
|
||||
// the byte offsets of placeholders that have not been replaced yet.
|
||||
for (start, end, param) in replacements.into_iter().rev() {
|
||||
query.replace_range(start..end, ¶m);
|
||||
}
|
||||
|
||||
Ok(query)
|
||||
}
|
||||
|
||||
fn location_to_byte_offset(query: &str, line: u64, column: u64) -> Option<usize> {
|
||||
// sqlparser spans are 1-based line/column locations, and columns advance by
|
||||
// Rust `char`s rather than bytes. Convert them to byte offsets before using
|
||||
// `String::replace_range` on the original SQL text.
|
||||
if line == 0 || column == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut current_line = 1;
|
||||
let mut current_column = 1;
|
||||
for (index, ch) in query.char_indices() {
|
||||
if current_line == line && current_column == column {
|
||||
return Some(index);
|
||||
}
|
||||
|
||||
if ch == '\n' {
|
||||
current_line += 1;
|
||||
current_column = 1;
|
||||
} else {
|
||||
current_column += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// The exclusive end location of a trailing placeholder points just past
|
||||
// the last character, for example the end span of `SELECT ?`.
|
||||
(current_line == line && current_column == column).then_some(query.len())
|
||||
}
|
||||
|
||||
fn format_duration(duration: Duration) -> String {
|
||||
@@ -778,20 +897,33 @@ fn dummy_params(index: usize) -> Result<Vec<Column>> {
|
||||
}
|
||||
|
||||
/// Parameters that the client must provide when executing the prepared statement.
|
||||
fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
|
||||
let mut params = Vec::with_capacity(param_types.len());
|
||||
fn prepared_params(
|
||||
param_types: &HashMap<String, Option<ConcreteDataType>>,
|
||||
param_num: usize,
|
||||
) -> Result<Vec<Column>> {
|
||||
let mut params = Vec::with_capacity(param_num - 1);
|
||||
|
||||
// Placeholder index starts from 1
|
||||
for index in 1..=param_types.len() {
|
||||
if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
|
||||
let column = create_mysql_column(t, "")?;
|
||||
params.push(column);
|
||||
}
|
||||
for i in 1..param_num {
|
||||
let column = if let Some(Some(t)) = param_types.get(&format_placeholder(i)) {
|
||||
create_mysql_column(t, "")?
|
||||
} else {
|
||||
create_mysql_column(&ConcreteDataType::null_datatype(), "")?
|
||||
};
|
||||
params.push(column);
|
||||
}
|
||||
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
fn all_params_have_types(
|
||||
param_types: &HashMap<String, Option<ConcreteDataType>>,
|
||||
param_num: usize,
|
||||
) -> bool {
|
||||
param_types.len() == param_num - 1
|
||||
&& (1..param_num).all(|i| matches!(param_types.get(&format_placeholder(i)), Some(Some(_))))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
@@ -852,6 +984,122 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
fn statement_with_transformed_placeholders(query: &str) -> Statement {
|
||||
let mut statements =
|
||||
ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
assert_eq!(statements.len(), 1);
|
||||
transform_placeholders_with_count(statements.remove(0)).0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepared_params_keep_unknown_type_placeholders() {
|
||||
let mut param_types = HashMap::new();
|
||||
param_types.insert(format_placeholder(1), None);
|
||||
param_types.insert(
|
||||
format_placeholder(2),
|
||||
Some(ConcreteDataType::int32_datatype()),
|
||||
);
|
||||
|
||||
let params = prepared_params(¶m_types, 3).unwrap();
|
||||
assert_eq!(params.len(), 2);
|
||||
assert!(!all_params_have_types(¶m_types, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_params_by_placeholder_span() {
|
||||
let query = "SELECT ?, ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'$2 should stay'".to_string(), "'value'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT '$2 should stay', 'value'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT ?, ?, ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec![
|
||||
"'much longer than a placeholder'".to_string(),
|
||||
"0".to_string(),
|
||||
"'also much longer than a placeholder'".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT 'much longer than a placeholder', 0, 'also much longer than a placeholder'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT '$1', \"$2\", `$3`, ?, ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'1'".to_string(), "'2'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT '$1', \"$2\", `$3`, '1', '2'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT /* ? */ ? -- ?\n, ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'first'".to_string(), "'second'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT /* ? */ 'first' -- ?\n, 'second'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT '中文', ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'value'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT '中文', 'value'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT '中文',\n ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'value'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT '中文',\n 'value'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT 'x'\r\n, ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'crlf'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT 'x'\r\n, 'crlf'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SELECT\t?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["NULL".to_string()];
|
||||
|
||||
assert_eq!("SELECT\tNULL", replace_params(params, stmt, query).unwrap());
|
||||
|
||||
let query = "SELECT CAST(? AS INT64), ? + (SELECT ?)".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["1".to_string(), "2".to_string(), "3".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SELECT CAST(1 AS INT64), 2 + (SELECT 3)",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
|
||||
let query = "SET time_zone = ?".to_string();
|
||||
let stmt = statement_with_transformed_placeholders(&query);
|
||||
let params = vec!["'UTC'".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
"SET time_zone = 'UTC'",
|
||||
replace_params(params, stmt, query).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prepare_federated_query() {
|
||||
let mut shim = create_shim();
|
||||
|
||||
@@ -23,6 +23,7 @@ use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::ColumnSchema;
|
||||
use datatypes::types::TimestampType;
|
||||
use datatypes::value::{self, Value};
|
||||
#[cfg(test)]
|
||||
use itertools::Itertools;
|
||||
use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
|
||||
use snafu::ResultExt;
|
||||
@@ -31,6 +32,17 @@ use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
|
||||
/// Location of a prepared-statement placeholder in the original SQL text.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct PlaceholderSpan {
|
||||
/// 1-based placeholder index.
|
||||
pub(crate) index: usize,
|
||||
pub(crate) start_line: u64,
|
||||
pub(crate) start_column: u64,
|
||||
pub(crate) end_line: u64,
|
||||
pub(crate) end_column: u64,
|
||||
}
|
||||
|
||||
/// Returns the placeholder string "$i".
|
||||
pub fn format_placeholder(i: usize) -> String {
|
||||
format!("${}", i)
|
||||
@@ -38,6 +50,7 @@ pub fn format_placeholder(i: usize) -> String {
|
||||
|
||||
/// Replace all the "?" placeholder into "$i" in SQL,
|
||||
/// returns the new SQL and the last placeholder index.
|
||||
#[cfg(test)]
|
||||
pub fn replace_placeholders(query: &str) -> (String, usize) {
|
||||
let query_parts = query.split('?').collect::<Vec<_>>();
|
||||
let parts_len = query_parts.len();
|
||||
@@ -58,27 +71,51 @@ pub fn replace_placeholders(query: &str) -> (String, usize) {
|
||||
(query, index + 1)
|
||||
}
|
||||
|
||||
/// Transform all the "?" placeholder into "$i".
|
||||
/// Only works for Insert,Query and Delete statements.
|
||||
pub fn transform_placeholders(stmt: Statement) -> Statement {
|
||||
match stmt {
|
||||
Statement::Query(mut query) => {
|
||||
visit_placeholders(&mut query.inner);
|
||||
Statement::Query(query)
|
||||
}
|
||||
Statement::Insert(mut insert) => {
|
||||
visit_placeholders(&mut insert.inner);
|
||||
Statement::Insert(insert)
|
||||
}
|
||||
Statement::Delete(mut delete) => {
|
||||
visit_placeholders(&mut delete.inner);
|
||||
Statement::Delete(delete)
|
||||
}
|
||||
stmt => stmt,
|
||||
}
|
||||
/// Transform all the "?" placeholders into "$i" and return the number of
|
||||
/// transformed placeholders.
|
||||
pub fn transform_placeholders_with_count(mut stmt: Statement) -> (Statement, usize) {
|
||||
let count = visit_placeholders(&mut stmt);
|
||||
(stmt, count)
|
||||
}
|
||||
|
||||
fn visit_placeholders<V>(v: &mut V)
|
||||
/// Collect spans of "$i" placeholders in a statement.
|
||||
pub(crate) fn placeholder_spans(mut stmt: Statement) -> Vec<PlaceholderSpan> {
|
||||
let mut spans = Vec::new();
|
||||
collect_placeholder_spans(&mut stmt, &mut spans);
|
||||
spans
|
||||
}
|
||||
|
||||
fn collect_placeholder_spans<V>(v: &mut V, spans: &mut Vec<PlaceholderSpan>)
|
||||
where
|
||||
V: VisitMut,
|
||||
{
|
||||
let _ = visit_expressions_mut(v, |expr| {
|
||||
if let Expr::Value(ValueWithSpan {
|
||||
value: ValueExpr::Placeholder(s),
|
||||
span,
|
||||
}) = expr
|
||||
&& let Some(index) = placeholder_index(s)
|
||||
{
|
||||
spans.push(PlaceholderSpan {
|
||||
index,
|
||||
start_line: span.start.line,
|
||||
start_column: span.start.column,
|
||||
end_line: span.end.line,
|
||||
end_column: span.end.column,
|
||||
});
|
||||
}
|
||||
ControlFlow::<()>::Continue(())
|
||||
});
|
||||
}
|
||||
|
||||
fn placeholder_index(s: &str) -> Option<usize> {
|
||||
s.strip_prefix('$')?
|
||||
.parse::<usize>()
|
||||
.ok()
|
||||
.filter(|i| *i > 0)
|
||||
}
|
||||
|
||||
fn visit_placeholders<V>(v: &mut V) -> usize
|
||||
where
|
||||
V: VisitMut,
|
||||
{
|
||||
@@ -88,12 +125,14 @@ where
|
||||
value: ValueExpr::Placeholder(s),
|
||||
..
|
||||
}) = expr
|
||||
&& s == "?"
|
||||
{
|
||||
*s = format_placeholder(index);
|
||||
index += 1;
|
||||
}
|
||||
ControlFlow::<()>::Continue(())
|
||||
});
|
||||
index - 1
|
||||
}
|
||||
|
||||
/// Convert [`ParamValue`] into [`Value`] according to param type.
|
||||
@@ -340,33 +379,52 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transform_placeholders() {
|
||||
let insert = parse_sql("insert into demo values(?,?,?)");
|
||||
let Statement::Insert(insert) = transform_placeholders(insert) else {
|
||||
let (stmt, count) = transform_placeholders_with_count(insert);
|
||||
let Statement::Insert(insert) = stmt else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(
|
||||
"INSERT INTO demo VALUES ($1, $2, $3)",
|
||||
insert.inner.to_string()
|
||||
);
|
||||
assert_eq!(3, count);
|
||||
|
||||
let delete = parse_sql("delete from demo where host=? and idc=?");
|
||||
let Statement::Delete(delete) = transform_placeholders(delete) else {
|
||||
let (stmt, count) = transform_placeholders_with_count(delete);
|
||||
let Statement::Delete(delete) = stmt else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(
|
||||
"DELETE FROM demo WHERE host = $1 AND idc = $2",
|
||||
delete.inner.to_string()
|
||||
);
|
||||
assert_eq!(2, count);
|
||||
|
||||
let select = parse_sql(
|
||||
"select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
|
||||
);
|
||||
let Statement::Query(select) = transform_placeholders(select) else {
|
||||
let (stmt, count) = transform_placeholders_with_count(select);
|
||||
let Statement::Query(select) = stmt else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(
|
||||
"SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
|
||||
select.inner.to_string()
|
||||
);
|
||||
assert_eq!(3, count);
|
||||
|
||||
let select = parse_sql("select '?', ?");
|
||||
let (stmt, count) = transform_placeholders_with_count(select);
|
||||
let Statement::Query(select) = stmt else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!("SELECT '?', $1", select.inner.to_string());
|
||||
assert_eq!(1, count);
|
||||
|
||||
let set = parse_sql("set time_zone = ?");
|
||||
let (stmt, count) = transform_placeholders_with_count(set);
|
||||
assert_eq!("SET time_zone = $1", stmt.to_string());
|
||||
assert_eq!(1, count);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -548,6 +548,140 @@ async fn test_query_prepared() -> Result<()> {
|
||||
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
|
||||
}
|
||||
|
||||
// Untyped placeholders should still be advertised in the MySQL prepare
|
||||
// response. This used to fail on the client side because the server
|
||||
// reported 0 parameters for `SELECT ?`.
|
||||
{
|
||||
let stmt = connection.prep("SELECT ?").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let row: Option<String> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"can't".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some("can't".to_string()));
|
||||
|
||||
let stmt = connection.prep("SELECT ?").await.unwrap();
|
||||
let row: Option<String> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"a\\'b".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some("a\\'b".to_string()));
|
||||
|
||||
let stmt = connection.prep("SELECT ?").await.unwrap();
|
||||
let row: Option<Vec<u8>> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(vec![0xFF, 0xFE])])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(vec![0xFF, 0xFE]));
|
||||
|
||||
let stmt = connection.prep("SELECT ?").await.unwrap();
|
||||
let row: Option<Option<String>> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::NULL])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(None));
|
||||
}
|
||||
|
||||
// Values inserted into the SQL text must not be processed again while
|
||||
// replacing later placeholders.
|
||||
{
|
||||
let stmt = connection.prep("SELECT ?, ?").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 2);
|
||||
|
||||
let row: Option<(String, String)> = connection
|
||||
.exec_first(
|
||||
stmt,
|
||||
vec![
|
||||
mysql_async::Value::Bytes(b"keep $2".to_vec()),
|
||||
mysql_async::Value::Bytes(b"second".to_vec()),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(("keep $2".to_string(), "second".to_string())));
|
||||
}
|
||||
|
||||
// Non-placeholder question marks inside string literals must not affect
|
||||
// the advertised prepare parameter count.
|
||||
{
|
||||
let stmt = connection.prep("SELECT '?', ?").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let row: Option<(String, String)> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(("?".to_string(), "actual".to_string())));
|
||||
|
||||
let stmt = connection.prep("SELECT '$1', ?").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let row: Option<(String, String)> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(("$1".to_string(), "actual".to_string())));
|
||||
|
||||
let stmt = connection.prep("SELECT /* ? */ ? -- ?\n").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let row: Option<String> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"commented".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some("commented".to_string()));
|
||||
|
||||
let stmt = connection.prep("SELECT '中文', ?").await.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let row: Option<(String, String)> = connection
|
||||
.exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row, Some(("中文".to_string(), "actual".to_string())));
|
||||
}
|
||||
|
||||
// Also cover mixed known and unknown placeholders. The projection
|
||||
// placeholder is untyped, while the WHERE placeholder is inferred from the
|
||||
// column type. The prepare response must advertise both parameters.
|
||||
{
|
||||
let stmt = connection
|
||||
.prep("SELECT ?, uint32s FROM all_datatypes WHERE uint32s >= ?")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(stmt.num_params(), 2);
|
||||
|
||||
let rows: Vec<Row> = connection
|
||||
.exec(
|
||||
stmt,
|
||||
vec![
|
||||
mysql_async::Value::Bytes(b"unknown".to_vec()),
|
||||
mysql_async::Value::UInt(0),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!rows.is_empty());
|
||||
}
|
||||
|
||||
// LIMIT placeholders used to be a common case where DataFusion did not
|
||||
// infer a parameter type. The prepare response must still advertise the
|
||||
// parameter and execution must substitute it correctly.
|
||||
{
|
||||
let stmt = connection
|
||||
.prep("SELECT uint32s FROM all_datatypes ORDER BY uint32s LIMIT ?")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(stmt.num_params(), 1);
|
||||
|
||||
let rows: Vec<Row> = connection
|
||||
.exec(stmt, vec![mysql_async::Value::UInt(1)])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user