From 30472cebaeee399ce163780ddb9714806c012ebf Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Tue, 20 Jun 2023 12:07:28 +0800 Subject: [PATCH] feat: prepare supports caching logical plan and infering param types (#1776) * feat: change do_describe function signature * feat: infer param type and cache logical plan for msyql prepared statments * fix: convert_value * fix: forgot helper * chore: comments * fix: typo * test: add more tests and test date, datatime in mysql * chore: fix CR comments * chore: add location * chore: by CR comments * Update tests-integration/tests/sql.rs Co-authored-by: Ruihang Xia * chore: remove the trace --------- Co-authored-by: Ruihang Xia --- Cargo.lock | 4 + src/common/time/src/date.rs | 6 + src/datatypes/src/data_type.rs | 6 + src/datatypes/src/value.rs | 6 +- src/frontend/src/instance.rs | 22 +- src/frontend/src/metrics.rs | 1 + src/query/src/datafusion.rs | 18 +- src/query/src/plan.rs | 29 ++- src/query/src/planner.rs | 1 + src/query/src/query_engine.rs | 11 +- src/servers/Cargo.toml | 3 + src/servers/src/error.rs | 52 ++++- src/servers/src/http.rs | 12 +- src/servers/src/mysql.rs | 1 + src/servers/src/mysql/handler.rs | 184 +++++++++++++---- src/servers/src/mysql/helper.rs | 238 ++++++++++++++++++++++ src/servers/src/mysql/writer.rs | 24 +-- src/servers/src/postgres/handler.rs | 3 +- src/servers/src/query_handler/sql.rs | 21 +- src/servers/tests/http/influxdb_test.rs | 13 +- src/servers/tests/http/opentsdb_test.rs | 13 +- src/servers/tests/http/prometheus_test.rs | 13 +- src/servers/tests/mod.rs | 9 +- src/sql/src/ast.rs | 6 +- tests-integration/Cargo.toml | 2 + tests-integration/tests/sql.rs | 50 ++++- 26 files changed, 650 insertions(+), 98 deletions(-) create mode 100644 src/servers/src/mysql/helper.rs diff --git a/Cargo.lock b/Cargo.lock index 36ace61d41..d17d0c05bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8537,6 +8537,8 @@ dependencies = [ "common-test-util", "common-time", "datafusion", + "datafusion-common", + "datafusion-expr", "datatypes", "derive_builder 0.12.0", "digest", @@ -8973,6 +8975,7 @@ dependencies = [ "bitflags 1.3.2", "byteorder", "bytes", + "chrono", "crc", "crossbeam-queue", "digest", @@ -9553,6 +9556,7 @@ dependencies = [ "axum", "axum-test-helper", "catalog", + "chrono", "client", "common-base", "common-catalog", diff --git a/src/common/time/src/date.rs b/src/common/time/src/date.rs index fff9f412db..4540490111 100644 --- a/src/common/time/src/date.rs +++ b/src/common/time/src/date.rs @@ -52,6 +52,12 @@ impl From for Date { } } +impl From for Date { + fn from(date: NaiveDate) -> Self { + Self(date.num_days_from_ce() - UNIX_EPOCH_FROM_CE) + } +} + impl Display for Date { /// [Date] is formatted according to ISO-8601 standard. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/src/datatypes/src/data_type.rs b/src/datatypes/src/data_type.rs index f9e40880c8..f0d225fdb9 100644 --- a/src/datatypes/src/data_type.rs +++ b/src/datatypes/src/data_type.rs @@ -183,6 +183,12 @@ impl ConcreteDataType { } } +impl From<&ConcreteDataType> for ConcreteDataType { + fn from(t: &ConcreteDataType) -> Self { + t.clone() + } +} + impl TryFrom<&ArrowDataType> for ConcreteDataType { type Error = Error; diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 733b853808..319da1066d 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -248,7 +248,7 @@ impl Value { Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())), Value::Date(v) => ScalarValue::Date32(Some(v.val())), Value::DateTime(v) => ScalarValue::Date64(Some(v.val())), - Value::Null => to_null_value(output_type), + Value::Null => to_null_scalar_value(output_type), Value::List(list) => { // Safety: The logical type of the value and output_type are the same. let list_type = output_type.as_list().unwrap(); @@ -261,7 +261,7 @@ impl Value { } } -fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue { +pub fn to_null_scalar_value(output_type: &ConcreteDataType) -> ScalarValue { match output_type { ConcreteDataType::Null(_) => ScalarValue::Null, ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None), @@ -285,7 +285,7 @@ fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue { } ConcreteDataType::Dictionary(dict) => ScalarValue::Dictionary( Box::new(dict.key_type().as_arrow_type()), - Box::new(to_null_value(dict.value_type())), + Box::new(to_null_scalar_value(dict.value_type())), ), } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index f76a44c00a..3507b7e36e 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -53,7 +53,9 @@ use meta_client::MetaClientOptions; use partition::manager::PartitionRuleManager; use partition::route::TableRoutes; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; +use query::plan::LogicalPlan; use query::query_engine::options::{validate_catalog_and_schema, QueryOptions}; +use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu}; @@ -73,8 +75,9 @@ use sql::statements::statement::Statement; use crate::catalog::FrontendCatalogManager; use crate::error::{ - self, Error, ExecutePromqlSnafu, ExternalSnafu, InvalidInsertRequestSnafu, - MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu, + self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu, + InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result, + SqlExecInterceptedSnafu, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; @@ -506,6 +509,14 @@ impl SqlQueryHandler for Instance { } } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + let _timer = timer!(metrics::METRIC_EXEC_PLAN_ELAPSED); + self.query_engine + .execute(plan, query_ctx) + .await + .context(ExecLogicalPlanSnafu) + } + async fn do_promql_query( &self, query: &PromQuery, @@ -523,8 +534,11 @@ impl SqlQueryHandler for Instance { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { - if let Statement::Query(_) = stmt { + ) -> Result> { + if matches!( + stmt, + Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_) + ) { let plan = self .query_engine .planner() diff --git a/src/frontend/src/metrics.rs b/src/frontend/src/metrics.rs index 61ffab4089..cb7745d8c0 100644 --- a/src/frontend/src/metrics.rs +++ b/src/frontend/src/metrics.rs @@ -13,6 +13,7 @@ // limitations under the License. pub(crate) const METRIC_HANDLE_SQL_ELAPSED: &str = "frontend.handle_sql_elapsed"; +pub(crate) const METRIC_EXEC_PLAN_ELAPSED: &str = "frontend.exec_plan_elapsed"; pub(crate) const METRIC_HANDLE_SCRIPTS_ELAPSED: &str = "frontend.handle_scripts_elapsed"; pub(crate) const METRIC_RUN_SCRIPT_ELAPSED: &str = "frontend.run_script_elapsed"; diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index efc868d1cd..4931ac58eb 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -37,7 +37,6 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_common::ResolvedTableReference; use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, WriteOp}; use datatypes::prelude::VectorRef; -use datatypes::schema::Schema; use futures_util::StreamExt; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; @@ -57,7 +56,7 @@ use crate::physical_optimizer::PhysicalOptimizer; use crate::physical_planner::PhysicalPlanner; use crate::plan::LogicalPlan; use crate::planner::{DfLogicalPlanner, LogicalPlanner}; -use crate::query_engine::{QueryEngineContext, QueryEngineState}; +use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState}; use crate::{metrics, QueryEngine}; pub struct DatafusionQueryEngine { @@ -221,11 +220,12 @@ impl QueryEngine for DatafusionQueryEngine { "datafusion" } - async fn describe(&self, plan: LogicalPlan) -> Result { - // TODO(sunng87): consider cache optmised logical plan between describe - // and execute + async fn describe(&self, plan: LogicalPlan) -> Result { let optimised_plan = self.optimize(&plan)?; - optimised_plan.schema() + Ok(DescribeResult { + schema: optimised_plan.schema()?, + logical_plan: optimised_plan, + }) } async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { @@ -540,7 +540,10 @@ mod tests { .await .unwrap(); - let schema = engine.describe(plan).await.unwrap(); + let DescribeResult { + schema, + logical_plan, + } = engine.describe(plan).await.unwrap(); assert_eq!( schema.column_schemas()[0], @@ -550,5 +553,6 @@ mod tests { true ) ); + assert_eq!("Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]]\n TableScan: numbers projection=[number]", format!("{}", logical_plan.display_indent())); } } diff --git a/src/query/src/plan.rs b/src/query/src/plan.rs index b24ddc4504..14ff331122 100644 --- a/src/query/src/plan.rs +++ b/src/query/src/plan.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::fmt::{Debug, Display}; +use common_query::prelude::ScalarValue; use datafusion_expr::LogicalPlan as DfLogicalPlan; +use datatypes::data_type::ConcreteDataType; use datatypes::schema::Schema; use snafu::ResultExt; -use crate::error::{ConvertDatafusionSchemaSnafu, Result}; +use crate::error::{ConvertDatafusionSchemaSnafu, DataFusionSnafu, Result}; /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by @@ -59,4 +62,28 @@ impl LogicalPlan { let LogicalPlan::DfPlan(plan) = self; plan.display_indent() } + + /// Walk the logical plan, find any `PlaceHolder` tokens, + /// and return a map of their IDs and ConcreteDataTypes + pub fn get_param_types(&self) -> Result>> { + let LogicalPlan::DfPlan(plan) = self; + let types = plan.get_parameter_types().context(DataFusionSnafu)?; + + Ok(types + .into_iter() + .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) + .collect()) + } + + /// Return a logical plan with all placeholders/params (e.g $1 $2, + /// ...) replaced with corresponding values provided in the + /// params_values + pub fn replace_params_with_values(&self, values: &[ScalarValue]) -> Result { + let LogicalPlan::DfPlan(plan) = self; + + plan.clone() + .replace_params_with_values(values) + .context(DataFusionSnafu) + .map(LogicalPlan::DfPlan) + } } diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 92131be37c..2ec425c24c 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -77,6 +77,7 @@ impl DfLogicalPlanner { }; PlanSqlSnafu { sql } })?; + Ok(LogicalPlan::DfPlan(result)) } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 5880f405c6..dbab74cf2b 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -43,6 +43,15 @@ pub use crate::query_engine::state::QueryEngineState; pub type SqlStatementExecutorRef = Arc; +/// Describe statement result +#[derive(Debug)] +pub struct DescribeResult { + /// The schema of statement + pub schema: Schema, + /// The logical plan for statement + pub logical_plan: LogicalPlan, +} + #[async_trait] pub trait SqlStatementExecutor: Send + Sync { async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result; @@ -58,7 +67,7 @@ pub trait QueryEngine: Send + Sync { fn name(&self) -> &str; - async fn describe(&self, plan: LogicalPlan) -> Result; + async fn describe(&self, plan: LogicalPlan) -> Result; async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 9212a1cd1d..8621aa0522 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -33,6 +33,9 @@ common-runtime = { path = "../common/runtime" } common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } datafusion.workspace = true +datafusion-common.workspace = true +datafusion-expr.workspace = true + datatypes = { path = "../datatypes" } derive_builder = "0.12" digest = "0.10" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 2f224da225..317953d970 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - use std::any::Any; use std::net::SocketAddr; use std::string::FromUtf8Error; @@ -23,6 +22,7 @@ use base64::DecodeError; use catalog; use common_error::prelude::*; use common_telemetry::logging; +use datatypes::prelude::ConcreteDataType; use query::parser::PromQuery; use serde_json::json; use snafu::Location; @@ -75,6 +75,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Failed to execute plan, source: {}", source))] + ExecutePlan { + location: Location, + source: BoxedError, + }, + #[snafu(display("{source}"))] ExecuteGrpcQuery { location: Location, @@ -250,6 +256,12 @@ pub enum Error { source: query::error::Error, }, + #[snafu(display("Failed to get param types, source: {source}, location: {location}"))] + GetPreparedStmtParams { + source: query::error::Error, + location: Location, + }, + #[snafu(display("{}", reason))] UnexpectedResult { reason: String, location: Location }, @@ -269,10 +281,7 @@ pub enum Error { #[cfg(feature = "pprof")] #[snafu(display("Failed to dump pprof data, source: {}", source))] - DumpPprof { - #[snafu(backtrace)] - source: common_pprof::Error, - }, + DumpPprof { source: common_pprof::Error }, #[snafu(display("Failed to update jemalloc metrics, source: {source}, location: {location}"))] UpdateJemallocMetrics { @@ -285,6 +294,31 @@ pub enum Error { source: datafusion::error::DataFusionError, location: Location, }, + + #[snafu(display( + "Failed to replace params with values in prepared statement, source: {source}, location: {location}" + ))] + ReplacePreparedStmtParams { + source: query::error::Error, + location: Location, + }, + + #[snafu(display("Failed to convert scalar value, source: {source}, location: {location}"))] + ConvertScalarValue { + source: datatypes::error::Error, + location: Location, + }, + + #[snafu(display( + "Expected type: {:?}, actual: {:?}, location: {location}", + expected, + actual + ))] + PreparedStmtTypeMismatch { + expected: ConcreteDataType, + actual: opensrv_mysql::ColumnType, + location: Location, + }, } pub type Result = std::result::Result; @@ -309,6 +343,7 @@ impl ErrorExt for Error { InsertScript { source, .. } | ExecuteScript { source, .. } | ExecuteQuery { source, .. } + | ExecutePlan { source, .. } | ExecuteGrpcQuery { source, .. } | CheckDatabaseValidity { source, .. } => source.status_code(), @@ -324,6 +359,7 @@ impl ErrorExt for Error { | InvalidFlightTicket { .. } | InvalidPrepareStatement { .. } | DataFrame { .. } + | PreparedStmtTypeMismatch { .. } | TimePrecision { .. } => StatusCode::InvalidArguments, InfluxdbLinesWrite { source, .. } | PromSeriesWrite { source, .. } => { @@ -347,7 +383,9 @@ impl ErrorExt for Error { DumpProfileData { source, .. } => source.status_code(), InvalidFlushArgument { .. } => StatusCode::InvalidArguments, - ParsePromQL { source, .. } => source.status_code(), + ReplacePreparedStmtParams { source, .. } + | GetPreparedStmtParams { source, .. } + | ParsePromQL { source, .. } => source.status_code(), Other { source, .. } => source.status_code(), UnexpectedResult { .. } => StatusCode::Unexpected, @@ -366,6 +404,8 @@ impl ErrorExt for Error { DumpPprof { source, .. } => source.status_code(), UpdateJemallocMetrics { .. } => StatusCode::Internal, + + ConvertScalarValue { source, .. } => source.status_code(), } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index cfe9f668a5..0b28ae0623 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -719,6 +719,8 @@ mod test { use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; use query::parser::PromQuery; + use query::plan::LogicalPlan; + use query::query_engine::DescribeResult; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -760,11 +762,19 @@ mod test { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_describe( &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/src/mysql.rs b/src/servers/src/mysql.rs index 04059124f7..0e73ae617a 100644 --- a/src/servers/src/mysql.rs +++ b/src/servers/src/mysql.rs @@ -14,5 +14,6 @@ mod federated; pub mod handler; +mod helper; pub mod server; pub mod writer; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index d917db5597..6205829fcd 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{AtomicU32, Ordering}; @@ -22,18 +21,20 @@ use async_trait::async_trait; use chrono::{NaiveDate, NaiveDateTime}; use common_error::prelude::ErrorExt; use common_query::Output; -use common_telemetry::tracing::log; -use common_telemetry::{error, timer, trace, warn}; +use common_telemetry::{error, logging, timer, trace, warn}; +use datatypes::prelude::ConcreteDataType; use metrics::increment_counter; use opensrv_mysql::{ - AsyncMysqlShim, Column, ColumnFlags, ColumnType, ErrorKind, InitWriter, ParamParser, - ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner, + AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter, + StatementMetaWriter, ValueInner, }; use parking_lot::RwLock; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use rand::RngCore; use session::context::Channel; use session::{Session, SessionRef}; -use snafu::ensure; +use snafu::{ensure, ResultExt}; use sql::dialect::MySqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -41,17 +42,27 @@ use tokio::io::AsyncWrite; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error::{self, InvalidPrepareStatementSnafu, Result}; +use crate::mysql::helper::{ + self, format_placeholder, replace_placeholders, transform_placeholders, +}; use crate::mysql::writer; +use crate::mysql::writer::create_mysql_column; use crate::query_handler::sql::ServerSqlQueryHandlerRef; +/// Cached SQL and logical plan +#[derive(Clone)] +struct SqlPlan { + query: String, + plan: Option, +} + // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { query_handler: ServerSqlQueryHandlerRef, salt: [u8; 20], session: SessionRef, user_provider: Option, - // TODO(SSebo): use something like moka to achieve TTL or LRU - prepared_stmts: Arc>>, + prepared_stmts: Arc>>, prepared_stmts_counter: AtomicU32, } @@ -105,14 +116,34 @@ impl MysqlInstanceShim { output } - fn set_query(&self, query: String) -> u32 { - let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::SeqCst); - let mut guard = self.prepared_stmts.write(); - guard.insert(stmt_id, query); + /// Execute the logical plan and return the output + async fn do_exec_plan(&self, query: &str, plan: LogicalPlan) -> Result { + if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { + Ok(output) + } else { + self.query_handler + .do_exec_plan(plan, self.session.context()) + .await + } + } + + /// Describe the statement + async fn do_describe(&self, statement: Statement) -> Result> { + self.query_handler + .do_describe(statement, self.session.context()) + .await + } + + /// Save query and logical plan, return the unique id + fn save_plan(&self, plan: SqlPlan) -> u32 { + let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed); + let mut prepared_stmts = self.prepared_stmts.write(); + prepared_stmts.insert(stmt_id, plan); stmt_id } - fn query(&self, stmt_id: u32) -> Option { + /// Retrieve the query and logical plan by id + fn plan(&self, stmt_id: u32) -> Option { let guard = self.prepared_stmts.read(); guard.get(&stmt_id).cloned() } @@ -175,15 +206,36 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, w: StatementMetaWriter<'a, W>, ) -> Result<()> { - let (query, param_num) = replace_placeholder(query); - if let Err(e) = validate_query(&query).await { - w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) - .await?; - return Ok(()); + let raw_query = query.clone(); + let (query, param_num) = replace_placeholders(query); + + let statement = validate_query(raw_query).await?; + + // We have to transform the placeholder, because DataFusion only parses placeholders + // in the form of "$i", it can't process "?" right now. + let statement = transform_placeholders(statement); + + let plan = self + .do_describe(statement.clone()) + .await? + .map(|DescribeResult { logical_plan, .. }| logical_plan); + + let params = if let Some(plan) = &plan { + prepared_params( + &plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?, + )? + } else { + dummy_params(param_num)? }; - let stmt_id = self.set_query(query); - let params = dummy_params(param_num); + debug_assert_eq!(params.len(), param_num - 1); + + let stmt_id = self.save_plan(SqlPlan { + query: query.to_string(), + plan, + }); w.reply(stmt_id, ¶ms, &[]).await?; increment_counter!( @@ -216,7 +268,7 @@ impl AsyncMysqlShim for MysqlInstanceShi ] ); let params: Vec = p.into_iter().collect(); - let query = match self.query(stmt_id) { + let sql_plan = match self.plan(stmt_id) { None => { w.error( ErrorKind::ER_UNKNOWN_STMT_HANDLER, @@ -225,13 +277,36 @@ impl AsyncMysqlShim for MysqlInstanceShi .await?; return Ok(()); } - Some(query) => query, + Some(sql_plan) => sql_plan, }; - let query = replace_params(params, query); - log::debug!("execute replaced query: {}", query); + let (query, outputs) = match sql_plan.plan { + Some(plan) => { + let param_types = plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?; + + if params.len() != param_types.len() { + return error::InternalSnafu { + err_msg: "prepare statement params number mismatch".to_string(), + } + .fail(); + } + let plan = replace_params_with_values(&plan, param_types, params)?; + logging::debug!("Mysql execute prepared plan: {}", plan.display_indent()); + let outputs = vec![self.do_exec_plan(&sql_plan.query, plan).await]; + + (sql_plan.query, outputs) + } + None => { + let query = replace_params(params, sql_plan.query); + logging::debug!("Mysql execute replaced query: {}", query); + let outputs = self.do_query(&query).await; + + (query, outputs) + } + }; - let outputs = self.do_query(&query).await; writer::write_output(w, &query, self.session.context(), outputs).await?; Ok(()) @@ -318,7 +393,7 @@ fn replace_params(params: Vec, query: String) -> String { ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), ValueInner::Time(_) => format_duration(Duration::from(param.value)), }; - query = query.replace(&format!("${}", index), &s); + query = query.replace(&format_placeholder(index), &s); index += 1; } query @@ -331,6 +406,27 @@ fn format_duration(duration: Duration) -> String { format!("{}:{}:{}", hours, minutes, seconds) } +fn replace_params_with_values( + plan: &LogicalPlan, + param_types: HashMap>, + params: Vec, +) -> Result { + debug_assert_eq!(param_types.len(), params.len()); + + let mut values = Vec::with_capacity(params.len()); + + for (i, param) in params.iter().enumerate() { + if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) { + let value = helper::convert_value(param, t)?; + + values.push(value); + } + } + + plan.replace_params_with_values(&values) + .context(error::ReplacePreparedStmtParamsSnafu) +} + async fn validate_query(query: &str) -> Result { let statement = ParserContext::create_with_dialect(query, &MySqlDialect {}); let mut statement = statement.map_err(|e| { @@ -352,29 +448,27 @@ async fn validate_query(query: &str) -> Result { Ok(statement) } -// dummy columns to satisfy opensrv_mysql, just the number of params is useful -// TODO(SSebo): use parameter type inference to return actual types -fn dummy_params(index: u32) -> Vec { - let mut params = vec![]; +fn dummy_params(index: usize) -> Result> { + let mut params = Vec::with_capacity(index - 1); for _ in 1..index { - params.push(opensrv_mysql::Column { - table: "".to_string(), - column: "".to_string(), - coltype: ColumnType::MYSQL_TYPE_LONG, - colflags: ColumnFlags::NOT_NULL_FLAG, - }); + params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?); } - params + + Ok(params) } -fn replace_placeholder(query: &str) -> (String, u32) { - let mut query = query.to_string(); - let mut index = 1; - while let Some(position) = query.find('?') { - let place_holder = format!("${}", index); - query.replace_range(position..position + 1, &place_holder); - index += 1; +/// Parameters that the client must provide when executing the prepared statement. +fn prepared_params(param_types: &HashMap>) -> Result> { + let mut params = Vec::with_capacity(param_types.len()); + + // Placeholder index starts from 1 + for index in 1..=param_types.len() { + if let Some(Some(t)) = param_types.get(&format_placeholder(index)) { + let column = create_mysql_column(t, "")?; + params.push(column); + } } - (query, index) + + Ok(params) } diff --git a/src/servers/src/mysql/helper.rs b/src/servers/src/mysql/helper.rs new file mode 100644 index 0000000000..e734b821c2 --- /dev/null +++ b/src/servers/src/mysql/helper.rs @@ -0,0 +1,238 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::ops::ControlFlow; +use std::time::Duration; + +use chrono::{NaiveDate, NaiveDateTime}; +use common_query::prelude::ScalarValue; +use datatypes::prelude::ConcreteDataType; +use datatypes::value::{self, Value}; +use itertools::Itertools; +use opensrv_mysql::{ParamValue, ValueInner}; +use snafu::ResultExt; +use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut}; +use sql::statements::statement::Statement; + +use crate::error::{self, Result}; + +/// Returns the placeholder string "$i". +pub fn format_placeholder(i: usize) -> String { + format!("${}", i) +} + +/// Replace all the "?" placeholder into "$i" in SQL, +/// returns the new SQL and the last placeholder index. +pub fn replace_placeholders(query: &str) -> (String, usize) { + let query_parts = query.split('?').collect::>(); + let parts_len = query_parts.len(); + let mut index = 0; + let query = query_parts + .into_iter() + .enumerate() + .map(|(i, part)| { + if i == parts_len - 1 { + return part.to_string(); + } + + index += 1; + format!("{part}{}", format_placeholder(index)) + }) + .join(""); + + (query, index + 1) +} + +/// Transform all the "?" placeholder into "$i". +/// Only works for Insert,Query and Delete statements. +pub fn transform_placeholders(stmt: Statement) -> Statement { + match stmt { + Statement::Query(mut query) => { + visit_placeholders(&mut query.inner); + Statement::Query(query) + } + Statement::Insert(mut insert) => { + visit_placeholders(&mut insert.inner); + Statement::Insert(insert) + } + Statement::Delete(mut delete) => { + visit_placeholders(&mut delete.inner); + Statement::Delete(delete) + } + stmt => stmt, + } +} + +fn visit_placeholders(v: &mut V) +where + V: VisitMut, +{ + let mut index = 1; + visit_expressions_mut(v, |expr| { + if let Expr::Value(ValueExpr::Placeholder(s)) = expr { + *s = format_placeholder(index); + index += 1; + } + ControlFlow::<()>::Continue(()) + }); +} + +/// Convert [`ParamValue`] into [`Value`] according to param type. +/// It will try it's best to do type conversions if possible +pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result { + match param.value.into_inner() { + ValueInner::Int(i) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))), + ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i)) + .try_to_scalar_value(t) + .context(error::ConvertScalarValueSnafu), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::UInt(u) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))), + ConcreteDataType::Timestamp(ts_type) => { + Value::Timestamp(ts_type.create_timestamp(u as i64)) + .try_to_scalar_value(t) + .context(error::ConvertScalarValueSnafu) + } + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::Double(f) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::NULL => Ok(value::to_null_scalar_value(t)), + ValueInner::Bytes(b) => match t { + ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some( + String::from_utf8_lossy(b).to_string(), + ))), + ConcreteDataType::Binary(_) => Ok(ScalarValue::LargeBinary(Some(b.to_vec()))), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::Date(_) => { + let date: common_time::Date = NaiveDate::from(param.value).into(); + Ok(ScalarValue::Date32(Some(date.val()))) + } + ValueInner::Datetime(_) => Ok(ScalarValue::Date64(Some( + NaiveDateTime::from(param.value).timestamp_millis(), + ))), + ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some( + Duration::from(param.value).as_millis() as i64, + ))), + } +} + +#[cfg(test)] +mod tests { + use sql::dialect::MySqlDialect; + use sql::parser::ParserContext; + + use super::*; + + #[test] + fn test_format_placeholder() { + assert_eq!("$1", format_placeholder(1)); + assert_eq!("$3", format_placeholder(3)); + } + + #[test] + fn test_replace_placeholders() { + let create = "create table demo(host string, ts timestamp time index)"; + let (sql, index) = replace_placeholders(create); + assert_eq!(create, sql); + assert_eq!(1, index); + + let insert = "insert into demo values(?,?,?)"; + let (sql, index) = replace_placeholders(insert); + assert_eq!("insert into demo values($1,$2,$3)", sql); + assert_eq!(4, index); + + let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?"; + let (sql, index) = replace_placeholders(query); + assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql); + assert_eq!(4, index); + } + + fn parse_sql(sql: &str) -> Statement { + let mut stmts = ParserContext::create_with_dialect(sql, &MySqlDialect {}).unwrap(); + stmts.remove(0) + } + + #[test] + fn test_transform_placeholders() { + let insert = parse_sql("insert into demo values(?,?,?)"); + let Statement::Insert(insert) = transform_placeholders(insert) else { unreachable!()}; + assert_eq!( + "INSERT INTO demo VALUES ($1, $2, $3)", + insert.inner.to_string() + ); + + let delete = parse_sql("delete from demo where host=? and idc=?"); + let Statement::Delete(delete) = transform_placeholders(delete) else { unreachable!()}; + assert_eq!( + "DELETE FROM demo WHERE host = $1 AND idc = $2", + delete.inner.to_string() + ); + + let select = parse_sql("select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?"); + let Statement::Query(select) = transform_placeholders(select) else { unreachable!()}; + assert_eq!("SELECT from AS demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string()); + } +} diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index 6a060635b6..4249b82780 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -18,7 +18,7 @@ use common_query::Output; use common_recordbatch::{util, RecordBatch}; use common_telemetry::error; use datatypes::prelude::{ConcreteDataType, Value}; -use datatypes::schema::{ColumnSchema, SchemaRef}; +use datatypes::schema::SchemaRef; use opensrv_mysql::{ Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter, }; @@ -176,8 +176,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { Value::Float64(v) => row_writer.write_col(v.0)?, Value::String(v) => row_writer.write_col(v.as_utf8())?, Value::Binary(v) => row_writer.write_col(v.deref())?, - Value::Date(v) => row_writer.write_col(v.val())?, - Value::DateTime(v) => row_writer.write_col(v.val())?, + Value::Date(v) => row_writer.write_col(v.to_chrono_date())?, + Value::DateTime(v) => row_writer.write_col(v.to_chrono_datetime())?, Value::Timestamp(v) => row_writer .write_col(v.to_timezone_aware_string(query_context.time_zone()))?, Value::List(_) => { @@ -208,8 +208,11 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { } } -fn create_mysql_column(column_schema: &ColumnSchema) -> Result { - let column_type = match column_schema.data_type { +pub(crate) fn create_mysql_column( + data_type: &ConcreteDataType, + column_name: &str, +) -> Result { + let column_type = match data_type { ConcreteDataType::Null(_) => Ok(ColumnType::MYSQL_TYPE_NULL), ConcreteDataType::Boolean(_) | ConcreteDataType::Int8(_) | ConcreteDataType::UInt8(_) => { Ok(ColumnType::MYSQL_TYPE_TINY) @@ -230,15 +233,12 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { ConcreteDataType::Date(_) => Ok(ColumnType::MYSQL_TYPE_DATE), ConcreteDataType::DateTime(_) => Ok(ColumnType::MYSQL_TYPE_DATETIME), _ => error::InternalSnafu { - err_msg: format!( - "not implemented for column datatype {:?}", - column_schema.data_type - ), + err_msg: format!("not implemented for column datatype {:?}", data_type), } .fail(), }; let mut colflags = ColumnFlags::empty(); - match column_schema.data_type { + match data_type { ConcreteDataType::UInt16(_) | ConcreteDataType::UInt8(_) | ConcreteDataType::UInt32(_) @@ -246,7 +246,7 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { _ => {} }; column_type.map(|column_type| Column { - column: column_schema.name.clone(), + column: column_name.to_string(), coltype: column_type, // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server @@ -261,6 +261,6 @@ pub fn create_mysql_column_def(schema: &SchemaRef) -> Result> { schema .column_schemas() .iter() - .map(create_mysql_column) + .map(|column_schema| create_mysql_column(&column_schema.data_type, &column_schema.name)) .collect() } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 583e988474..5fc71be2f7 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -33,6 +33,7 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use query::query_engine::DescribeResult; use sql::dialect::PostgreSqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -405,7 +406,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { // get Statement part of the tuple let (stmt, _) = stmt; - if let Some(schema) = self + if let Some(DescribeResult { schema, .. }) = self .query_handler .do_describe(stmt.clone(), self.session.context()) .await diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index af8bbed5c2..3ae964c9cc 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use async_trait::async_trait; use common_error::prelude::*; use common_query::Output; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; use session::context::QueryContextRef; use sql::statements::statement::Statement; @@ -26,6 +26,7 @@ use crate::error::{self, Result}; pub type SqlQueryHandlerRef = Arc + Send + Sync>; pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef; +use query::query_engine::DescribeResult; #[async_trait] pub trait SqlQueryHandler { @@ -37,6 +38,12 @@ pub trait SqlQueryHandler { query_ctx: QueryContextRef, ) -> Vec>; + async fn do_exec_plan( + &self, + plan: LogicalPlan, + query_ctx: QueryContextRef, + ) -> std::result::Result; + async fn do_promql_query( &self, query: &PromQuery, @@ -47,7 +54,7 @@ pub trait SqlQueryHandler { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> std::result::Result, Self::Error>; + ) -> std::result::Result, Self::Error>; async fn is_valid_schema( &self, @@ -83,6 +90,14 @@ where .collect() } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + self.0 + .do_exec_plan(plan, query_ctx) + .await + .map_err(BoxedError::new) + .context(error::ExecutePlanSnafu) + } + async fn do_promql_query( &self, query: &PromQuery, @@ -107,7 +122,7 @@ where &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { self.0 .do_describe(stmt, query_ctx) .await diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 658dc929d1..d7a92543af 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -21,8 +21,9 @@ use axum::{http, Router}; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::influxdb::InfluxdbRequest; @@ -71,6 +72,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -83,7 +92,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 9cce749fc9..d01f2a0a10 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -20,8 +20,9 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{self, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::opentsdb::codec::DataPoint; @@ -70,6 +71,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -82,7 +91,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index ba70759a72..ddcdbb7fa1 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -23,9 +23,10 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use prost::Message; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::prometheus; @@ -95,6 +96,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -107,7 +116,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index e91eebb9d5..910b2ca132 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -21,8 +21,9 @@ use async_trait::async_trait; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; -use datatypes::schema::Schema; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; @@ -78,6 +79,10 @@ impl SqlQueryHandler for DummyInstance { vec![Ok(output)] } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + Ok(self.query_engine.execute(plan, query_ctx).await.unwrap()) + } + async fn do_promql_query( &self, _: &PromQuery, @@ -90,7 +95,7 @@ impl SqlQueryHandler for DummyInstance { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { if let Statement::Query(_) = stmt { let plan = self .query_engine diff --git a/src/sql/src/ast.rs b/src/sql/src/ast.rs index b35b71b51b..a72d7965b7 100644 --- a/src/sql/src/ast.rs +++ b/src/sql/src/ast.rs @@ -13,7 +13,7 @@ // limitations under the License. pub use sqlparser::ast::{ - BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function, - FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo, - Value, + visit_expressions_mut, BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType, + Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, + TimezoneInfo, Value, VisitMut, Visitor, }; diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 90bed17fa3..e4cc07dde3 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -13,6 +13,7 @@ axum = "0.6" axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } async-trait = "0.1" catalog = { path = "../src/catalog" } +chrono.workspace = true client = { path = "../src/client", features = ["testing"] } common-base = { path = "../src/common/base" } common-catalog = { path = "../src/common/catalog" } @@ -49,6 +50,7 @@ sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "mysql", "postgres", + "chrono", ] } table = { path = "../src/table" } tempfile.workspace = true diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 915ab909bb..f5d770582f 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; use sqlx::mysql::MySqlPoolOptions; use sqlx::postgres::PgPoolOptions; use sqlx::Row; @@ -62,20 +63,24 @@ pub async fn test_mysql_crud(store_type: StorageType) { .await .unwrap(); - sqlx::query("create table demo(i bigint, ts timestamp time index)") + sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)") .execute(&pool) .await .unwrap(); for i in 0..10 { - sqlx::query("insert into demo values(?, ?)") + let dt = DateTime::::from_utc(NaiveDateTime::from_timestamp_opt(60, i).unwrap(), Utc); + let d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + sqlx::query("insert into demo values(?, ?, ?, ?)") .bind(i) .bind(i) + .bind(d) + .bind(dt) .execute(&pool) .await .unwrap(); } - let rows = sqlx::query("select i from demo") + let rows = sqlx::query("select i, d, dt from demo") .fetch_all(&pool) .await .unwrap(); @@ -83,7 +88,34 @@ pub async fn test_mysql_crud(store_type: StorageType) { for (i, row) in rows.iter().enumerate() { let ret: i64 = row.get(0); + let d: NaiveDate = row.get(1); + let dt: DateTime = row.get(2); assert_eq!(ret, i as i64); + + let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + assert_eq!(expected_d, d); + + let expected_dt = DateTime::::from_utc( + NaiveDateTime::from_timestamp_opt(60, i as u32).unwrap(), + Utc, + ); + + assert_eq!( + format!("{}", expected_dt.format("%Y-%m-%d %H:%M:%S")), + format!("{}", dt.format("%Y-%m-%d %H:%M:%S")) + ); + } + + let rows = sqlx::query("select i from demo where i=?") + .bind(6) + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + + for row in rows { + let ret: i64 = row.get(0); + assert_eq!(ret, 6); } sqlx::query("delete from demo") @@ -133,6 +165,18 @@ pub async fn test_postgres_crud(store_type: StorageType) { assert_eq!(ret, i as i64); } + let rows = sqlx::query("select i from demo where i=$1") + .bind(6) + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + + for row in rows { + let ret: i64 = row.get(0); + assert_eq!(ret, 6); + } + sqlx::query("delete from demo") .execute(&pool) .await