mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: prepare supports caching logical plan and infering param types (#1776)
* feat: change do_describe function signature * feat: infer param type and cache logical plan for msyql prepared statments * fix: convert_value * fix: forgot helper * chore: comments * fix: typo * test: add more tests and test date, datatime in mysql * chore: fix CR comments * chore: add location * chore: by CR comments * Update tests-integration/tests/sql.rs Co-authored-by: Ruihang Xia <waynestxia@gmail.com> * chore: remove the trace --------- Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -8537,6 +8537,8 @@ dependencies = [
|
||||
"common-test-util",
|
||||
"common-time",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
"datatypes",
|
||||
"derive_builder 0.12.0",
|
||||
"digest",
|
||||
@@ -8973,6 +8975,7 @@ dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"digest",
|
||||
@@ -9553,6 +9556,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-test-helper",
|
||||
"catalog",
|
||||
"chrono",
|
||||
"client",
|
||||
"common-base",
|
||||
"common-catalog",
|
||||
|
||||
@@ -52,6 +52,12 @@ impl From<i32> for Date {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NaiveDate> for Date {
|
||||
fn from(date: NaiveDate) -> Self {
|
||||
Self(date.num_days_from_ce() - UNIX_EPOCH_FROM_CE)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Date {
|
||||
/// [Date] is formatted according to ISO-8601 standard.
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
|
||||
@@ -183,6 +183,12 @@ impl ConcreteDataType {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ConcreteDataType> for ConcreteDataType {
|
||||
fn from(t: &ConcreteDataType) -> Self {
|
||||
t.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&ArrowDataType> for ConcreteDataType {
|
||||
type Error = Error;
|
||||
|
||||
|
||||
@@ -248,7 +248,7 @@ impl Value {
|
||||
Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())),
|
||||
Value::Date(v) => ScalarValue::Date32(Some(v.val())),
|
||||
Value::DateTime(v) => ScalarValue::Date64(Some(v.val())),
|
||||
Value::Null => to_null_value(output_type),
|
||||
Value::Null => to_null_scalar_value(output_type),
|
||||
Value::List(list) => {
|
||||
// Safety: The logical type of the value and output_type are the same.
|
||||
let list_type = output_type.as_list().unwrap();
|
||||
@@ -261,7 +261,7 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue {
|
||||
pub fn to_null_scalar_value(output_type: &ConcreteDataType) -> ScalarValue {
|
||||
match output_type {
|
||||
ConcreteDataType::Null(_) => ScalarValue::Null,
|
||||
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None),
|
||||
@@ -285,7 +285,7 @@ fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue {
|
||||
}
|
||||
ConcreteDataType::Dictionary(dict) => ScalarValue::Dictionary(
|
||||
Box::new(dict.key_type().as_arrow_type()),
|
||||
Box::new(to_null_value(dict.value_type())),
|
||||
Box::new(to_null_scalar_value(dict.value_type())),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,9 @@ use meta_client::MetaClientOptions;
|
||||
use partition::manager::PartitionRuleManager;
|
||||
use partition::route::TableRoutes;
|
||||
use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::options::{validate_catalog_and_schema, QueryOptions};
|
||||
use query::query_engine::DescribeResult;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use servers::error as server_error;
|
||||
use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu};
|
||||
@@ -73,8 +75,9 @@ use sql::statements::statement::Statement;
|
||||
|
||||
use crate::catalog::FrontendCatalogManager;
|
||||
use crate::error::{
|
||||
self, Error, ExecutePromqlSnafu, ExternalSnafu, InvalidInsertRequestSnafu,
|
||||
MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu,
|
||||
self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu,
|
||||
InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result,
|
||||
SqlExecInterceptedSnafu,
|
||||
};
|
||||
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
|
||||
use crate::frontend::FrontendOptions;
|
||||
@@ -506,6 +509,14 @@ impl SqlQueryHandler for Instance {
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let _timer = timer!(metrics::METRIC_EXEC_PLAN_ELAPSED);
|
||||
self.query_engine
|
||||
.execute(plan, query_ctx)
|
||||
.await
|
||||
.context(ExecLogicalPlanSnafu)
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
query: &PromQuery,
|
||||
@@ -523,8 +534,11 @@ impl SqlQueryHandler for Instance {
|
||||
&self,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
if let Statement::Query(_) = stmt {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
if matches!(
|
||||
stmt,
|
||||
Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_)
|
||||
) {
|
||||
let plan = self
|
||||
.query_engine
|
||||
.planner()
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) const METRIC_HANDLE_SQL_ELAPSED: &str = "frontend.handle_sql_elapsed";
|
||||
pub(crate) const METRIC_EXEC_PLAN_ELAPSED: &str = "frontend.exec_plan_elapsed";
|
||||
pub(crate) const METRIC_HANDLE_SCRIPTS_ELAPSED: &str = "frontend.handle_scripts_elapsed";
|
||||
pub(crate) const METRIC_RUN_SCRIPT_ELAPSED: &str = "frontend.run_script_elapsed";
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ use datafusion::physical_plan::ExecutionPlan;
|
||||
use datafusion_common::ResolvedTableReference;
|
||||
use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, WriteOp};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::schema::Schema;
|
||||
use futures_util::StreamExt;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
@@ -57,7 +56,7 @@ use crate::physical_optimizer::PhysicalOptimizer;
|
||||
use crate::physical_planner::PhysicalPlanner;
|
||||
use crate::plan::LogicalPlan;
|
||||
use crate::planner::{DfLogicalPlanner, LogicalPlanner};
|
||||
use crate::query_engine::{QueryEngineContext, QueryEngineState};
|
||||
use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
|
||||
use crate::{metrics, QueryEngine};
|
||||
|
||||
pub struct DatafusionQueryEngine {
|
||||
@@ -221,11 +220,12 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
"datafusion"
|
||||
}
|
||||
|
||||
async fn describe(&self, plan: LogicalPlan) -> Result<Schema> {
|
||||
// TODO(sunng87): consider cache optmised logical plan between describe
|
||||
// and execute
|
||||
async fn describe(&self, plan: LogicalPlan) -> Result<DescribeResult> {
|
||||
let optimised_plan = self.optimize(&plan)?;
|
||||
optimised_plan.schema()
|
||||
Ok(DescribeResult {
|
||||
schema: optimised_plan.schema()?,
|
||||
logical_plan: optimised_plan,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
@@ -540,7 +540,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = engine.describe(plan).await.unwrap();
|
||||
let DescribeResult {
|
||||
schema,
|
||||
logical_plan,
|
||||
} = engine.describe(plan).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
schema.column_schemas()[0],
|
||||
@@ -550,5 +553,6 @@ mod tests {
|
||||
true
|
||||
)
|
||||
);
|
||||
assert_eq!("Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]]\n TableScan: numbers projection=[number]", format!("{}", logical_plan.display_indent()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
use common_query::prelude::ScalarValue;
|
||||
use datafusion_expr::LogicalPlan as DfLogicalPlan;
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use datatypes::schema::Schema;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{ConvertDatafusionSchemaSnafu, Result};
|
||||
use crate::error::{ConvertDatafusionSchemaSnafu, DataFusionSnafu, Result};
|
||||
|
||||
/// A LogicalPlan represents the different types of relational
|
||||
/// operators (such as Projection, Filter, etc) and can be created by
|
||||
@@ -59,4 +62,28 @@ impl LogicalPlan {
|
||||
let LogicalPlan::DfPlan(plan) = self;
|
||||
plan.display_indent()
|
||||
}
|
||||
|
||||
/// Walk the logical plan, find any `PlaceHolder` tokens,
|
||||
/// and return a map of their IDs and ConcreteDataTypes
|
||||
pub fn get_param_types(&self) -> Result<HashMap<String, Option<ConcreteDataType>>> {
|
||||
let LogicalPlan::DfPlan(plan) = self;
|
||||
let types = plan.get_parameter_types().context(DataFusionSnafu)?;
|
||||
|
||||
Ok(types
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Return a logical plan with all placeholders/params (e.g $1 $2,
|
||||
/// ...) replaced with corresponding values provided in the
|
||||
/// params_values
|
||||
pub fn replace_params_with_values(&self, values: &[ScalarValue]) -> Result<LogicalPlan> {
|
||||
let LogicalPlan::DfPlan(plan) = self;
|
||||
|
||||
plan.clone()
|
||||
.replace_params_with_values(values)
|
||||
.context(DataFusionSnafu)
|
||||
.map(LogicalPlan::DfPlan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ impl DfLogicalPlanner {
|
||||
};
|
||||
PlanSqlSnafu { sql }
|
||||
})?;
|
||||
|
||||
Ok(LogicalPlan::DfPlan(result))
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,15 @@ pub use crate::query_engine::state::QueryEngineState;
|
||||
|
||||
pub type SqlStatementExecutorRef = Arc<dyn SqlStatementExecutor>;
|
||||
|
||||
/// Describe statement result
|
||||
#[derive(Debug)]
|
||||
pub struct DescribeResult {
|
||||
/// The schema of statement
|
||||
pub schema: Schema,
|
||||
/// The logical plan for statement
|
||||
pub logical_plan: LogicalPlan,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SqlStatementExecutor: Send + Sync {
|
||||
async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output>;
|
||||
@@ -58,7 +67,7 @@ pub trait QueryEngine: Send + Sync {
|
||||
|
||||
fn name(&self) -> &str;
|
||||
|
||||
async fn describe(&self, plan: LogicalPlan) -> Result<Schema>;
|
||||
async fn describe(&self, plan: LogicalPlan) -> Result<DescribeResult>;
|
||||
|
||||
async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output>;
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
datafusion.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
|
||||
datatypes = { path = "../datatypes" }
|
||||
derive_builder = "0.12"
|
||||
digest = "0.10"
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::net::SocketAddr;
|
||||
use std::string::FromUtf8Error;
|
||||
@@ -23,6 +22,7 @@ use base64::DecodeError;
|
||||
use catalog;
|
||||
use common_error::prelude::*;
|
||||
use common_telemetry::logging;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use query::parser::PromQuery;
|
||||
use serde_json::json;
|
||||
use snafu::Location;
|
||||
@@ -75,6 +75,12 @@ pub enum Error {
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute plan, source: {}", source))]
|
||||
ExecutePlan {
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("{source}"))]
|
||||
ExecuteGrpcQuery {
|
||||
location: Location,
|
||||
@@ -250,6 +256,12 @@ pub enum Error {
|
||||
source: query::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to get param types, source: {source}, location: {location}"))]
|
||||
GetPreparedStmtParams {
|
||||
source: query::error::Error,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("{}", reason))]
|
||||
UnexpectedResult { reason: String, location: Location },
|
||||
|
||||
@@ -269,10 +281,7 @@ pub enum Error {
|
||||
|
||||
#[cfg(feature = "pprof")]
|
||||
#[snafu(display("Failed to dump pprof data, source: {}", source))]
|
||||
DumpPprof {
|
||||
#[snafu(backtrace)]
|
||||
source: common_pprof::Error,
|
||||
},
|
||||
DumpPprof { source: common_pprof::Error },
|
||||
|
||||
#[snafu(display("Failed to update jemalloc metrics, source: {source}, location: {location}"))]
|
||||
UpdateJemallocMetrics {
|
||||
@@ -285,6 +294,31 @@ pub enum Error {
|
||||
source: datafusion::error::DataFusionError,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display(
|
||||
"Failed to replace params with values in prepared statement, source: {source}, location: {location}"
|
||||
))]
|
||||
ReplacePreparedStmtParams {
|
||||
source: query::error::Error,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to convert scalar value, source: {source}, location: {location}"))]
|
||||
ConvertScalarValue {
|
||||
source: datatypes::error::Error,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display(
|
||||
"Expected type: {:?}, actual: {:?}, location: {location}",
|
||||
expected,
|
||||
actual
|
||||
))]
|
||||
PreparedStmtTypeMismatch {
|
||||
expected: ConcreteDataType,
|
||||
actual: opensrv_mysql::ColumnType,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -309,6 +343,7 @@ impl ErrorExt for Error {
|
||||
InsertScript { source, .. }
|
||||
| ExecuteScript { source, .. }
|
||||
| ExecuteQuery { source, .. }
|
||||
| ExecutePlan { source, .. }
|
||||
| ExecuteGrpcQuery { source, .. }
|
||||
| CheckDatabaseValidity { source, .. } => source.status_code(),
|
||||
|
||||
@@ -324,6 +359,7 @@ impl ErrorExt for Error {
|
||||
| InvalidFlightTicket { .. }
|
||||
| InvalidPrepareStatement { .. }
|
||||
| DataFrame { .. }
|
||||
| PreparedStmtTypeMismatch { .. }
|
||||
| TimePrecision { .. } => StatusCode::InvalidArguments,
|
||||
|
||||
InfluxdbLinesWrite { source, .. } | PromSeriesWrite { source, .. } => {
|
||||
@@ -347,7 +383,9 @@ impl ErrorExt for Error {
|
||||
DumpProfileData { source, .. } => source.status_code(),
|
||||
InvalidFlushArgument { .. } => StatusCode::InvalidArguments,
|
||||
|
||||
ParsePromQL { source, .. } => source.status_code(),
|
||||
ReplacePreparedStmtParams { source, .. }
|
||||
| GetPreparedStmtParams { source, .. }
|
||||
| ParsePromQL { source, .. } => source.status_code(),
|
||||
Other { source, .. } => source.status_code(),
|
||||
|
||||
UnexpectedResult { .. } => StatusCode::Unexpected,
|
||||
@@ -366,6 +404,8 @@ impl ErrorExt for Error {
|
||||
DumpPprof { source, .. } => source.status_code(),
|
||||
|
||||
UpdateJemallocMetrics { .. } => StatusCode::Internal,
|
||||
|
||||
ConvertScalarValue { source, .. } => source.status_code(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -719,6 +719,8 @@ mod test {
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{StringVector, UInt32Vector};
|
||||
use query::parser::PromQuery;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -760,11 +762,19 @@ mod test {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_exec_plan(
|
||||
&self,
|
||||
_plan: LogicalPlan,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_describe(
|
||||
&self,
|
||||
_stmt: sql::statements::statement::Statement,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -14,5 +14,6 @@
|
||||
|
||||
mod federated;
|
||||
pub mod handler;
|
||||
mod helper;
|
||||
pub mod server;
|
||||
pub mod writer;
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
@@ -22,18 +21,20 @@ use async_trait::async_trait;
|
||||
use chrono::{NaiveDate, NaiveDateTime};
|
||||
use common_error::prelude::ErrorExt;
|
||||
use common_query::Output;
|
||||
use common_telemetry::tracing::log;
|
||||
use common_telemetry::{error, timer, trace, warn};
|
||||
use common_telemetry::{error, logging, timer, trace, warn};
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use metrics::increment_counter;
|
||||
use opensrv_mysql::{
|
||||
AsyncMysqlShim, Column, ColumnFlags, ColumnType, ErrorKind, InitWriter, ParamParser,
|
||||
ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner,
|
||||
AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
|
||||
StatementMetaWriter, ValueInner,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use rand::RngCore;
|
||||
use session::context::Channel;
|
||||
use session::{Session, SessionRef};
|
||||
use snafu::ensure;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sql::dialect::MySqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
@@ -41,17 +42,27 @@ use tokio::io::AsyncWrite;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error::{self, InvalidPrepareStatementSnafu, Result};
|
||||
use crate::mysql::helper::{
|
||||
self, format_placeholder, replace_placeholders, transform_placeholders,
|
||||
};
|
||||
use crate::mysql::writer;
|
||||
use crate::mysql::writer::create_mysql_column;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
/// Cached SQL and logical plan
|
||||
#[derive(Clone)]
|
||||
struct SqlPlan {
|
||||
query: String,
|
||||
plan: Option<LogicalPlan>,
|
||||
}
|
||||
|
||||
// An intermediate shim for executing MySQL queries.
|
||||
pub struct MysqlInstanceShim {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
salt: [u8; 20],
|
||||
session: SessionRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
// TODO(SSebo): use something like moka to achieve TTL or LRU
|
||||
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
|
||||
prepared_stmts: Arc<RwLock<HashMap<u32, SqlPlan>>>,
|
||||
prepared_stmts_counter: AtomicU32,
|
||||
}
|
||||
|
||||
@@ -105,14 +116,34 @@ impl MysqlInstanceShim {
|
||||
output
|
||||
}
|
||||
|
||||
fn set_query(&self, query: String) -> u32 {
|
||||
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::SeqCst);
|
||||
let mut guard = self.prepared_stmts.write();
|
||||
guard.insert(stmt_id, query);
|
||||
/// Execute the logical plan and return the output
|
||||
async fn do_exec_plan(&self, query: &str, plan: LogicalPlan) -> Result<Output> {
|
||||
if let Some(output) = crate::mysql::federated::check(query, self.session.context()) {
|
||||
Ok(output)
|
||||
} else {
|
||||
self.query_handler
|
||||
.do_exec_plan(plan, self.session.context())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Describe the statement
|
||||
async fn do_describe(&self, statement: Statement) -> Result<Option<DescribeResult>> {
|
||||
self.query_handler
|
||||
.do_describe(statement, self.session.context())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Save query and logical plan, return the unique id
|
||||
fn save_plan(&self, plan: SqlPlan) -> u32 {
|
||||
let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
|
||||
let mut prepared_stmts = self.prepared_stmts.write();
|
||||
prepared_stmts.insert(stmt_id, plan);
|
||||
stmt_id
|
||||
}
|
||||
|
||||
fn query(&self, stmt_id: u32) -> Option<String> {
|
||||
/// Retrieve the query and logical plan by id
|
||||
fn plan(&self, stmt_id: u32) -> Option<SqlPlan> {
|
||||
let guard = self.prepared_stmts.read();
|
||||
guard.get(&stmt_id).cloned()
|
||||
}
|
||||
@@ -175,15 +206,36 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
query: &'a str,
|
||||
w: StatementMetaWriter<'a, W>,
|
||||
) -> Result<()> {
|
||||
let (query, param_num) = replace_placeholder(query);
|
||||
if let Err(e) = validate_query(&query).await {
|
||||
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
|
||||
.await?;
|
||||
return Ok(());
|
||||
let raw_query = query.clone();
|
||||
let (query, param_num) = replace_placeholders(query);
|
||||
|
||||
let statement = validate_query(raw_query).await?;
|
||||
|
||||
// We have to transform the placeholder, because DataFusion only parses placeholders
|
||||
// in the form of "$i", it can't process "?" right now.
|
||||
let statement = transform_placeholders(statement);
|
||||
|
||||
let plan = self
|
||||
.do_describe(statement.clone())
|
||||
.await?
|
||||
.map(|DescribeResult { logical_plan, .. }| logical_plan);
|
||||
|
||||
let params = if let Some(plan) = &plan {
|
||||
prepared_params(
|
||||
&plan
|
||||
.get_param_types()
|
||||
.context(error::GetPreparedStmtParamsSnafu)?,
|
||||
)?
|
||||
} else {
|
||||
dummy_params(param_num)?
|
||||
};
|
||||
|
||||
let stmt_id = self.set_query(query);
|
||||
let params = dummy_params(param_num);
|
||||
debug_assert_eq!(params.len(), param_num - 1);
|
||||
|
||||
let stmt_id = self.save_plan(SqlPlan {
|
||||
query: query.to_string(),
|
||||
plan,
|
||||
});
|
||||
|
||||
w.reply(stmt_id, ¶ms, &[]).await?;
|
||||
increment_counter!(
|
||||
@@ -216,7 +268,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
]
|
||||
);
|
||||
let params: Vec<ParamValue> = p.into_iter().collect();
|
||||
let query = match self.query(stmt_id) {
|
||||
let sql_plan = match self.plan(stmt_id) {
|
||||
None => {
|
||||
w.error(
|
||||
ErrorKind::ER_UNKNOWN_STMT_HANDLER,
|
||||
@@ -225,13 +277,36 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(query) => query,
|
||||
Some(sql_plan) => sql_plan,
|
||||
};
|
||||
|
||||
let query = replace_params(params, query);
|
||||
log::debug!("execute replaced query: {}", query);
|
||||
let (query, outputs) = match sql_plan.plan {
|
||||
Some(plan) => {
|
||||
let param_types = plan
|
||||
.get_param_types()
|
||||
.context(error::GetPreparedStmtParamsSnafu)?;
|
||||
|
||||
if params.len() != param_types.len() {
|
||||
return error::InternalSnafu {
|
||||
err_msg: "prepare statement params number mismatch".to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
let plan = replace_params_with_values(&plan, param_types, params)?;
|
||||
logging::debug!("Mysql execute prepared plan: {}", plan.display_indent());
|
||||
let outputs = vec![self.do_exec_plan(&sql_plan.query, plan).await];
|
||||
|
||||
(sql_plan.query, outputs)
|
||||
}
|
||||
None => {
|
||||
let query = replace_params(params, sql_plan.query);
|
||||
logging::debug!("Mysql execute replaced query: {}", query);
|
||||
let outputs = self.do_query(&query).await;
|
||||
|
||||
(query, outputs)
|
||||
}
|
||||
};
|
||||
|
||||
let outputs = self.do_query(&query).await;
|
||||
writer::write_output(w, &query, self.session.context(), outputs).await?;
|
||||
|
||||
Ok(())
|
||||
@@ -318,7 +393,7 @@ fn replace_params(params: Vec<ParamValue>, query: String) -> String {
|
||||
ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(),
|
||||
ValueInner::Time(_) => format_duration(Duration::from(param.value)),
|
||||
};
|
||||
query = query.replace(&format!("${}", index), &s);
|
||||
query = query.replace(&format_placeholder(index), &s);
|
||||
index += 1;
|
||||
}
|
||||
query
|
||||
@@ -331,6 +406,27 @@ fn format_duration(duration: Duration) -> String {
|
||||
format!("{}:{}:{}", hours, minutes, seconds)
|
||||
}
|
||||
|
||||
fn replace_params_with_values(
|
||||
plan: &LogicalPlan,
|
||||
param_types: HashMap<String, Option<ConcreteDataType>>,
|
||||
params: Vec<ParamValue>,
|
||||
) -> Result<LogicalPlan> {
|
||||
debug_assert_eq!(param_types.len(), params.len());
|
||||
|
||||
let mut values = Vec::with_capacity(params.len());
|
||||
|
||||
for (i, param) in params.iter().enumerate() {
|
||||
if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
|
||||
let value = helper::convert_value(param, t)?;
|
||||
|
||||
values.push(value);
|
||||
}
|
||||
}
|
||||
|
||||
plan.replace_params_with_values(&values)
|
||||
.context(error::ReplacePreparedStmtParamsSnafu)
|
||||
}
|
||||
|
||||
async fn validate_query(query: &str) -> Result<Statement> {
|
||||
let statement = ParserContext::create_with_dialect(query, &MySqlDialect {});
|
||||
let mut statement = statement.map_err(|e| {
|
||||
@@ -352,29 +448,27 @@ async fn validate_query(query: &str) -> Result<Statement> {
|
||||
Ok(statement)
|
||||
}
|
||||
|
||||
// dummy columns to satisfy opensrv_mysql, just the number of params is useful
|
||||
// TODO(SSebo): use parameter type inference to return actual types
|
||||
fn dummy_params(index: u32) -> Vec<Column> {
|
||||
let mut params = vec![];
|
||||
fn dummy_params(index: usize) -> Result<Vec<Column>> {
|
||||
let mut params = Vec::with_capacity(index - 1);
|
||||
|
||||
for _ in 1..index {
|
||||
params.push(opensrv_mysql::Column {
|
||||
table: "".to_string(),
|
||||
column: "".to_string(),
|
||||
coltype: ColumnType::MYSQL_TYPE_LONG,
|
||||
colflags: ColumnFlags::NOT_NULL_FLAG,
|
||||
});
|
||||
params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
|
||||
}
|
||||
params
|
||||
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
fn replace_placeholder(query: &str) -> (String, u32) {
|
||||
let mut query = query.to_string();
|
||||
let mut index = 1;
|
||||
while let Some(position) = query.find('?') {
|
||||
let place_holder = format!("${}", index);
|
||||
query.replace_range(position..position + 1, &place_holder);
|
||||
index += 1;
|
||||
/// Parameters that the client must provide when executing the prepared statement.
|
||||
fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
|
||||
let mut params = Vec::with_capacity(param_types.len());
|
||||
|
||||
// Placeholder index starts from 1
|
||||
for index in 1..=param_types.len() {
|
||||
if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
|
||||
let column = create_mysql_column(t, "")?;
|
||||
params.push(column);
|
||||
}
|
||||
}
|
||||
(query, index)
|
||||
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
238
src/servers/src/mysql/helper.rs
Normal file
238
src/servers/src/mysql/helper.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
use std::ops::ControlFlow;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{NaiveDate, NaiveDateTime};
|
||||
use common_query::prelude::ScalarValue;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{self, Value};
|
||||
use itertools::Itertools;
|
||||
use opensrv_mysql::{ParamValue, ValueInner};
|
||||
use snafu::ResultExt;
|
||||
use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
|
||||
/// Returns the placeholder string "$i".
|
||||
pub fn format_placeholder(i: usize) -> String {
|
||||
format!("${}", i)
|
||||
}
|
||||
|
||||
/// Replace all the "?" placeholder into "$i" in SQL,
|
||||
/// returns the new SQL and the last placeholder index.
|
||||
pub fn replace_placeholders(query: &str) -> (String, usize) {
|
||||
let query_parts = query.split('?').collect::<Vec<_>>();
|
||||
let parts_len = query_parts.len();
|
||||
let mut index = 0;
|
||||
let query = query_parts
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, part)| {
|
||||
if i == parts_len - 1 {
|
||||
return part.to_string();
|
||||
}
|
||||
|
||||
index += 1;
|
||||
format!("{part}{}", format_placeholder(index))
|
||||
})
|
||||
.join("");
|
||||
|
||||
(query, index + 1)
|
||||
}
|
||||
|
||||
/// Transform all the "?" placeholder into "$i".
|
||||
/// Only works for Insert,Query and Delete statements.
|
||||
pub fn transform_placeholders(stmt: Statement) -> Statement {
|
||||
match stmt {
|
||||
Statement::Query(mut query) => {
|
||||
visit_placeholders(&mut query.inner);
|
||||
Statement::Query(query)
|
||||
}
|
||||
Statement::Insert(mut insert) => {
|
||||
visit_placeholders(&mut insert.inner);
|
||||
Statement::Insert(insert)
|
||||
}
|
||||
Statement::Delete(mut delete) => {
|
||||
visit_placeholders(&mut delete.inner);
|
||||
Statement::Delete(delete)
|
||||
}
|
||||
stmt => stmt,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_placeholders<V>(v: &mut V)
|
||||
where
|
||||
V: VisitMut,
|
||||
{
|
||||
let mut index = 1;
|
||||
visit_expressions_mut(v, |expr| {
|
||||
if let Expr::Value(ValueExpr::Placeholder(s)) = expr {
|
||||
*s = format_placeholder(index);
|
||||
index += 1;
|
||||
}
|
||||
ControlFlow::<()>::Continue(())
|
||||
});
|
||||
}
|
||||
|
||||
/// Convert [`ParamValue`] into [`Value`] according to param type.
|
||||
/// It will try it's best to do type conversions if possible
|
||||
pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
|
||||
match param.value.into_inner() {
|
||||
ValueInner::Int(i) => match t {
|
||||
ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
|
||||
ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
|
||||
ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
|
||||
ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
|
||||
ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
|
||||
ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
|
||||
ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
|
||||
ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
|
||||
ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
|
||||
ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
|
||||
ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
|
||||
.try_to_scalar_value(t)
|
||||
.context(error::ConvertScalarValueSnafu),
|
||||
|
||||
_ => error::PreparedStmtTypeMismatchSnafu {
|
||||
expected: t,
|
||||
actual: param.coltype,
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
ValueInner::UInt(u) => match t {
|
||||
ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
|
||||
ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
|
||||
ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
|
||||
ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
|
||||
ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
|
||||
ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
|
||||
ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
|
||||
ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
|
||||
ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
|
||||
ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
|
||||
ConcreteDataType::Timestamp(ts_type) => {
|
||||
Value::Timestamp(ts_type.create_timestamp(u as i64))
|
||||
.try_to_scalar_value(t)
|
||||
.context(error::ConvertScalarValueSnafu)
|
||||
}
|
||||
|
||||
_ => error::PreparedStmtTypeMismatchSnafu {
|
||||
expected: t,
|
||||
actual: param.coltype,
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
ValueInner::Double(f) => match t {
|
||||
ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
|
||||
ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
|
||||
ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
|
||||
ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
|
||||
ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
|
||||
ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
|
||||
ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
|
||||
ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
|
||||
ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
|
||||
ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
|
||||
|
||||
_ => error::PreparedStmtTypeMismatchSnafu {
|
||||
expected: t,
|
||||
actual: param.coltype,
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
ValueInner::NULL => Ok(value::to_null_scalar_value(t)),
|
||||
ValueInner::Bytes(b) => match t {
|
||||
ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some(
|
||||
String::from_utf8_lossy(b).to_string(),
|
||||
))),
|
||||
ConcreteDataType::Binary(_) => Ok(ScalarValue::LargeBinary(Some(b.to_vec()))),
|
||||
|
||||
_ => error::PreparedStmtTypeMismatchSnafu {
|
||||
expected: t,
|
||||
actual: param.coltype,
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
ValueInner::Date(_) => {
|
||||
let date: common_time::Date = NaiveDate::from(param.value).into();
|
||||
Ok(ScalarValue::Date32(Some(date.val())))
|
||||
}
|
||||
ValueInner::Datetime(_) => Ok(ScalarValue::Date64(Some(
|
||||
NaiveDateTime::from(param.value).timestamp_millis(),
|
||||
))),
|
||||
ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
|
||||
Duration::from(param.value).as_millis() as i64,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sql::dialect::MySqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_placeholder() {
|
||||
assert_eq!("$1", format_placeholder(1));
|
||||
assert_eq!("$3", format_placeholder(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_placeholders() {
|
||||
let create = "create table demo(host string, ts timestamp time index)";
|
||||
let (sql, index) = replace_placeholders(create);
|
||||
assert_eq!(create, sql);
|
||||
assert_eq!(1, index);
|
||||
|
||||
let insert = "insert into demo values(?,?,?)";
|
||||
let (sql, index) = replace_placeholders(insert);
|
||||
assert_eq!("insert into demo values($1,$2,$3)", sql);
|
||||
assert_eq!(4, index);
|
||||
|
||||
let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
|
||||
let (sql, index) = replace_placeholders(query);
|
||||
assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql);
|
||||
assert_eq!(4, index);
|
||||
}
|
||||
|
||||
fn parse_sql(sql: &str) -> Statement {
|
||||
let mut stmts = ParserContext::create_with_dialect(sql, &MySqlDialect {}).unwrap();
|
||||
stmts.remove(0)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_placeholders() {
|
||||
let insert = parse_sql("insert into demo values(?,?,?)");
|
||||
let Statement::Insert(insert) = transform_placeholders(insert) else { unreachable!()};
|
||||
assert_eq!(
|
||||
"INSERT INTO demo VALUES ($1, $2, $3)",
|
||||
insert.inner.to_string()
|
||||
);
|
||||
|
||||
let delete = parse_sql("delete from demo where host=? and idc=?");
|
||||
let Statement::Delete(delete) = transform_placeholders(delete) else { unreachable!()};
|
||||
assert_eq!(
|
||||
"DELETE FROM demo WHERE host = $1 AND idc = $2",
|
||||
delete.inner.to_string()
|
||||
);
|
||||
|
||||
let select = parse_sql("select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?");
|
||||
let Statement::Query(select) = transform_placeholders(select) else { unreachable!()};
|
||||
assert_eq!("SELECT from AS demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string());
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ use common_query::Output;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use common_telemetry::error;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::{ColumnSchema, SchemaRef};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use opensrv_mysql::{
|
||||
Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter,
|
||||
};
|
||||
@@ -176,8 +176,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
Value::Float64(v) => row_writer.write_col(v.0)?,
|
||||
Value::String(v) => row_writer.write_col(v.as_utf8())?,
|
||||
Value::Binary(v) => row_writer.write_col(v.deref())?,
|
||||
Value::Date(v) => row_writer.write_col(v.val())?,
|
||||
Value::DateTime(v) => row_writer.write_col(v.val())?,
|
||||
Value::Date(v) => row_writer.write_col(v.to_chrono_date())?,
|
||||
Value::DateTime(v) => row_writer.write_col(v.to_chrono_datetime())?,
|
||||
Value::Timestamp(v) => row_writer
|
||||
.write_col(v.to_timezone_aware_string(query_context.time_zone()))?,
|
||||
Value::List(_) => {
|
||||
@@ -208,8 +208,11 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
|
||||
let column_type = match column_schema.data_type {
|
||||
pub(crate) fn create_mysql_column(
|
||||
data_type: &ConcreteDataType,
|
||||
column_name: &str,
|
||||
) -> Result<Column> {
|
||||
let column_type = match data_type {
|
||||
ConcreteDataType::Null(_) => Ok(ColumnType::MYSQL_TYPE_NULL),
|
||||
ConcreteDataType::Boolean(_) | ConcreteDataType::Int8(_) | ConcreteDataType::UInt8(_) => {
|
||||
Ok(ColumnType::MYSQL_TYPE_TINY)
|
||||
@@ -230,15 +233,12 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
|
||||
ConcreteDataType::Date(_) => Ok(ColumnType::MYSQL_TYPE_DATE),
|
||||
ConcreteDataType::DateTime(_) => Ok(ColumnType::MYSQL_TYPE_DATETIME),
|
||||
_ => error::InternalSnafu {
|
||||
err_msg: format!(
|
||||
"not implemented for column datatype {:?}",
|
||||
column_schema.data_type
|
||||
),
|
||||
err_msg: format!("not implemented for column datatype {:?}", data_type),
|
||||
}
|
||||
.fail(),
|
||||
};
|
||||
let mut colflags = ColumnFlags::empty();
|
||||
match column_schema.data_type {
|
||||
match data_type {
|
||||
ConcreteDataType::UInt16(_)
|
||||
| ConcreteDataType::UInt8(_)
|
||||
| ConcreteDataType::UInt32(_)
|
||||
@@ -246,7 +246,7 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result<Column> {
|
||||
_ => {}
|
||||
};
|
||||
column_type.map(|column_type| Column {
|
||||
column: column_schema.name.clone(),
|
||||
column: column_name.to_string(),
|
||||
coltype: column_type,
|
||||
|
||||
// TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server
|
||||
@@ -261,6 +261,6 @@ pub fn create_mysql_column_def(schema: &SchemaRef) -> Result<Vec<Column>> {
|
||||
schema
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.map(create_mysql_column)
|
||||
.map(|column_schema| create_mysql_column(&column_schema.data_type, &column_schema.name))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ use pgwire::api::stmt::QueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use query::query_engine::DescribeResult;
|
||||
use sql::dialect::PostgreSqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
@@ -405,7 +406,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
// get Statement part of the tuple
|
||||
let (stmt, _) = stmt;
|
||||
|
||||
if let Some(schema) = self
|
||||
if let Some(DescribeResult { schema, .. }) = self
|
||||
.query_handler
|
||||
.do_describe(stmt.clone(), self.session.context())
|
||||
.await
|
||||
|
||||
@@ -17,8 +17,8 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use common_error::prelude::*;
|
||||
use common_query::Output;
|
||||
use datatypes::schema::Schema;
|
||||
use query::parser::PromQuery;
|
||||
use query::plan::LogicalPlan;
|
||||
use session::context::QueryContextRef;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
@@ -26,6 +26,7 @@ use crate::error::{self, Result};
|
||||
|
||||
pub type SqlQueryHandlerRef<E> = Arc<dyn SqlQueryHandler<Error = E> + Send + Sync>;
|
||||
pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef<error::Error>;
|
||||
use query::query_engine::DescribeResult;
|
||||
|
||||
#[async_trait]
|
||||
pub trait SqlQueryHandler {
|
||||
@@ -37,6 +38,12 @@ pub trait SqlQueryHandler {
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Vec<std::result::Result<Output, Self::Error>>;
|
||||
|
||||
async fn do_exec_plan(
|
||||
&self,
|
||||
plan: LogicalPlan,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error>;
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
query: &PromQuery,
|
||||
@@ -47,7 +54,7 @@ pub trait SqlQueryHandler {
|
||||
&self,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Option<Schema>, Self::Error>;
|
||||
) -> std::result::Result<Option<DescribeResult>, Self::Error>;
|
||||
|
||||
async fn is_valid_schema(
|
||||
&self,
|
||||
@@ -83,6 +90,14 @@ where
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
self.0
|
||||
.do_exec_plan(plan, query_ctx)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(error::ExecutePlanSnafu)
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
query: &PromQuery,
|
||||
@@ -107,7 +122,7 @@ where
|
||||
&self,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
self.0
|
||||
.do_describe(stmt, query_ctx)
|
||||
.await
|
||||
|
||||
@@ -21,8 +21,9 @@ use axum::{http, Router};
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use common_test_util::ports;
|
||||
use datatypes::schema::Schema;
|
||||
use query::parser::PromQuery;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use servers::error::{Error, Result};
|
||||
use servers::http::{HttpOptions, HttpServerBuilder};
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
@@ -71,6 +72,14 @@ impl SqlQueryHandler for DummyInstance {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_exec_plan(
|
||||
&self,
|
||||
_plan: LogicalPlan,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
_: &PromQuery,
|
||||
@@ -83,7 +92,7 @@ impl SqlQueryHandler for DummyInstance {
|
||||
&self,
|
||||
_stmt: sql::statements::statement::Statement,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,9 @@ use axum::Router;
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use common_test_util::ports;
|
||||
use datatypes::schema::Schema;
|
||||
use query::parser::PromQuery;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use servers::error::{self, Result};
|
||||
use servers::http::{HttpOptions, HttpServerBuilder};
|
||||
use servers::opentsdb::codec::DataPoint;
|
||||
@@ -70,6 +71,14 @@ impl SqlQueryHandler for DummyInstance {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_exec_plan(
|
||||
&self,
|
||||
_plan: LogicalPlan,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
_: &PromQuery,
|
||||
@@ -82,7 +91,7 @@ impl SqlQueryHandler for DummyInstance {
|
||||
&self,
|
||||
_stmt: sql::statements::statement::Statement,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -23,9 +23,10 @@ use axum::Router;
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use common_test_util::ports;
|
||||
use datatypes::schema::Schema;
|
||||
use prost::Message;
|
||||
use query::parser::PromQuery;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use servers::error::{Error, Result};
|
||||
use servers::http::{HttpOptions, HttpServerBuilder};
|
||||
use servers::prometheus;
|
||||
@@ -95,6 +96,14 @@ impl SqlQueryHandler for DummyInstance {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_exec_plan(
|
||||
&self,
|
||||
_plan: LogicalPlan,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
_: &PromQuery,
|
||||
@@ -107,7 +116,7 @@ impl SqlQueryHandler for DummyInstance {
|
||||
&self,
|
||||
_stmt: sql::statements::statement::Statement,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,9 @@ use async_trait::async_trait;
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use datatypes::schema::Schema;
|
||||
use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DescribeResult;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
@@ -78,6 +79,10 @@ impl SqlQueryHandler for DummyInstance {
|
||||
vec![Ok(output)]
|
||||
}
|
||||
|
||||
async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
Ok(self.query_engine.execute(plan, query_ctx).await.unwrap())
|
||||
}
|
||||
|
||||
async fn do_promql_query(
|
||||
&self,
|
||||
_: &PromQuery,
|
||||
@@ -90,7 +95,7 @@ impl SqlQueryHandler for DummyInstance {
|
||||
&self,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Option<Schema>> {
|
||||
) -> Result<Option<DescribeResult>> {
|
||||
if let Statement::Query(_) = stmt {
|
||||
let plan = self
|
||||
.query_engine
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub use sqlparser::ast::{
|
||||
BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function,
|
||||
FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo,
|
||||
Value,
|
||||
visit_expressions_mut, BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType,
|
||||
Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint,
|
||||
TimezoneInfo, Value, VisitMut, Visitor,
|
||||
};
|
||||
|
||||
@@ -13,6 +13,7 @@ axum = "0.6"
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
async-trait = "0.1"
|
||||
catalog = { path = "../src/catalog" }
|
||||
chrono.workspace = true
|
||||
client = { path = "../src/client", features = ["testing"] }
|
||||
common-base = { path = "../src/common/base" }
|
||||
common-catalog = { path = "../src/common/catalog" }
|
||||
@@ -49,6 +50,7 @@ sqlx = { version = "0.6", features = [
|
||||
"runtime-tokio-rustls",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"chrono",
|
||||
] }
|
||||
table = { path = "../src/table" }
|
||||
tempfile.workspace = true
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
|
||||
use sqlx::mysql::MySqlPoolOptions;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::Row;
|
||||
@@ -62,20 +63,24 @@ pub async fn test_mysql_crud(store_type: StorageType) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx::query("create table demo(i bigint, ts timestamp time index)")
|
||||
sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
for i in 0..10 {
|
||||
sqlx::query("insert into demo values(?, ?)")
|
||||
let dt = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp_opt(60, i).unwrap(), Utc);
|
||||
let d = NaiveDate::from_yo_opt(2015, 100).unwrap();
|
||||
sqlx::query("insert into demo values(?, ?, ?, ?)")
|
||||
.bind(i)
|
||||
.bind(i)
|
||||
.bind(d)
|
||||
.bind(dt)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let rows = sqlx::query("select i from demo")
|
||||
let rows = sqlx::query("select i, d, dt from demo")
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -83,7 +88,34 @@ pub async fn test_mysql_crud(store_type: StorageType) {
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let ret: i64 = row.get(0);
|
||||
let d: NaiveDate = row.get(1);
|
||||
let dt: DateTime<Utc> = row.get(2);
|
||||
assert_eq!(ret, i as i64);
|
||||
|
||||
let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap();
|
||||
assert_eq!(expected_d, d);
|
||||
|
||||
let expected_dt = DateTime::<Utc>::from_utc(
|
||||
NaiveDateTime::from_timestamp_opt(60, i as u32).unwrap(),
|
||||
Utc,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", expected_dt.format("%Y-%m-%d %H:%M:%S")),
|
||||
format!("{}", dt.format("%Y-%m-%d %H:%M:%S"))
|
||||
);
|
||||
}
|
||||
|
||||
let rows = sqlx::query("select i from demo where i=?")
|
||||
.bind(6)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1);
|
||||
|
||||
for row in rows {
|
||||
let ret: i64 = row.get(0);
|
||||
assert_eq!(ret, 6);
|
||||
}
|
||||
|
||||
sqlx::query("delete from demo")
|
||||
@@ -133,6 +165,18 @@ pub async fn test_postgres_crud(store_type: StorageType) {
|
||||
assert_eq!(ret, i as i64);
|
||||
}
|
||||
|
||||
let rows = sqlx::query("select i from demo where i=$1")
|
||||
.bind(6)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1);
|
||||
|
||||
for row in rows {
|
||||
let ret: i64 = row.get(0);
|
||||
assert_eq!(ret, 6);
|
||||
}
|
||||
|
||||
sqlx::query("delete from demo")
|
||||
.execute(&pool)
|
||||
.await
|
||||
|
||||
Reference in New Issue
Block a user