From 95090592f0ce105287d86a0fc618474011687717 Mon Sep 17 00:00:00 2001 From: SSebo Date: Wed, 8 Mar 2023 11:02:29 +0800 Subject: [PATCH] feat: mysql prepare replacing sql placeholder to param (#1086) * feat: mysql prepare by replace ? in sql to param * chore: mysql prepare statment support time param * chore: prepare test more types * chore: add TODO --- Cargo.lock | 1 + src/servers/Cargo.toml | 1 + src/servers/src/error.rs | 3 + src/servers/src/mysql/handler.rs | 183 +++++++++++++++++-- src/servers/src/mysql/writer.rs | 15 +- src/servers/tests/mysql/mod.rs | 2 +- src/servers/tests/mysql/mysql_server_test.rs | 95 +++++++++- 7 files changed, 274 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8cdfe1285e..cde0a5087c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6949,6 +6949,7 @@ dependencies = [ "once_cell", "openmetrics-parser", "opensrv-mysql", + "parking_lot", "pgwire", "pin-project", "postgres-types", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 1cacfa59b8..b8203dfd1c 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -43,6 +43,7 @@ num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" opensrv-mysql = { git = "https://github.com/sunng87/opensrv", branch = "fix/buffer-overread" } +parking_lot = "0.12" pgwire = "0.10" pin-project = "1.0" postgres-types = { version = "0.2", features = ["with-chrono-0_4"] } diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 470ab13554..4e6a3b5e5f 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -263,6 +263,8 @@ pub enum Error { #[snafu(backtrace)] source: common_mem_prof::error::Error, }, + #[snafu(display("Invalid prepare statement: {}", err_msg))] + InvalidPrepareStatement { err_msg: String }, } pub type Result = std::result::Result; @@ -302,6 +304,7 @@ impl ErrorExt for Error { | DecompressPromRemoteRequest { .. } | InvalidPromRemoteRequest { .. } | InvalidFlightTicket { .. } + | InvalidPrepareStatement { .. } | TimePrecision { .. } => StatusCode::InvalidArguments, InfluxdbLinesWrite { source, .. } | ConvertFlightMessage { source } => { diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index f3d3f35190..922d5a87a7 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -12,24 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use async_trait::async_trait; +use chrono::{NaiveDate, NaiveDateTime}; use common_query::Output; +use common_telemetry::tracing::log; use common_telemetry::{error, trace}; use opensrv_mysql::{ - AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter, + AsyncMysqlShim, Column, ColumnFlags, ColumnType, ErrorKind, InitWriter, ParamParser, + ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner, }; +use parking_lot::RwLock; use rand::RngCore; use session::context::Channel; use session::Session; use snafu::ensure; +use sql::dialect::GenericDialect; +use sql::parser::ParserContext; +use sql::statements::statement::Statement; use tokio::io::AsyncWrite; use crate::auth::{Identity, Password, UserProviderRef}; -use crate::error::{self, Result}; +use crate::error::{self, InvalidPrepareStatementSnafu, Result}; use crate::mysql::writer::MysqlResultWriter; use crate::query_handler::sql::ServerSqlQueryHandlerRef; @@ -39,6 +48,9 @@ pub struct MysqlInstanceShim { salt: [u8; 20], session: Arc, user_provider: Option, + // TODO(SSebo): use something like moka to achieve TTL or LRU + prepared_stmts: Arc>>, + prepared_stmts_counter: AtomicU32, } impl MysqlInstanceShim { @@ -65,6 +77,8 @@ impl MysqlInstanceShim { salt: scramble, session: Arc::new(Session::new(client_addr, Channel::Mysql)), user_provider, + prepared_stmts: Default::default(), + prepared_stmts_counter: AtomicU32::new(1), } } @@ -91,6 +105,18 @@ impl MysqlInstanceShim { ); output } + + fn set_query(&self, query: String) -> u32 { + let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::SeqCst); + let mut guard = self.prepared_stmts.write(); + guard.insert(stmt_id, query); + stmt_id + } + + fn query(&self, stmt_id: u32) -> Option { + let guard = self.prepared_stmts.read(); + guard.get(&stmt_id).map(|s| s.to_owned()) + } } #[async_trait] @@ -140,34 +166,59 @@ impl AsyncMysqlShim for MysqlInstanceShi true } - async fn on_prepare<'a>(&'a mut self, _: &'a str, w: StatementMetaWriter<'a, W>) -> Result<()> { - w.error( - ErrorKind::ER_UNKNOWN_ERROR, - b"prepare statement is not supported yet", - ) - .await?; - Ok(()) + async fn on_prepare<'a>( + &'a mut self, + query: &'a str, + w: StatementMetaWriter<'a, W>, + ) -> Result<()> { + let (query, param_num) = replace_placeholder(query); + if let Err(e) = validate_query(&query).await { + w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) + .await?; + return Ok(()); + }; + + let stmt_id = self.set_query(query); + let params = dummy_params(param_num); + + w.reply(stmt_id, ¶ms, &[]).await?; + return Ok(()); } async fn on_execute<'a>( &'a mut self, - _: u32, - _: ParamParser<'a>, + stmt_id: u32, + p: ParamParser<'a>, w: QueryResultWriter<'a, W>, ) -> Result<()> { - w.error( - ErrorKind::ER_UNKNOWN_ERROR, - b"prepare statement is not supported yet", - ) - .await?; + let params: Vec = p.into_iter().collect(); + let query = match self.query(stmt_id) { + None => { + w.error( + ErrorKind::ER_UNKNOWN_STMT_HANDLER, + b"prepare statement not exist", + ) + .await?; + return Ok(()); + } + Some(query) => query, + }; + + let query = replace_params(params, query); + log::debug!("execute replaced query: {}", query); + + let outputs = self.do_query(&query).await; + write_output(w, &query, outputs).await?; + Ok(()) } - async fn on_close<'a>(&'a mut self, _stmt_id: u32) + async fn on_close<'a>(&'a mut self, stmt_id: u32) where W: 'async_trait, { - // do nothing because we haven't implemented prepare statement + let mut guard = self.prepared_stmts.write(); + guard.remove(&stmt_id); } async fn on_query<'a>( @@ -211,3 +262,97 @@ impl AsyncMysqlShim for MysqlInstanceShi w.ok().await.map_err(|e| e.into()) } } + +fn replace_params(params: Vec, query: String) -> String { + let mut query = query; + let mut index = 1; + for param in params { + let s = match param.value.into_inner() { + ValueInner::Int(u) => u.to_string(), + ValueInner::UInt(u) => u.to_string(), + ValueInner::Double(u) => u.to_string(), + ValueInner::NULL => "NULL".to_string(), + ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)), + ValueInner::Date(_) => NaiveDate::from(param.value).to_string(), + ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), + ValueInner::Time(_) => format_duration(Duration::from(param.value)), + }; + query = query.replace(&format!("${}", index), &s); + index += 1; + } + query +} + +fn format_duration(duration: Duration) -> String { + let seconds = duration.as_secs() % 60; + let minutes = (duration.as_secs() / 60) % 60; + let hours = (duration.as_secs() / 60) / 60; + format!("{}:{}:{}", hours, minutes, seconds) +} + +async fn validate_query(query: &str) -> Result { + let statement = ParserContext::create_with_dialect(query, &GenericDialect {}); + let mut statement = statement.map_err(|e| { + InvalidPrepareStatementSnafu { + err_msg: e.to_string(), + } + .build() + })?; + + ensure!( + statement.len() == 1, + InvalidPrepareStatementSnafu { + err_msg: "prepare statement only support single statement".to_string(), + } + ); + + let statement = statement.remove(0); + + ensure!( + matches!(statement, Statement::Query(_)), + InvalidPrepareStatementSnafu { + err_msg: "prepare statement only support SELECT for now".to_string(), + } + ); + + Ok(statement) +} + +async fn write_output<'a, W: AsyncWrite + Send + Sync + Unpin>( + w: QueryResultWriter<'a, W>, + query: &str, + outputs: Vec>, +) -> Result<()> { + let mut writer = MysqlResultWriter::new(w); + for output in outputs { + writer.write(query, output).await?; + } + Ok(()) +} + +// dummy columns to satisfy opensrv_mysql, just the number of params is useful +// TODO(SSebo): use parameter type inference to return actual types +fn dummy_params(index: u32) -> Vec { + let mut params = vec![]; + + for _ in 1..index { + params.push(opensrv_mysql::Column { + table: "".to_string(), + column: "".to_string(), + coltype: ColumnType::MYSQL_TYPE_LONG, + colflags: ColumnFlags::NOT_NULL_FLAG, + }); + } + params +} + +fn replace_placeholder(query: &str) -> (String, u32) { + let mut query = query.to_string(); + let mut index = 1; + while let Some(position) = query.find('?') { + let place_holder = format!("${}", index); + query.replace_range(position..position + 1, &place_holder); + index += 1; + } + (query, index) +} diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index 6940b6a882..1553369504 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -171,9 +171,8 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { ConcreteDataType::Int64(_) | ConcreteDataType::UInt64(_) => { Ok(ColumnType::MYSQL_TYPE_LONGLONG) } - ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => { - Ok(ColumnType::MYSQL_TYPE_FLOAT) - } + ConcreteDataType::Float32(_) => Ok(ColumnType::MYSQL_TYPE_FLOAT), + ConcreteDataType::Float64(_) => Ok(ColumnType::MYSQL_TYPE_DOUBLE), ConcreteDataType::Binary(_) | ConcreteDataType::String(_) => { Ok(ColumnType::MYSQL_TYPE_VARCHAR) } @@ -186,6 +185,14 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { } .fail(), }; + let mut colflags = ColumnFlags::empty(); + match column_schema.data_type { + ConcreteDataType::UInt16(_) + | ConcreteDataType::UInt8(_) + | ConcreteDataType::UInt32(_) + | ConcreteDataType::UInt64(_) => colflags |= ColumnFlags::UNSIGNED_FLAG, + _ => {} + }; column_type.map(|column_type| Column { column: column_schema.name.clone(), coltype: column_type, @@ -193,7 +200,7 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server // implementation, will revisit them again in the future. table: "".to_string(), - colflags: ColumnFlags::empty(), + colflags, }) } diff --git a/src/servers/tests/mysql/mod.rs b/src/servers/tests/mysql/mod.rs index 8da117cedc..0e9477f061 100644 --- a/src/servers/tests/mysql/mod.rs +++ b/src/servers/tests/mysql/mod.rs @@ -119,7 +119,7 @@ pub fn all_datatype_testing_data() -> TestingData { ColumnType::MYSQL_TYPE_LONG, ColumnType::MYSQL_TYPE_LONGLONG, ColumnType::MYSQL_TYPE_FLOAT, - ColumnType::MYSQL_TYPE_FLOAT, + ColumnType::MYSQL_TYPE_DOUBLE, ColumnType::MYSQL_TYPE_VARCHAR, ColumnType::MYSQL_TYPE_VARCHAR, ]; diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 7d5a57da31..da5b777e6b 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -19,9 +19,11 @@ use std::time::Duration; use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_recordbatch::RecordBatch; use common_runtime::Builder as RuntimeBuilder; -use datatypes::schema::Schema; +use datatypes::prelude::VectorRef; +use datatypes::schema::{ColumnSchema, Schema}; +use datatypes::value::Value; use mysql_async::prelude::*; -use mysql_async::SslOpts; +use mysql_async::{Conn, Row, SslOpts}; use rand::rngs::StdRng; use rand::Rng; use servers::error::Result; @@ -451,6 +453,95 @@ async fn test_query_concurrently() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_query_prepared() -> Result<()> { + common_telemetry::init_default_ut_logging(); + let TestingData { + column_schemas, + mysql_columns_def: _, + columns, + mysql_text_output_rows: _, + } = all_datatype_testing_data(); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns.clone()).unwrap(); + let table = MemTable::new("all_datatypes", recordbatch); + + let mysql_server = create_mysql_server( + table, + MysqlOpts { + ..Default::default() + }, + )?; + + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + + let mut connection = create_connection_default_db_name(server_addr.port(), false) + .await + .unwrap(); + + test_prepare_all_type(column_schemas, columns, &mut connection).await; + + Ok(()) +} + +async fn test_prepare_all_type( + column_schemas: Vec, + columns: Vec, + connection: &mut Conn, +) { + let mut column_index = 0; + let mut stmt_id = 1; + for schema in column_schemas { + let query = format!( + "SELECT {} FROM all_datatypes WHERE {} = ?", + schema.name, schema.name + ); + let statement = connection.prep(query).await; + let statement = statement.unwrap(); + assert_eq!(stmt_id, statement.id()); + stmt_id += 1; + + let vector_ref = columns.get(column_index).unwrap(); + for vector_index in 0..vector_ref.len() { + let v = vector_ref.get(vector_index); + let v = if let Some(v) = prepare_convert_type(v) { + v + } else { + continue; + }; + + let output: std::result::Result, mysql_async::Error> = + connection.exec(statement.clone(), vec![v]).await; + + assert!(output.is_ok()); + + let rows = output.unwrap(); + assert!(!rows.is_empty()); + } + column_index += 1; + } +} + +fn prepare_convert_type(item: Value) -> Option { + let v = match item { + Value::UInt8(u) => mysql_async::Value::UInt(u as u64), + Value::UInt16(u) => mysql_async::Value::UInt(u as u64), + Value::UInt32(u) => mysql_async::Value::UInt(u as u64), + Value::UInt64(u) => mysql_async::Value::UInt(u), + Value::Int8(i) => mysql_async::Value::Int(i as i64), + Value::Int16(i) => mysql_async::Value::Int(i as i64), + Value::Int32(i) => mysql_async::Value::Int(i as i64), + Value::Int64(i) => mysql_async::Value::Int(i), + Value::Float32(f) => mysql_async::Value::Float(f.into()), + Value::Float64(f) => mysql_async::Value::Double(f.into()), + Value::String(s) => mysql_async::Value::Bytes(s.as_utf8().as_bytes().to_vec()), + Value::Binary(b) => mysql_async::Value::Bytes(b.to_vec()), + _ => return None, + }; + Some(v) +} + async fn create_connection_default_db_name( port: u16, ssl: bool,