feat: mysql prepare replacing sql placeholder to param (#1086)

* feat: mysql prepare by replace ? in sql to param

* chore: mysql prepare statment support time param

* chore: prepare test more types

* chore: add TODO
This commit is contained in:
SSebo
2023-03-08 11:02:29 +08:00
committed by GitHub
parent 3a527c0fd5
commit 95090592f0
7 changed files with 274 additions and 26 deletions

1
Cargo.lock generated
View File

@@ -6949,6 +6949,7 @@ dependencies = [
"once_cell",
"openmetrics-parser",
"opensrv-mysql",
"parking_lot",
"pgwire",
"pin-project",
"postgres-types",

View File

@@ -43,6 +43,7 @@ num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = { git = "https://github.com/sunng87/opensrv", branch = "fix/buffer-overread" }
parking_lot = "0.12"
pgwire = "0.10"
pin-project = "1.0"
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }

View File

@@ -263,6 +263,8 @@ pub enum Error {
#[snafu(backtrace)]
source: common_mem_prof::error::Error,
},
#[snafu(display("Invalid prepare statement: {}", err_msg))]
InvalidPrepareStatement { err_msg: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -302,6 +304,7 @@ impl ErrorExt for Error {
| DecompressPromRemoteRequest { .. }
| InvalidPromRemoteRequest { .. }
| InvalidFlightTicket { .. }
| InvalidPrepareStatement { .. }
| TimePrecision { .. } => StatusCode::InvalidArguments,
InfluxdbLinesWrite { source, .. } | ConvertFlightMessage { source } => {

View File

@@ -12,24 +12,33 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use chrono::{NaiveDate, NaiveDateTime};
use common_query::Output;
use common_telemetry::tracing::log;
use common_telemetry::{error, trace};
use opensrv_mysql::{
AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter,
AsyncMysqlShim, Column, ColumnFlags, ColumnType, ErrorKind, InitWriter, ParamParser,
ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner,
};
use parking_lot::RwLock;
use rand::RngCore;
use session::context::Channel;
use session::Session;
use snafu::ensure;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error::{self, Result};
use crate::error::{self, InvalidPrepareStatementSnafu, Result};
use crate::mysql::writer::MysqlResultWriter;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
@@ -39,6 +48,9 @@ pub struct MysqlInstanceShim {
salt: [u8; 20],
session: Arc<Session>,
user_provider: Option<UserProviderRef>,
// TODO(SSebo): use something like moka to achieve TTL or LRU
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
prepared_stmts_counter: AtomicU32,
}
impl MysqlInstanceShim {
@@ -65,6 +77,8 @@ impl MysqlInstanceShim {
salt: scramble,
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),
}
}
@@ -91,6 +105,18 @@ impl MysqlInstanceShim {
);
output
}
fn set_query(&self, query: String) -> u32 {
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::SeqCst);
let mut guard = self.prepared_stmts.write();
guard.insert(stmt_id, query);
stmt_id
}
fn query(&self, stmt_id: u32) -> Option<String> {
let guard = self.prepared_stmts.read();
guard.get(&stmt_id).map(|s| s.to_owned())
}
}
#[async_trait]
@@ -140,34 +166,59 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
true
}
async fn on_prepare<'a>(&'a mut self, _: &'a str, w: StatementMetaWriter<'a, W>) -> Result<()> {
w.error(
ErrorKind::ER_UNKNOWN_ERROR,
b"prepare statement is not supported yet",
)
.await?;
Ok(())
async fn on_prepare<'a>(
&'a mut self,
query: &'a str,
w: StatementMetaWriter<'a, W>,
) -> Result<()> {
let (query, param_num) = replace_placeholder(query);
if let Err(e) = validate_query(&query).await {
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
.await?;
return Ok(());
};
let stmt_id = self.set_query(query);
let params = dummy_params(param_num);
w.reply(stmt_id, &params, &[]).await?;
return Ok(());
}
async fn on_execute<'a>(
&'a mut self,
_: u32,
_: ParamParser<'a>,
stmt_id: u32,
p: ParamParser<'a>,
w: QueryResultWriter<'a, W>,
) -> Result<()> {
w.error(
ErrorKind::ER_UNKNOWN_ERROR,
b"prepare statement is not supported yet",
)
.await?;
let params: Vec<ParamValue> = p.into_iter().collect();
let query = match self.query(stmt_id) {
None => {
w.error(
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
b"prepare statement not exist",
)
.await?;
return Ok(());
}
Some(query) => query,
};
let query = replace_params(params, query);
log::debug!("execute replaced query: {}", query);
let outputs = self.do_query(&query).await;
write_output(w, &query, outputs).await?;
Ok(())
}
async fn on_close<'a>(&'a mut self, _stmt_id: u32)
async fn on_close<'a>(&'a mut self, stmt_id: u32)
where
W: 'async_trait,
{
// do nothing because we haven't implemented prepare statement
let mut guard = self.prepared_stmts.write();
guard.remove(&stmt_id);
}
async fn on_query<'a>(
@@ -211,3 +262,97 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
w.ok().await.map_err(|e| e.into())
}
}
fn replace_params(params: Vec<ParamValue>, query: String) -> String {
let mut query = query;
let mut index = 1;
for param in params {
let s = match param.value.into_inner() {
ValueInner::Int(u) => u.to_string(),
ValueInner::UInt(u) => u.to_string(),
ValueInner::Double(u) => u.to_string(),
ValueInner::NULL => "NULL".to_string(),
ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
ValueInner::Date(_) => NaiveDate::from(param.value).to_string(),
ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(),
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
};
query = query.replace(&format!("${}", index), &s);
index += 1;
}
query
}
fn format_duration(duration: Duration) -> String {
let seconds = duration.as_secs() % 60;
let minutes = (duration.as_secs() / 60) % 60;
let hours = (duration.as_secs() / 60) / 60;
format!("{}:{}:{}", hours, minutes, seconds)
}
async fn validate_query(query: &str) -> Result<Statement> {
let statement = ParserContext::create_with_dialect(query, &GenericDialect {});
let mut statement = statement.map_err(|e| {
InvalidPrepareStatementSnafu {
err_msg: e.to_string(),
}
.build()
})?;
ensure!(
statement.len() == 1,
InvalidPrepareStatementSnafu {
err_msg: "prepare statement only support single statement".to_string(),
}
);
let statement = statement.remove(0);
ensure!(
matches!(statement, Statement::Query(_)),
InvalidPrepareStatementSnafu {
err_msg: "prepare statement only support SELECT for now".to_string(),
}
);
Ok(statement)
}
async fn write_output<'a, W: AsyncWrite + Send + Sync + Unpin>(
w: QueryResultWriter<'a, W>,
query: &str,
outputs: Vec<Result<Output>>,
) -> Result<()> {
let mut writer = MysqlResultWriter::new(w);
for output in outputs {
writer.write(query, output).await?;
}
Ok(())
}
// dummy columns to satisfy opensrv_mysql, just the number of params is useful
// TODO(SSebo): use parameter type inference to return actual types
fn dummy_params(index: u32) -> Vec<Column> {
let mut params = vec![];
for _ in 1..index {
params.push(opensrv_mysql::Column {
table: "".to_string(),
column: "".to_string(),
coltype: ColumnType::MYSQL_TYPE_LONG,
colflags: ColumnFlags::NOT_NULL_FLAG,
});
}
params
}
fn replace_placeholder(query: &str) -> (String, u32) {
let mut query = query.to_string();
let mut index = 1;
while let Some(position) = query.find('?') {
let place_holder = format!("${}", index);
query.replace_range(position..position + 1, &place_holder);
index += 1;
}
(query, index)
}

View File

@@ -171,9 +171,8 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
ConcreteDataType::Int64(_) | ConcreteDataType::UInt64(_) => {
Ok(ColumnType::MYSQL_TYPE_LONGLONG)
}
ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
Ok(ColumnType::MYSQL_TYPE_FLOAT)
}
ConcreteDataType::Float32(_) => Ok(ColumnType::MYSQL_TYPE_FLOAT),
ConcreteDataType::Float64(_) => Ok(ColumnType::MYSQL_TYPE_DOUBLE),
ConcreteDataType::Binary(_) | ConcreteDataType::String(_) => {
Ok(ColumnType::MYSQL_TYPE_VARCHAR)
}
@@ -186,6 +185,14 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
}
.fail(),
};
let mut colflags = ColumnFlags::empty();
match column_schema.data_type {
ConcreteDataType::UInt16(_)
| ConcreteDataType::UInt8(_)
| ConcreteDataType::UInt32(_)
| ConcreteDataType::UInt64(_) => colflags |= ColumnFlags::UNSIGNED_FLAG,
_ => {}
};
column_type.map(|column_type| Column {
column: column_schema.name.clone(),
coltype: column_type,
@@ -193,7 +200,7 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
// TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server
// implementation, will revisit them again in the future.
table: "".to_string(),
colflags: ColumnFlags::empty(),
colflags,
})
}

View File

@@ -119,7 +119,7 @@ pub fn all_datatype_testing_data() -> TestingData {
ColumnType::MYSQL_TYPE_LONG,
ColumnType::MYSQL_TYPE_LONGLONG,
ColumnType::MYSQL_TYPE_FLOAT,
ColumnType::MYSQL_TYPE_FLOAT,
ColumnType::MYSQL_TYPE_DOUBLE,
ColumnType::MYSQL_TYPE_VARCHAR,
ColumnType::MYSQL_TYPE_VARCHAR,
];

View File

@@ -19,9 +19,11 @@ use std::time::Duration;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_recordbatch::RecordBatch;
use common_runtime::Builder as RuntimeBuilder;
use datatypes::schema::Schema;
use datatypes::prelude::VectorRef;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::Value;
use mysql_async::prelude::*;
use mysql_async::SslOpts;
use mysql_async::{Conn, Row, SslOpts};
use rand::rngs::StdRng;
use rand::Rng;
use servers::error::Result;
@@ -451,6 +453,95 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_query_prepared() -> Result<()> {
common_telemetry::init_default_ut_logging();
let TestingData {
column_schemas,
mysql_columns_def: _,
columns,
mysql_text_output_rows: _,
} = all_datatype_testing_data();
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns.clone()).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
table,
MysqlOpts {
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection_default_db_name(server_addr.port(), false)
.await
.unwrap();
test_prepare_all_type(column_schemas, columns, &mut connection).await;
Ok(())
}
async fn test_prepare_all_type(
column_schemas: Vec<ColumnSchema>,
columns: Vec<VectorRef>,
connection: &mut Conn,
) {
let mut column_index = 0;
let mut stmt_id = 1;
for schema in column_schemas {
let query = format!(
"SELECT {} FROM all_datatypes WHERE {} = ?",
schema.name, schema.name
);
let statement = connection.prep(query).await;
let statement = statement.unwrap();
assert_eq!(stmt_id, statement.id());
stmt_id += 1;
let vector_ref = columns.get(column_index).unwrap();
for vector_index in 0..vector_ref.len() {
let v = vector_ref.get(vector_index);
let v = if let Some(v) = prepare_convert_type(v) {
v
} else {
continue;
};
let output: std::result::Result<Vec<Row>, mysql_async::Error> =
connection.exec(statement.clone(), vec![v]).await;
assert!(output.is_ok());
let rows = output.unwrap();
assert!(!rows.is_empty());
}
column_index += 1;
}
}
fn prepare_convert_type(item: Value) -> Option<mysql_async::Value> {
let v = match item {
Value::UInt8(u) => mysql_async::Value::UInt(u as u64),
Value::UInt16(u) => mysql_async::Value::UInt(u as u64),
Value::UInt32(u) => mysql_async::Value::UInt(u as u64),
Value::UInt64(u) => mysql_async::Value::UInt(u),
Value::Int8(i) => mysql_async::Value::Int(i as i64),
Value::Int16(i) => mysql_async::Value::Int(i as i64),
Value::Int32(i) => mysql_async::Value::Int(i as i64),
Value::Int64(i) => mysql_async::Value::Int(i),
Value::Float32(f) => mysql_async::Value::Float(f.into()),
Value::Float64(f) => mysql_async::Value::Double(f.into()),
Value::String(s) => mysql_async::Value::Bytes(s.as_utf8().as_bytes().to_vec()),
Value::Binary(b) => mysql_async::Value::Bytes(b.to_vec()),
_ => return None,
};
Some(v)
}
async fn create_connection_default_db_name(
port: u16,
ssl: bool,