mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-29 11:20:38 +00:00
feat: PREPARE and EXECUTE statement from mysql client (#4125)
* feat: prepare stmt in mysql client * feat: execute stmt in mysql client * fix: handle parameters properly * refactor: use existing funcs to convert expr to scalar value * refactor: use uuid strings as stmt_key for queries from COM_PREPARE packet * refactor: take prepare and execute parser as submodule * test: add unit test for converting expr to scalar value * feat: deallocate stmt in mysql client * chore: comments and duplicates --------- Co-authored-by: dennis zhuang <killme2008@gmail.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -9938,6 +9938,7 @@ dependencies = [
|
||||
"tower",
|
||||
"tower-http",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"zstd 0.13.1",
|
||||
]
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ tonic-reflection = "0.11"
|
||||
tower = { workspace = true, features = ["full"] }
|
||||
tower-http = { version = "0.4", features = ["full"] }
|
||||
urlencoding = "2.1"
|
||||
uuid.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
[target.'cfg(not(windows))'.dependencies]
|
||||
|
||||
@@ -59,7 +59,7 @@ pub struct MysqlInstanceShim {
|
||||
salt: [u8; 20],
|
||||
session: SessionRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
prepared_stmts: Arc<RwLock<HashMap<u32, SqlPlan>>>,
|
||||
prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
|
||||
prepared_stmts_counter: AtomicU32,
|
||||
}
|
||||
|
||||
@@ -134,18 +134,88 @@ impl MysqlInstanceShim {
|
||||
self.query_handler.do_describe(statement, query_ctx).await
|
||||
}
|
||||
|
||||
/// Save query and logical plan, return the unique id
|
||||
fn save_plan(&self, plan: SqlPlan) -> u32 {
|
||||
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
|
||||
/// Save query and logical plan with a given statement key
|
||||
fn save_plan(&self, plan: SqlPlan, stmt_key: String) {
|
||||
let mut prepared_stmts = self.prepared_stmts.write();
|
||||
let _ = prepared_stmts.insert(stmt_id, plan);
|
||||
stmt_id
|
||||
let _ = prepared_stmts.insert(stmt_key, plan);
|
||||
}
|
||||
|
||||
/// Retrieve the query and logical plan by id
|
||||
fn plan(&self, stmt_id: u32) -> Option<SqlPlan> {
|
||||
/// Retrieve the query and logical plan by a given statement key
|
||||
fn plan(&self, stmt_key: String) -> Option<SqlPlan> {
|
||||
let guard = self.prepared_stmts.read();
|
||||
guard.get(&stmt_id).cloned()
|
||||
guard.get(&stmt_key).cloned()
|
||||
}
|
||||
|
||||
/// Save the prepared statement and return the parameters and result columns
|
||||
async fn do_prepare(
|
||||
&mut self,
|
||||
raw_query: &str,
|
||||
query_ctx: QueryContextRef,
|
||||
stmt_key: String,
|
||||
) -> Result<(Vec<Column>, Vec<Column>)> {
|
||||
let (query, param_num) = replace_placeholders(raw_query);
|
||||
|
||||
let statement = validate_query(raw_query).await?;
|
||||
|
||||
// We have to transform the placeholder, because DataFusion only parses placeholders
|
||||
// in the form of "$i", it can't process "?" right now.
|
||||
let statement = transform_placeholders(statement);
|
||||
|
||||
let describe_result = self
|
||||
.do_describe(statement.clone(), query_ctx.clone())
|
||||
.await?;
|
||||
let (plan, schema) = if let Some(DescribeResult {
|
||||
logical_plan,
|
||||
schema,
|
||||
}) = describe_result
|
||||
{
|
||||
(Some(logical_plan), Some(schema))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let params = if let Some(plan) = &plan {
|
||||
prepared_params(
|
||||
&plan
|
||||
.get_param_types()
|
||||
.context(error::GetPreparedStmtParamsSnafu)?,
|
||||
)?
|
||||
} else {
|
||||
dummy_params(param_num)?
|
||||
};
|
||||
|
||||
debug_assert_eq!(params.len(), param_num - 1);
|
||||
|
||||
let columns = schema
|
||||
.as_ref()
|
||||
.map(|schema| {
|
||||
schema
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.map(|column_schema| {
|
||||
create_mysql_column(&column_schema.data_type, &column_schema.name)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
self.save_plan(
|
||||
SqlPlan {
|
||||
query: query.to_string(),
|
||||
plan,
|
||||
schema,
|
||||
},
|
||||
stmt_key,
|
||||
);
|
||||
|
||||
Ok((params, columns))
|
||||
}
|
||||
|
||||
/// Remove the prepared statement by a given statement key
|
||||
fn do_close(&mut self, stmt_key: String) {
|
||||
let mut guard = self.prepared_stmts.write();
|
||||
let _ = guard.remove(&stmt_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,59 +280,11 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
w: StatementMetaWriter<'a, W>,
|
||||
) -> Result<()> {
|
||||
let query_ctx = self.session.new_query_context();
|
||||
let (query, param_num) = replace_placeholders(raw_query);
|
||||
|
||||
let statement = validate_query(raw_query).await?;
|
||||
|
||||
// We have to transform the placeholder, because DataFusion only parses placeholders
|
||||
// in the form of "$i", it can't process "?" right now.
|
||||
let statement = transform_placeholders(statement);
|
||||
|
||||
let describe_result = self
|
||||
.do_describe(statement.clone(), query_ctx.clone())
|
||||
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
|
||||
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
|
||||
let (params, columns) = self
|
||||
.do_prepare(raw_query, query_ctx.clone(), stmt_key)
|
||||
.await?;
|
||||
let (plan, schema) = if let Some(DescribeResult {
|
||||
logical_plan,
|
||||
schema,
|
||||
}) = describe_result
|
||||
{
|
||||
(Some(logical_plan), Some(schema))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let params = if let Some(plan) = &plan {
|
||||
prepared_params(
|
||||
&plan
|
||||
.get_param_types()
|
||||
.context(error::GetPreparedStmtParamsSnafu)?,
|
||||
)?
|
||||
} else {
|
||||
dummy_params(param_num)?
|
||||
};
|
||||
|
||||
debug_assert_eq!(params.len(), param_num - 1);
|
||||
|
||||
let columns = schema
|
||||
.as_ref()
|
||||
.map(|schema| {
|
||||
schema
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.map(|column_schema| {
|
||||
create_mysql_column(&column_schema.data_type, &column_schema.name)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
let stmt_id = self.save_plan(SqlPlan {
|
||||
query: query.to_string(),
|
||||
plan,
|
||||
schema,
|
||||
});
|
||||
|
||||
w.reply(stmt_id, ¶ms, &columns).await?;
|
||||
crate::metrics::METRIC_MYSQL_PREPARED_COUNT
|
||||
.with_label_values(&[query_ctx.get_db_string().as_str()])
|
||||
@@ -283,11 +305,12 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
.start_timer();
|
||||
|
||||
let params: Vec<ParamValue> = p.into_iter().collect();
|
||||
let sql_plan = match self.plan(stmt_id) {
|
||||
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
|
||||
let sql_plan = match self.plan(stmt_key) {
|
||||
None => {
|
||||
w.error(
|
||||
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
|
||||
b"prepare statement not exist",
|
||||
b"prepare statement not found",
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
@@ -334,7 +357,11 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
]
|
||||
}
|
||||
None => {
|
||||
let query = replace_params(params, sql_plan.query);
|
||||
let param_strs = params
|
||||
.iter()
|
||||
.map(|x| convert_param_value_to_string(x))
|
||||
.collect();
|
||||
let query = replace_params(param_strs, sql_plan.query);
|
||||
debug!("Mysql execute replaced query: {}", query);
|
||||
self.do_query(&query, query_ctx.clone()).await
|
||||
}
|
||||
@@ -349,8 +376,8 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
where
|
||||
W: 'async_trait,
|
||||
{
|
||||
let mut guard = self.prepared_stmts.write();
|
||||
let _ = guard.remove(&stmt_id);
|
||||
let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
|
||||
self.do_close(stmt_key);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(protocol = "mysql"))]
|
||||
@@ -364,6 +391,130 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
|
||||
.with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
|
||||
.start_timer();
|
||||
|
||||
let query_upcase = query.to_uppercase();
|
||||
if query_upcase.starts_with("PREPARE ") {
|
||||
match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
|
||||
Ok((stmt_name, stmt)) => {
|
||||
let prepare_results =
|
||||
self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
|
||||
match prepare_results {
|
||||
Ok(_) => {
|
||||
let outputs = vec![Ok(Output::new_with_affected_rows(0))];
|
||||
writer::write_output(writer, query_ctx, outputs).await?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
writer
|
||||
.error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
writer
|
||||
.error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else if query_upcase.starts_with("EXECUTE ") {
|
||||
match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
|
||||
// TODO: similar to on_execute, refactor this
|
||||
Ok((stmt_name, params)) => {
|
||||
let sql_plan = match self.plan(stmt_name) {
|
||||
None => {
|
||||
writer
|
||||
.error(
|
||||
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
|
||||
b"prepare statement not found",
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(sql_plan) => sql_plan,
|
||||
};
|
||||
|
||||
let outputs = match sql_plan.plan {
|
||||
Some(plan) => {
|
||||
let param_types = plan
|
||||
.get_param_types()
|
||||
.context(error::GetPreparedStmtParamsSnafu)?;
|
||||
|
||||
if params.len() != param_types.len() {
|
||||
writer
|
||||
.error(
|
||||
ErrorKind::ER_SP_BADSTATEMENT,
|
||||
b"prepare statement params number mismatch",
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let plan = match replace_params_with_exprs(&plan, param_types, ¶ms)
|
||||
{
|
||||
Ok(plan) => plan,
|
||||
Err(e) => {
|
||||
if e.status_code().should_log_error() {
|
||||
error!(e; "params: {}", params
|
||||
.iter()
|
||||
.map(|x| format!("({:?})", x))
|
||||
.join(", "));
|
||||
}
|
||||
|
||||
writer
|
||||
.error(
|
||||
ErrorKind::ER_TRUNCATED_WRONG_VALUE,
|
||||
e.output_msg().as_bytes(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Mysql execute prepared plan: {}", plan.display_indent());
|
||||
vec![
|
||||
self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
|
||||
.await,
|
||||
]
|
||||
}
|
||||
None => {
|
||||
let param_strs = params.iter().map(|x| x.to_string()).collect();
|
||||
let query = replace_params(param_strs, sql_plan.query);
|
||||
debug!("Mysql execute replaced query: {}", query);
|
||||
let outputs = self.do_query(&query, query_ctx.clone()).await;
|
||||
writer::write_output(writer, query_ctx, outputs).await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
writer::write_output(writer, query_ctx, outputs).await?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
writer
|
||||
.error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else if query_upcase.starts_with("DEALLOCATE ") {
|
||||
match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
|
||||
Ok(stmt_name) => {
|
||||
self.do_close(stmt_name);
|
||||
let outputs = vec![Ok(Output::new_with_affected_rows(0))];
|
||||
writer::write_output(writer, query_ctx, outputs).await?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
writer
|
||||
.error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let outputs = self.do_query(query, query_ctx.clone()).await;
|
||||
writer::write_output(writer, query_ctx, outputs).await?;
|
||||
Ok(())
|
||||
@@ -420,21 +571,24 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_params(params: Vec<ParamValue>, query: String) -> String {
|
||||
fn convert_param_value_to_string(param: &ParamValue) -> String {
|
||||
match param.value.into_inner() {
|
||||
ValueInner::Int(u) => u.to_string(),
|
||||
ValueInner::UInt(u) => u.to_string(),
|
||||
ValueInner::Double(u) => u.to_string(),
|
||||
ValueInner::NULL => "NULL".to_string(),
|
||||
ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
|
||||
ValueInner::Date(_) => NaiveDate::from(param.value).to_string(),
|
||||
ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(),
|
||||
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_params(params: Vec<String>, query: String) -> String {
|
||||
let mut query = query;
|
||||
let mut index = 1;
|
||||
for param in params {
|
||||
let s = match param.value.into_inner() {
|
||||
ValueInner::Int(u) => u.to_string(),
|
||||
ValueInner::UInt(u) => u.to_string(),
|
||||
ValueInner::Double(u) => u.to_string(),
|
||||
ValueInner::NULL => "NULL".to_string(),
|
||||
ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
|
||||
ValueInner::Date(_) => NaiveDate::from(param.value).to_string(),
|
||||
ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(),
|
||||
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
|
||||
};
|
||||
query = query.replace(&format_placeholder(index), &s);
|
||||
query = query.replace(&format_placeholder(index), ¶m);
|
||||
index += 1;
|
||||
}
|
||||
query
|
||||
@@ -477,6 +631,33 @@ fn replace_params_with_values(
|
||||
.context(error::ReplacePreparedStmtParamsSnafu)
|
||||
}
|
||||
|
||||
fn replace_params_with_exprs(
|
||||
plan: &LogicalPlan,
|
||||
param_types: HashMap<String, Option<ConcreteDataType>>,
|
||||
params: &[sql::ast::Expr],
|
||||
) -> Result<LogicalPlan> {
|
||||
debug_assert_eq!(param_types.len(), params.len());
|
||||
|
||||
debug!(
|
||||
"replace_params_with_exprs(param_types: {:#?}, params: {:#?})",
|
||||
param_types,
|
||||
params.iter().map(|x| format!("({:?})", x)).join(", ")
|
||||
);
|
||||
|
||||
let mut values = Vec::with_capacity(params.len());
|
||||
|
||||
for (i, param) in params.iter().enumerate() {
|
||||
if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
|
||||
let value = helper::convert_expr_to_scalar_value(param, t)?;
|
||||
|
||||
values.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
plan.replace_params_with_values(&values)
|
||||
.context(error::ReplacePreparedStmtParamsSnafu)
|
||||
}
|
||||
|
||||
async fn validate_query(query: &str) -> Result<Statement> {
|
||||
let statement =
|
||||
ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
|
||||
|
||||
@@ -23,6 +23,7 @@ use itertools::Itertools;
|
||||
use opensrv_mysql::{to_naive_datetime, ParamValue, ValueInner};
|
||||
use snafu::ResultExt;
|
||||
use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut};
|
||||
use sql::statements::sql_value_to_value;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
@@ -201,6 +202,27 @@ pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarV
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
|
||||
match param {
|
||||
Expr::Value(v) => {
|
||||
let v = sql_value_to_value("", t, v, None);
|
||||
match v {
|
||||
Ok(v) => v
|
||||
.try_to_scalar_value(t)
|
||||
.context(error::ConvertScalarValueSnafu),
|
||||
Err(e) => error::InvalidParameterSnafu {
|
||||
reason: e.to_string(),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
_ => error::InvalidParameterSnafu {
|
||||
reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sql::dialect::MySqlDialect;
|
||||
@@ -265,4 +287,45 @@ mod tests {
|
||||
};
|
||||
assert_eq!("SELECT from AS demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_expr_to_scalar_value() {
|
||||
let expr = Expr::Value(ValueExpr::Number("123".to_string(), false));
|
||||
let t = ConcreteDataType::int32_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
assert_eq!(ScalarValue::Int32(Some(123)), v);
|
||||
|
||||
let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false));
|
||||
let t = ConcreteDataType::float64_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
|
||||
|
||||
let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()));
|
||||
let t = ConcreteDataType::date_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
|
||||
.cast_to(&arrow_schema::DataType::Date32)
|
||||
.unwrap();
|
||||
assert_eq!(scalar_v, v);
|
||||
|
||||
let expr = Expr::Value(ValueExpr::SingleQuotedString(
|
||||
"2001-01-02 03:04:05".to_string(),
|
||||
));
|
||||
let t = ConcreteDataType::datetime_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
|
||||
.cast_to(&arrow_schema::DataType::Date64)
|
||||
.unwrap();
|
||||
assert_eq!(scalar_v, v);
|
||||
|
||||
let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()));
|
||||
let t = ConcreteDataType::string_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
|
||||
|
||||
let expr = Expr::Value(ValueExpr::Null);
|
||||
let t = ConcreteDataType::time_microsecond_datatype();
|
||||
let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
|
||||
assert_eq!(ScalarValue::Time64Microsecond(None), v);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,6 +175,27 @@ impl<'a> ParserContext<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
|
||||
pub fn parse_mysql_prepare_stmt(
|
||||
sql: &'a str,
|
||||
dialect: &dyn Dialect,
|
||||
) -> Result<(String, String)> {
|
||||
ParserContext::new(dialect, sql)?.parse_mysql_prepare()
|
||||
}
|
||||
|
||||
/// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
|
||||
pub fn parse_mysql_execute_stmt(
|
||||
sql: &'a str,
|
||||
dialect: &dyn Dialect,
|
||||
) -> Result<(String, Vec<Expr>)> {
|
||||
ParserContext::new(dialect, sql)?.parse_mysql_execute()
|
||||
}
|
||||
|
||||
/// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string.
|
||||
pub fn parse_mysql_deallocate_stmt(sql: &'a str, dialect: &dyn Dialect) -> Result<String> {
|
||||
ParserContext::new(dialect, sql)?.parse_deallocate()
|
||||
}
|
||||
|
||||
/// Raises an "unsupported statement" error.
|
||||
pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
|
||||
error::UnsupportedSnafu {
|
||||
@@ -257,6 +278,7 @@ impl<'a> ParserContext<'a> {
|
||||
mod tests {
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use sqlparser::dialect::MySqlDialect;
|
||||
|
||||
use super::*;
|
||||
use crate::dialect::GreptimeDbDialect;
|
||||
@@ -351,4 +373,57 @@ mod tests {
|
||||
assert_eq!(object_name.0.len(), 1);
|
||||
assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_mysql_prepare_stmt() {
|
||||
let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
|
||||
let (stmt_name, stmt) =
|
||||
ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt1");
|
||||
assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
|
||||
|
||||
let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
|
||||
let (stmt_name, stmt) =
|
||||
ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt2");
|
||||
assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_mysql_execute_stmt() {
|
||||
let sql = "EXECUTE stmt1 USING 1, 'hello';";
|
||||
let (stmt_name, params) =
|
||||
ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt1");
|
||||
assert_eq!(params.len(), 2);
|
||||
assert_eq!(params[0].to_string(), "1");
|
||||
assert_eq!(params[1].to_string(), "'hello'");
|
||||
|
||||
let sql = "EXECUTE stmt2;";
|
||||
let (stmt_name, params) =
|
||||
ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt2");
|
||||
assert_eq!(params.len(), 0);
|
||||
|
||||
let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
|
||||
let (stmt_name, params) =
|
||||
ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt3");
|
||||
assert_eq!(params.len(), 4);
|
||||
assert_eq!(params[0].to_string(), "231");
|
||||
assert_eq!(params[1].to_string(), "'hello'");
|
||||
assert_eq!(params[2].to_string(), "\"2003-03-1\"");
|
||||
assert_eq!(params[3].to_string(), "NULL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_mysql_deallocate_stmt() {
|
||||
let sql = "DEALLOCATE stmt1;";
|
||||
let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt1");
|
||||
|
||||
let sql = "DEALLOCATE stmt2";
|
||||
let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt2");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,12 +15,15 @@
|
||||
mod alter_parser;
|
||||
pub(crate) mod copy_parser;
|
||||
pub(crate) mod create_parser;
|
||||
pub(crate) mod deallocate_parser;
|
||||
pub(crate) mod delete_parser;
|
||||
pub(crate) mod describe_parser;
|
||||
pub(crate) mod drop_parser;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod execute_parser;
|
||||
pub(crate) mod explain_parser;
|
||||
pub(crate) mod insert_parser;
|
||||
pub(crate) mod prepare_parser;
|
||||
pub(crate) mod query_parser;
|
||||
pub(crate) mod set_var_parser;
|
||||
pub(crate) mod show_parser;
|
||||
|
||||
30
src/sql/src/parsers/deallocate_parser.rs
Normal file
30
src/sql/src/parsers/deallocate_parser.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use snafu::ResultExt;
|
||||
use sqlparser::keywords::Keyword;
|
||||
|
||||
use crate::error::{Result, SyntaxSnafu};
|
||||
use crate::parser::ParserContext;
|
||||
|
||||
impl<'a> ParserContext<'a> {
|
||||
/// Parses MySQL style 'PREPARE stmt_name' into a stmt_name string.
|
||||
pub(crate) fn parse_deallocate(&mut self) -> Result<String> {
|
||||
self.parser
|
||||
.expect_keyword(Keyword::DEALLOCATE)
|
||||
.context(SyntaxSnafu)?;
|
||||
let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?;
|
||||
Ok(stmt_name.value)
|
||||
}
|
||||
}
|
||||
41
src/sql/src/parsers/execute_parser.rs
Normal file
41
src/sql/src/parsers/execute_parser.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use snafu::ResultExt;
|
||||
use sqlparser::ast::Expr;
|
||||
use sqlparser::keywords::Keyword;
|
||||
use sqlparser::parser::Parser;
|
||||
|
||||
use crate::error::{Result, SyntaxSnafu};
|
||||
use crate::parser::ParserContext;
|
||||
|
||||
impl<'a> ParserContext<'a> {
|
||||
/// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
|
||||
/// Only use for MySQL. for PostgreSQL, use `sqlparser::parser::Parser::parse_execute` instead.
|
||||
pub(crate) fn parse_mysql_execute(&mut self) -> Result<(String, Vec<Expr>)> {
|
||||
self.parser
|
||||
.expect_keyword(Keyword::EXECUTE)
|
||||
.context(SyntaxSnafu)?;
|
||||
let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?;
|
||||
if self.parser.parse_keyword(Keyword::USING) {
|
||||
let param_list = self
|
||||
.parser
|
||||
.parse_comma_separated(Parser::parse_expr)
|
||||
.context(SyntaxSnafu)?;
|
||||
Ok((stmt_name.value, param_list))
|
||||
} else {
|
||||
Ok((stmt_name.value, vec![]))
|
||||
}
|
||||
}
|
||||
}
|
||||
46
src/sql/src/parsers/prepare_parser.rs
Normal file
46
src/sql/src/parsers/prepare_parser.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use snafu::ResultExt;
|
||||
use sqlparser::keywords::Keyword;
|
||||
use sqlparser::tokenizer::Token;
|
||||
|
||||
use crate::error::{Result, SyntaxSnafu};
|
||||
use crate::parser::ParserContext;
|
||||
|
||||
impl<'a> ParserContext<'a> {
|
||||
/// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
|
||||
/// Only use for MySQL. for PostgreSQL, use `sqlparser::parser::Parser::parse_prepare` instead.
|
||||
pub(crate) fn parse_mysql_prepare(&mut self) -> Result<(String, String)> {
|
||||
self.parser
|
||||
.expect_keyword(Keyword::PREPARE)
|
||||
.context(SyntaxSnafu)?;
|
||||
let stmt_name = self.parser.parse_identifier(false).context(SyntaxSnafu)?;
|
||||
self.parser
|
||||
.expect_keyword(Keyword::FROM)
|
||||
.context(SyntaxSnafu)?;
|
||||
let next_token = self.parser.peek_token();
|
||||
let stmt = match next_token.token {
|
||||
Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => {
|
||||
let _ = self.parser.next_token();
|
||||
s
|
||||
}
|
||||
_ => self
|
||||
.parser
|
||||
.expected("string literal", next_token)
|
||||
.context(SyntaxSnafu)?,
|
||||
};
|
||||
Ok((stmt_name.value, stmt))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user