diff --git a/Cargo.lock b/Cargo.lock index aafa225b4b..45b726ab82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12705,6 +12705,7 @@ dependencies = [ "metric-engine", "mime_guess", "mysql_async", + "mysql_common 0.34.1", "notify", "object-pool", "once_cell", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index b35b30968a..f870d3c3ca 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -82,6 +82,7 @@ log-query.workspace = true loki-proto.workspace = true metric-engine.workspace = true mime_guess = "2.0" +mysql_common = "0.34" notify.workspace = true object-pool = "0.5" once_cell.workspace = true diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 3a80593c63..88a539ea21 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -30,6 +30,7 @@ use datafusion_expr::LogicalPlan; use datatypes::prelude::ConcreteDataType; use datatypes::schema::Schema; use itertools::Itertools; +use mysql_common::Value as MysqlValue; use opensrv_mysql::{ AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner, @@ -51,9 +52,7 @@ use crate::error::{ self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result, }; use crate::metrics::METRIC_AUTH_FAILURE; -use crate::mysql::helper::{ - self, format_placeholder, replace_placeholders, transform_placeholders, -}; +use crate::mysql::helper::{self, format_placeholder, transform_placeholders_with_count}; use crate::mysql::writer; use crate::mysql::writer::{create_mysql_column, handle_err}; use crate::query_handler::sql::ServerSqlQueryHandlerRef; @@ -192,28 +191,31 @@ impl MysqlInstanceShim { return Ok((vec![], vec![])); } - let (query, param_num) = replace_placeholders(raw_query); - let statement = validate_query(raw_query).await?; // We have to transform the placeholder, because DataFusion only parses placeholders // in the form of "$i", it can't process "?" right now. - let statement = transform_placeholders(statement); + let (statement, placeholder_count) = transform_placeholders_with_count(statement); + let param_num = placeholder_count + 1; let describe_result = self .do_describe(statement.clone(), query_ctx.clone()) .await?; let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan); - let params = if let Some(plan) = &plan { + let (params, can_cache_as_plan) = if let Some(plan) = &plan { let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) .context(InferParameterTypesSnafu)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) .collect(); - prepared_params(¶m_types)? + + ( + prepared_params(¶m_types, param_num)?, + all_params_have_types(¶m_types, param_num), + ) } else { - dummy_params(param_num)? + (dummy_params(param_num)?, false) }; let columns = @@ -239,17 +241,20 @@ impl MysqlInstanceShim { .unwrap_or_default(); match plan { - Some(plan) if params.len() == param_num - 1 => { + Some(plan) if can_cache_as_plan => { self.save_plan(SqlPlan::Plan(plan, statement), stmt_key) .inspect_err(|e| { error!(e; "Failed to save prepared statement"); })?; } _ => { - self.save_plan(SqlPlan::Statement(statement, query), stmt_key) - .inspect_err(|e| { - error!(e; "Failed to save prepared statement"); - })?; + self.save_plan( + SqlPlan::Statement(statement, raw_query.to_string()), + stmt_key, + ) + .inspect_err(|e| { + error!(e; "Failed to save prepared statement"); + })?; } } @@ -312,7 +317,7 @@ impl MysqlInstanceShim { self.do_query(&query, query_ctx.clone()).await } } - SqlPlan::Statement(_stmt, query) => { + SqlPlan::Statement(stmt, query) => { let param_strs = match params { Params::ProtocolParams(params) => { params.iter().map(convert_param_value_to_string).collect() @@ -323,7 +328,7 @@ impl MysqlInstanceShim { "do_execute Replacing with Params: {:?}, Original Query: {}", param_strs, query ); - let query = replace_params(param_strs, query); + let query = replace_params(param_strs, stmt, query)?; debug!("Mysql execute replaced query: {}", query); self.do_query(&query, query_ctx.clone()).await } @@ -662,19 +667,133 @@ fn convert_param_value_to_string(param: &ParamValue) -> String { ValueInner::UInt(u) => u.to_string(), ValueInner::Double(u) => u.to_string(), ValueInner::NULL => "NULL".to_string(), - ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)), + // MySQL prepared fallback emits SQL text. Delegate bytes/string literal + // escaping to mysql_common. `false` means normal MySQL backslash escapes; + // if NO_BACKSLASH_ESCAPES is supported in this path later, wire the + // session SQL mode here. + ValueInner::Bytes(b) => MysqlValue::Bytes(b.to_vec()).as_sql(false), ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)), ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)), ValueInner::Time(_) => format_duration(Duration::from(param.value)), } } -fn replace_params(params: Vec, query: String) -> String { - let mut query = query; - for (index, param) in (1..).zip(params) { - query = query.replace(&format_placeholder(index), ¶m); +fn replace_params(params: Vec, stmt: Statement, mut query: String) -> Result { + let spans = helper::placeholder_spans(stmt); + ensure!( + spans.len() == params.len(), + error::InternalSnafu { + err_msg: format!( + "Prepared statement expected {} parameters but got {}", + spans.len(), + params.len() + ) + } + ); + + let mut replacements = Vec::with_capacity(spans.len()); + for span in spans { + let start = location_to_byte_offset(&query, span.start_line, span.start_column) + .ok_or_else(|| { + error::InternalSnafu { + err_msg: format!( + "Invalid placeholder start span: line {}, column {}", + span.start_line, span.start_column + ), + } + .build() + })?; + let end = + location_to_byte_offset(&query, span.end_line, span.end_column).ok_or_else(|| { + error::InternalSnafu { + err_msg: format!( + "Invalid placeholder end span: line {}, column {}", + span.end_line, span.end_column + ), + } + .build() + })?; + let param = span + .index + .checked_sub(1) + .and_then(|idx| params.get(idx)) + .ok_or_else(|| { + error::InternalSnafu { + err_msg: format!("Missing prepared statement parameter {}", span.index), + } + .build() + })?; + + ensure!( + start < end && end <= query.len(), + error::InternalSnafu { + err_msg: format!( + "Invalid placeholder byte span: {}..{} for query length {}", + start, + end, + query.len() + ) + } + ); + ensure!( + query.get(start..end) == Some("?"), + error::InternalSnafu { + err_msg: format!( + "Prepared statement placeholder span maps to {:?} instead of '?'", + query.get(start..end) + ) + } + ); + + replacements.push((start, end, param.clone())); } - query + + replacements.sort_unstable_by_key(|(start, _, _)| *start); + for windows in replacements.windows(2) { + ensure!( + windows[0].1 <= windows[1].0, + error::InternalSnafu { + err_msg: "Overlapping placeholder spans in prepared statement".to_string() + } + ); + } + + // All spans are computed against the original query. Apply replacements + // from right to left so changing one parameter's string length never shifts + // the byte offsets of placeholders that have not been replaced yet. + for (start, end, param) in replacements.into_iter().rev() { + query.replace_range(start..end, ¶m); + } + + Ok(query) +} + +fn location_to_byte_offset(query: &str, line: u64, column: u64) -> Option { + // sqlparser spans are 1-based line/column locations, and columns advance by + // Rust `char`s rather than bytes. Convert them to byte offsets before using + // `String::replace_range` on the original SQL text. + if line == 0 || column == 0 { + return None; + } + + let mut current_line = 1; + let mut current_column = 1; + for (index, ch) in query.char_indices() { + if current_line == line && current_column == column { + return Some(index); + } + + if ch == '\n' { + current_line += 1; + current_column = 1; + } else { + current_column += 1; + } + } + + // The exclusive end location of a trailing placeholder points just past + // the last character, for example the end span of `SELECT ?`. + (current_line == line && current_column == column).then_some(query.len()) } fn format_duration(duration: Duration) -> String { @@ -778,20 +897,33 @@ fn dummy_params(index: usize) -> Result> { } /// Parameters that the client must provide when executing the prepared statement. -fn prepared_params(param_types: &HashMap>) -> Result> { - let mut params = Vec::with_capacity(param_types.len()); +fn prepared_params( + param_types: &HashMap>, + param_num: usize, +) -> Result> { + let mut params = Vec::with_capacity(param_num - 1); // Placeholder index starts from 1 - for index in 1..=param_types.len() { - if let Some(Some(t)) = param_types.get(&format_placeholder(index)) { - let column = create_mysql_column(t, "")?; - params.push(column); - } + for i in 1..param_num { + let column = if let Some(Some(t)) = param_types.get(&format_placeholder(i)) { + create_mysql_column(t, "")? + } else { + create_mysql_column(&ConcreteDataType::null_datatype(), "")? + }; + params.push(column); } Ok(params) } +fn all_params_have_types( + param_types: &HashMap>, + param_num: usize, +) -> bool { + param_types.len() == param_num - 1 + && (1..param_num).all(|i| matches!(param_types.get(&format_placeholder(i)), Some(Some(_)))) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -852,6 +984,122 @@ mod tests { ) } + fn statement_with_transformed_placeholders(query: &str) -> Statement { + let mut statements = + ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default()) + .unwrap(); + assert_eq!(statements.len(), 1); + transform_placeholders_with_count(statements.remove(0)).0 + } + + #[test] + fn test_prepared_params_keep_unknown_type_placeholders() { + let mut param_types = HashMap::new(); + param_types.insert(format_placeholder(1), None); + param_types.insert( + format_placeholder(2), + Some(ConcreteDataType::int32_datatype()), + ); + + let params = prepared_params(¶m_types, 3).unwrap(); + assert_eq!(params.len(), 2); + assert!(!all_params_have_types(¶m_types, 3)); + } + + #[test] + fn test_replace_params_by_placeholder_span() { + let query = "SELECT ?, ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'$2 should stay'".to_string(), "'value'".to_string()]; + + assert_eq!( + "SELECT '$2 should stay', 'value'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT ?, ?, ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec![ + "'much longer than a placeholder'".to_string(), + "0".to_string(), + "'also much longer than a placeholder'".to_string(), + ]; + + assert_eq!( + "SELECT 'much longer than a placeholder', 0, 'also much longer than a placeholder'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT '$1', \"$2\", `$3`, ?, ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'1'".to_string(), "'2'".to_string()]; + + assert_eq!( + "SELECT '$1', \"$2\", `$3`, '1', '2'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT /* ? */ ? -- ?\n, ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'first'".to_string(), "'second'".to_string()]; + + assert_eq!( + "SELECT /* ? */ 'first' -- ?\n, 'second'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT '中文', ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'value'".to_string()]; + + assert_eq!( + "SELECT '中文', 'value'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT '中文',\n ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'value'".to_string()]; + + assert_eq!( + "SELECT '中文',\n 'value'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT 'x'\r\n, ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'crlf'".to_string()]; + + assert_eq!( + "SELECT 'x'\r\n, 'crlf'", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SELECT\t?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["NULL".to_string()]; + + assert_eq!("SELECT\tNULL", replace_params(params, stmt, query).unwrap()); + + let query = "SELECT CAST(? AS INT64), ? + (SELECT ?)".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["1".to_string(), "2".to_string(), "3".to_string()]; + + assert_eq!( + "SELECT CAST(1 AS INT64), 2 + (SELECT 3)", + replace_params(params, stmt, query).unwrap() + ); + + let query = "SET time_zone = ?".to_string(); + let stmt = statement_with_transformed_placeholders(&query); + let params = vec!["'UTC'".to_string()]; + + assert_eq!( + "SET time_zone = 'UTC'", + replace_params(params, stmt, query).unwrap() + ); + } + #[tokio::test] async fn test_prepare_federated_query() { let mut shim = create_shim(); diff --git a/src/servers/src/mysql/helper.rs b/src/servers/src/mysql/helper.rs index 2ee2421892..c4c072b007 100644 --- a/src/servers/src/mysql/helper.rs +++ b/src/servers/src/mysql/helper.rs @@ -23,6 +23,7 @@ use datatypes::prelude::ConcreteDataType; use datatypes::schema::ColumnSchema; use datatypes::types::TimestampType; use datatypes::value::{self, Value}; +#[cfg(test)] use itertools::Itertools; use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime}; use snafu::ResultExt; @@ -31,6 +32,17 @@ use sql::statements::statement::Statement; use crate::error::{self, Result}; +/// Location of a prepared-statement placeholder in the original SQL text. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct PlaceholderSpan { + /// 1-based placeholder index. + pub(crate) index: usize, + pub(crate) start_line: u64, + pub(crate) start_column: u64, + pub(crate) end_line: u64, + pub(crate) end_column: u64, +} + /// Returns the placeholder string "$i". pub fn format_placeholder(i: usize) -> String { format!("${}", i) @@ -38,6 +50,7 @@ pub fn format_placeholder(i: usize) -> String { /// Replace all the "?" placeholder into "$i" in SQL, /// returns the new SQL and the last placeholder index. +#[cfg(test)] pub fn replace_placeholders(query: &str) -> (String, usize) { let query_parts = query.split('?').collect::>(); let parts_len = query_parts.len(); @@ -58,27 +71,51 @@ pub fn replace_placeholders(query: &str) -> (String, usize) { (query, index + 1) } -/// Transform all the "?" placeholder into "$i". -/// Only works for Insert,Query and Delete statements. -pub fn transform_placeholders(stmt: Statement) -> Statement { - match stmt { - Statement::Query(mut query) => { - visit_placeholders(&mut query.inner); - Statement::Query(query) - } - Statement::Insert(mut insert) => { - visit_placeholders(&mut insert.inner); - Statement::Insert(insert) - } - Statement::Delete(mut delete) => { - visit_placeholders(&mut delete.inner); - Statement::Delete(delete) - } - stmt => stmt, - } +/// Transform all the "?" placeholders into "$i" and return the number of +/// transformed placeholders. +pub fn transform_placeholders_with_count(mut stmt: Statement) -> (Statement, usize) { + let count = visit_placeholders(&mut stmt); + (stmt, count) } -fn visit_placeholders(v: &mut V) +/// Collect spans of "$i" placeholders in a statement. +pub(crate) fn placeholder_spans(mut stmt: Statement) -> Vec { + let mut spans = Vec::new(); + collect_placeholder_spans(&mut stmt, &mut spans); + spans +} + +fn collect_placeholder_spans(v: &mut V, spans: &mut Vec) +where + V: VisitMut, +{ + let _ = visit_expressions_mut(v, |expr| { + if let Expr::Value(ValueWithSpan { + value: ValueExpr::Placeholder(s), + span, + }) = expr + && let Some(index) = placeholder_index(s) + { + spans.push(PlaceholderSpan { + index, + start_line: span.start.line, + start_column: span.start.column, + end_line: span.end.line, + end_column: span.end.column, + }); + } + ControlFlow::<()>::Continue(()) + }); +} + +fn placeholder_index(s: &str) -> Option { + s.strip_prefix('$')? + .parse::() + .ok() + .filter(|i| *i > 0) +} + +fn visit_placeholders(v: &mut V) -> usize where V: VisitMut, { @@ -88,12 +125,14 @@ where value: ValueExpr::Placeholder(s), .. }) = expr + && s == "?" { *s = format_placeholder(index); index += 1; } ControlFlow::<()>::Continue(()) }); + index - 1 } /// Convert [`ParamValue`] into [`Value`] according to param type. @@ -340,33 +379,52 @@ mod tests { #[test] fn test_transform_placeholders() { let insert = parse_sql("insert into demo values(?,?,?)"); - let Statement::Insert(insert) = transform_placeholders(insert) else { + let (stmt, count) = transform_placeholders_with_count(insert); + let Statement::Insert(insert) = stmt else { unreachable!() }; assert_eq!( "INSERT INTO demo VALUES ($1, $2, $3)", insert.inner.to_string() ); + assert_eq!(3, count); let delete = parse_sql("delete from demo where host=? and idc=?"); - let Statement::Delete(delete) = transform_placeholders(delete) else { + let (stmt, count) = transform_placeholders_with_count(delete); + let Statement::Delete(delete) = stmt else { unreachable!() }; assert_eq!( "DELETE FROM demo WHERE host = $1 AND idc = $2", delete.inner.to_string() ); + assert_eq!(2, count); let select = parse_sql( "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?", ); - let Statement::Query(select) = transform_placeholders(select) else { + let (stmt, count) = transform_placeholders_with_count(select); + let Statement::Query(select) = stmt else { unreachable!() }; assert_eq!( "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string() ); + assert_eq!(3, count); + + let select = parse_sql("select '?', ?"); + let (stmt, count) = transform_placeholders_with_count(select); + let Statement::Query(select) = stmt else { + unreachable!() + }; + assert_eq!("SELECT '?', $1", select.inner.to_string()); + assert_eq!(1, count); + + let set = parse_sql("set time_zone = ?"); + let (stmt, count) = transform_placeholders_with_count(set); + assert_eq!("SET time_zone = $1", stmt.to_string()); + assert_eq!(1, count); } #[test] diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 0694cc7746..888ac92fc3 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -548,6 +548,140 @@ async fn test_query_prepared() -> Result<()> { assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row"); } + // Untyped placeholders should still be advertised in the MySQL prepare + // response. This used to fail on the client side because the server + // reported 0 parameters for `SELECT ?`. + { + let stmt = connection.prep("SELECT ?").await.unwrap(); + assert_eq!(stmt.num_params(), 1); + + let row: Option = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"can't".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some("can't".to_string())); + + let stmt = connection.prep("SELECT ?").await.unwrap(); + let row: Option = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"a\\'b".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some("a\\'b".to_string())); + + let stmt = connection.prep("SELECT ?").await.unwrap(); + let row: Option> = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(vec![0xFF, 0xFE])]) + .await + .unwrap(); + assert_eq!(row, Some(vec![0xFF, 0xFE])); + + let stmt = connection.prep("SELECT ?").await.unwrap(); + let row: Option> = connection + .exec_first(stmt, vec![mysql_async::Value::NULL]) + .await + .unwrap(); + assert_eq!(row, Some(None)); + } + + // Values inserted into the SQL text must not be processed again while + // replacing later placeholders. + { + let stmt = connection.prep("SELECT ?, ?").await.unwrap(); + assert_eq!(stmt.num_params(), 2); + + let row: Option<(String, String)> = connection + .exec_first( + stmt, + vec![ + mysql_async::Value::Bytes(b"keep $2".to_vec()), + mysql_async::Value::Bytes(b"second".to_vec()), + ], + ) + .await + .unwrap(); + assert_eq!(row, Some(("keep $2".to_string(), "second".to_string()))); + } + + // Non-placeholder question marks inside string literals must not affect + // the advertised prepare parameter count. + { + let stmt = connection.prep("SELECT '?', ?").await.unwrap(); + assert_eq!(stmt.num_params(), 1); + + let row: Option<(String, String)> = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some(("?".to_string(), "actual".to_string()))); + + let stmt = connection.prep("SELECT '$1', ?").await.unwrap(); + assert_eq!(stmt.num_params(), 1); + + let row: Option<(String, String)> = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some(("$1".to_string(), "actual".to_string()))); + + let stmt = connection.prep("SELECT /* ? */ ? -- ?\n").await.unwrap(); + assert_eq!(stmt.num_params(), 1); + + let row: Option = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"commented".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some("commented".to_string())); + + let stmt = connection.prep("SELECT '中文', ?").await.unwrap(); + assert_eq!(stmt.num_params(), 1); + + let row: Option<(String, String)> = connection + .exec_first(stmt, vec![mysql_async::Value::Bytes(b"actual".to_vec())]) + .await + .unwrap(); + assert_eq!(row, Some(("中文".to_string(), "actual".to_string()))); + } + + // Also cover mixed known and unknown placeholders. The projection + // placeholder is untyped, while the WHERE placeholder is inferred from the + // column type. The prepare response must advertise both parameters. + { + let stmt = connection + .prep("SELECT ?, uint32s FROM all_datatypes WHERE uint32s >= ?") + .await + .unwrap(); + assert_eq!(stmt.num_params(), 2); + + let rows: Vec = connection + .exec( + stmt, + vec![ + mysql_async::Value::Bytes(b"unknown".to_vec()), + mysql_async::Value::UInt(0), + ], + ) + .await + .unwrap(); + assert!(!rows.is_empty()); + } + + // LIMIT placeholders used to be a common case where DataFusion did not + // infer a parameter type. The prepare response must still advertise the + // parameter and execution must substitute it correctly. + { + let stmt = connection + .prep("SELECT uint32s FROM all_datatypes ORDER BY uint32s LIMIT ?") + .await + .unwrap(); + assert_eq!(stmt.num_params(), 1); + + let rows: Vec = connection + .exec(stmt, vec![mysql_async::Value::UInt(1)]) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + } + Ok(()) }