diff --git a/src/catalog/src/process_manager.rs b/src/catalog/src/process_manager.rs index 8796d948e1..6a3ae31b25 100644 --- a/src/catalog/src/process_manager.rs +++ b/src/catalog/src/process_manager.rs @@ -58,6 +58,8 @@ pub enum QueryStatement { Sql(Statement), // The optional string is the alias of the PromQL query. Promql(EvalStmt, Option), + /// Logical plan with original query string + Plan(String), } impl Display for QueryStatement { @@ -71,6 +73,7 @@ impl Display for QueryStatement { write!(f, "{}", eval_stmt) } } + QueryStatement::Plan(query) => write!(f, "{}", query), } } } @@ -369,6 +372,9 @@ impl SlowQueryTimer { QueryStatement::Sql(stmt) => { slow_query_event.query = stmt.to_string(); } + QueryStatement::Plan(query) => { + slow_query_event.query = query.clone(); + } } match self.record_type { diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 99444bb2a2..14b8d44831 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -195,7 +195,7 @@ impl Instance { let query_interceptor = self.plugins.get::>(); let query_interceptor = query_interceptor.as_ref(); - if should_capture_statement(Some(&stmt)) { + if stmt.is_readonly() { let slow_query_timer = self .slow_query_options .enable @@ -483,24 +483,16 @@ fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option, - query_ctx: &QueryContextRef, -) -> Option { - match stmt { - Some(s) => derive_timeout(s, query_ctx), - None => { - let query_timeout = query_ctx.query_timeout()?; - if query_timeout.is_zero() { - return None; - } - match query_ctx.channel() { - Channel::Postgres => Some(query_timeout), - _ => None, - } - } +/// Derives timeout for plan execution. +fn derive_timeout_for_plan(plan: &LogicalPlan, query_ctx: &QueryContextRef) -> Option { + let query_timeout = query_ctx.query_timeout()?; + if query_timeout.is_zero() { + return None; + } + match query_ctx.channel() { + Channel::Mysql if is_readonly_plan(plan) => Some(query_timeout), + Channel::Postgres => Some(query_timeout), + _ => None, } } @@ -618,11 +610,10 @@ impl Instance { async fn exec_plan_with_timeout( &self, - stmt: Option, plan: LogicalPlan, query_ctx: QueryContextRef, ) -> Result { - let timeout = derive_timeout_for_plan(stmt.as_ref(), &query_ctx); + let timeout = derive_timeout_for_plan(&plan, &query_ctx); match timeout { Some(timeout) => { let start = tokio::time::Instant::now(); @@ -638,16 +629,13 @@ impl Instance { async fn do_exec_plan_inner( &self, - stmt: Option, plan: LogicalPlan, + query: String, query_ctx: QueryContextRef, ) -> Result { ensure!(!self.is_suspended(), error::SuspendedSnafu); - if should_capture_statement(stmt.as_ref()) { - // It's safe to unwrap here because we've already checked the type. - let stmt = stmt.unwrap(); - let query = stmt.to_string(); + if is_readonly_plan(&plan) { let slow_query_timer = self .slow_query_options .enable @@ -655,7 +643,7 @@ impl Instance { .flatten() .map(|event_recorder| { SlowQueryTimer::new( - CatalogQueryStatement::Sql(stmt.clone()), + CatalogQueryStatement::Plan(query.clone()), self.slow_query_options.threshold, self.slow_query_options.sample_ratio, self.slow_query_options.record_type, @@ -672,7 +660,7 @@ impl Instance { slow_query_timer, ); - let query_fut = self.exec_plan_with_timeout(Some(stmt), plan, query_ctx); + let query_fut = self.exec_plan_with_timeout(plan, query_ctx); CancellableFuture::new(query_fut, ticket.cancellation_handle.clone()) .await @@ -689,7 +677,7 @@ impl Instance { Output { data, meta } }) } else { - self.exec_plan_with_timeout(stmt, plan, query_ctx).await + self.exec_plan_with_timeout(plan, query_ctx).await } } @@ -769,11 +757,11 @@ impl SqlQueryHandler for Instance { async fn do_exec_plan( &self, - stmt: Option, plan: LogicalPlan, + query: String, query_ctx: QueryContextRef, ) -> server_error::Result { - self.do_exec_plan_inner(stmt, plan, query_ctx) + self.do_exec_plan_inner(plan, query, query_ctx) .await .map_err(BoxedError::new) .context(server_error::ExecutePlanSnafu) @@ -1161,13 +1149,8 @@ fn validate_database(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<( .context(SqlExecInterceptedSnafu) } -// Create a query ticket and slow query timer if the statement is a query or readonly statement. -fn should_capture_statement(stmt: Option<&Statement>) -> bool { - if let Some(stmt) = stmt { - matches!(stmt, Statement::Query(_)) || stmt.is_readonly() - } else { - false - } +fn is_readonly_plan(plan: &LogicalPlan) -> bool { + !matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_)) } #[cfg(test)] diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index 70ff50fadc..b544612389 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -129,8 +129,9 @@ impl GrpcQueryHandler for Instance { .decode(bytes::Bytes::from(plan), dummy_catalog_list, true) .await .context(SubstraitDecodeLogicalPlanSnafu)?; + let query = logical_plan.display_indent().to_string(); let output = - self.do_exec_plan_inner(None, logical_plan, ctx.clone()).await?; + self.do_exec_plan_inner(logical_plan, query, ctx.clone()).await?; attach_timer(output, timer) } @@ -466,8 +467,9 @@ impl Instance { // Optimize the plan let optimized_plan = state.optimize(&analyzed_plan).context(DataFusionSnafu)?; + let query = optimized_plan.display_indent().to_string(); let output = self - .do_exec_plan_inner(None, optimized_plan, ctx.clone()) + .do_exec_plan_inner(optimized_plan, query, ctx.clone()) .await?; Ok(attach_timer(output, timer)) diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index e2e577debf..6fc78c59e5 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -54,7 +54,7 @@ use crate::analyze::DistAnalyzeExec; pub use crate::datafusion::planner::DfContextProviderAdapter; use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan}; use crate::error::{ - CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu, + CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu, MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu, TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu, }; @@ -427,15 +427,7 @@ impl QueryEngine for DatafusionQueryEngine { plan: LogicalPlan, _query_ctx: QueryContextRef, ) -> Result { - let schema = plan - .schema() - .clone() - .try_into() - .context(ConvertSchemaSnafu)?; - Ok(DescribeResult { - schema, - logical_plan: plan, - }) + Ok(DescribeResult { logical_plan: plan }) } async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { @@ -876,10 +868,10 @@ mod tests { .await .unwrap(); - let DescribeResult { - schema, - logical_plan, - } = engine.describe(plan, QueryContext::arc()).await.unwrap(); + let DescribeResult { logical_plan } = + engine.describe(plan, QueryContext::arc()).await.unwrap(); + + let schema: Schema = logical_plan.schema().clone().try_into().unwrap(); assert_eq!( schema.column_schemas()[0], diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index c6199e872f..1415998486 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -31,7 +31,6 @@ use common_query::Output; use datafusion::catalog::TableFunction; use datafusion::dataframe::DataFrame; use datafusion_expr::{AggregateUDF, LogicalPlan, WindowUDF}; -use datatypes::schema::Schema; pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer}; use partition::manager::PartitionRuleManagerRef; use session::context::QueryContextRef; @@ -48,8 +47,6 @@ use crate::region_query::RegionQueryHandlerRef; /// Describe statement result #[derive(Debug)] pub struct DescribeResult { - /// The schema of statement - pub schema: Schema, /// The logical plan for statement pub logical_plan: LogicalPlan, } diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index d4366e5d1f..7d8c01f048 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = [] dashboard = ["dep:rust-embed"] +enterprise = ["sql/enterprise"] mem-prof = ["dep:common-mem-prof"] pprof = ["dep:common-pprof"] testing = [] diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 6e3209cc8a..af3474f942 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -1329,7 +1329,6 @@ mod test { use query::parser::PromQuery; use query::query_engine::DescribeResult; use session::context::QueryContextRef; - use sql::statements::statement::Statement; use tokio::sync::mpsc; use tokio::time::Instant; @@ -1354,8 +1353,8 @@ mod test { async fn do_exec_plan( &self, - _stmt: Option, _plan: LogicalPlan, + _query: String, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 7221cfc6e2..a11cdd0af2 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -16,7 +16,6 @@ #![feature(exclusive_wrapper)] use datafusion_expr::LogicalPlan; -use datatypes::schema::Schema; use sql::statements::statement::Statement; // Re-export for use in add_service! macro #[doc(hidden)] @@ -54,13 +53,19 @@ mod row_writer; pub mod server; pub mod tls; -/// Cached SQL and logical plan for database interfaces +/// Cached sql plan or statement for database interfaces #[derive(Clone, Debug)] -pub struct SqlPlan { - pub(crate) query: String, - pub(crate) statement: Option, - pub(crate) plan: Option, - pub(crate) schema: Option, +pub enum SqlPlan { + /// Empty Query + Empty, + /// Hardcoded SQL shortcuts + Shortcut(String), + /// Datafusion parsed execution plan with the original query string + Plan(LogicalPlan, String), + /// Parsed statement when execution is not managed by datafusion + /// eg. CREATE TABLE + /// The String is the original query string to avoid AST round-trip issues + Statement(Statement, String), } /// Install the ring crypto provider for rustls process-wide. see: diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 9bd54a2de4..cdfbb5ebe0 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -28,6 +28,7 @@ use common_telemetry::{debug, error, tracing, warn}; use datafusion_common::ParamValues; use datafusion_expr::LogicalPlan; use datatypes::prelude::ConcreteDataType; +use datatypes::schema::Schema; use itertools::Itertools; use opensrv_mysql::{ AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter, @@ -138,23 +139,6 @@ impl MysqlInstanceShim { } } - /// Execute the logical plan and return the output - async fn do_exec_plan( - &self, - query: &str, - stmt: Option, - plan: LogicalPlan, - query_ctx: QueryContextRef, - ) -> Result { - if let Some(output) = - crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone()) - { - Ok(output) - } else { - self.query_handler.do_exec_plan(stmt, plan, query_ctx).await - } - } - /// Describe the statement async fn do_describe( &self, @@ -198,6 +182,16 @@ impl MysqlInstanceShim { query_ctx: QueryContextRef, stmt_key: String, ) -> Result<(Vec, Vec)> { + if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone()) + .is_some() + { + self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key) + .inspect_err(|e| { + error!(e; "Failed to save prepared statement"); + })?; + return Ok((vec![], vec![])); + } + let (query, param_num) = replace_placeholders(raw_query); let statement = validate_query(raw_query).await?; @@ -209,15 +203,7 @@ impl MysqlInstanceShim { let describe_result = self .do_describe(statement.clone(), query_ctx.clone()) .await?; - let (plan, schema) = if let Some(DescribeResult { - logical_plan, - schema, - }) = describe_result - { - (Some(logical_plan), Some(schema)) - } else { - (None, None) - }; + let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan); let params = if let Some(plan) = &plan { let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) @@ -230,49 +216,41 @@ impl MysqlInstanceShim { dummy_params(param_num)? }; - let columns = schema - .as_ref() - .map(|schema| { - schema - .column_schemas() - .iter() - .map(|column_schema| { - create_mysql_column(&column_schema.data_type, &column_schema.name) - }) - .collect::>>() - }) - .transpose()? - .unwrap_or_default(); + let columns = + plan.as_ref() + .map(|plan| { + let schema: Schema = plan.schema().clone().try_into().map_err( + |e: datatypes::error::Error| { + error::InternalSnafu { + err_msg: e.to_string(), + } + .build() + }, + )?; + schema + .column_schemas() + .iter() + .map(|column_schema| { + create_mysql_column(&column_schema.data_type, &column_schema.name) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or_default(); - // DataFusion may optimize the plan so that some parameters are not used. - if params.len() != param_num - 1 { - self.save_plan( - SqlPlan { - query: query.clone(), - statement: Some(statement), - plan: None, - schema: None, - }, - stmt_key, - ) - .map_err(|e| { - error!(e; "Failed to save prepared statement"); - e - })?; - } else { - self.save_plan( - SqlPlan { - query: query.clone(), - statement: Some(statement), - plan, - schema, - }, - stmt_key, - ) - .map_err(|e| { - error!(e; "Failed to save prepared statement"); - e - })?; + match plan { + Some(plan) if params.len() == param_num - 1 => { + self.save_plan(SqlPlan::Plan(plan, query.clone()), stmt_key) + .inspect_err(|e| { + error!(e; "Failed to save prepared statement"); + })?; + } + _ => { + self.save_plan(SqlPlan::Statement(statement, query), stmt_key) + .inspect_err(|e| { + error!(e; "Failed to save prepared statement"); + })?; + } } Ok((params, columns)) @@ -291,8 +269,8 @@ impl MysqlInstanceShim { Some(sql_plan) => sql_plan, }; - let outputs = match sql_plan.plan { - Some(plan) => { + let outputs = match sql_plan { + SqlPlan::Plan(plan, query) => { let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan) .context(InferParameterTypesSnafu)? .into_iter() @@ -306,7 +284,7 @@ impl MysqlInstanceShim { .fail(); } - let plan = match params { + let replaced_plan = match params { Params::ProtocolParams(params) => { replace_params_with_values(&plan, param_types, ¶ms) } @@ -315,18 +293,26 @@ impl MysqlInstanceShim { } }?; - debug!("Mysql execute prepared plan: {}", plan.display_indent()); + debug!( + "Mysql execute prepared plan: {}", + replaced_plan.display_indent() + ); vec![ - self.do_exec_plan( - &sql_plan.query, - sql_plan.statement.clone(), - plan, - query_ctx.clone(), - ) - .await, + self.query_handler + .do_exec_plan(replaced_plan, query, query_ctx.clone()) + .await, ] } - None => { + SqlPlan::Shortcut(query) => { + if let Some(output) = + crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone()) + { + vec![Ok(output)] + } else { + self.do_query(&query, query_ctx.clone()).await + } + } + SqlPlan::Statement(_stmt, query) => { let param_strs = match params { Params::ProtocolParams(params) => { params.iter().map(convert_param_value_to_string).collect() @@ -335,12 +321,15 @@ impl MysqlInstanceShim { }; debug!( "do_execute Replacing with Params: {:?}, Original Query: {}", - param_strs, sql_plan.query + param_strs, query ); - let query = replace_params(param_strs, sql_plan.query); + let query = replace_params(param_strs, query); debug!("Mysql execute replaced query: {}", query); self.do_query(&query, query_ctx.clone()).await } + _ => { + return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail(); + } }; Ok(outputs) @@ -802,3 +791,152 @@ fn prepared_params(param_types: &HashMap>) -> R Ok(params) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use async_trait::async_trait; + use common_query::Output; + use datafusion_expr::LogicalPlan; + use query::parser::PromQuery; + use query::query_engine::DescribeResult; + use session::context::QueryContext; + use sql::statements::statement::Statement; + + use super::*; + use crate::error::Result; + use crate::query_handler::sql::SqlQueryHandler; + + struct DummyQueryHandler; + + #[async_trait] + impl SqlQueryHandler for DummyQueryHandler { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_exec_plan( + &self, + _: LogicalPlan, + _: String, + _: QueryContextRef, + ) -> Result { + unimplemented!() + } + + async fn do_describe( + &self, + _: Statement, + _: QueryContextRef, + ) -> Result> { + unimplemented!() + } + + async fn is_valid_schema(&self, _: &str, _: &str) -> Result { + Ok(true) + } + } + + fn create_shim() -> MysqlInstanceShim { + MysqlInstanceShim::create( + Arc::new(DummyQueryHandler), + None, + "127.0.0.1:3306".parse().unwrap(), + 1, + 1024, + ) + } + + #[tokio::test] + async fn test_prepare_federated_query() { + let mut shim = create_shim(); + let query_ctx = QueryContext::arc(); + let stmt_key = "test_federated".to_string(); + + let (params, columns) = shim + .do_prepare( + "SELECT @@version_comment", + query_ctx.clone(), + stmt_key.clone(), + ) + .await + .unwrap(); + + assert!(params.is_empty()); + assert!(columns.is_empty()); + + let plan = shim.plan(&stmt_key).unwrap(); + assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment")); + } + + #[tokio::test] + async fn test_execute_federated_shortcut() { + let mut shim = create_shim(); + let query_ctx = QueryContext::arc(); + let stmt_key = "test_federated_exec".to_string(); + + shim.do_prepare( + "SELECT @@version_comment", + query_ctx.clone(), + stmt_key.clone(), + ) + .await + .unwrap(); + + let outputs = shim + .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![])) + .await + .unwrap(); + + assert_eq!(outputs.len(), 1); + let output = outputs.into_iter().next().unwrap().unwrap(); + let pretty = output.data.pretty_print().await; + assert!(pretty.contains("GreptimeDB")); + } + + #[tokio::test] + async fn test_prepare_non_federated_query_not_shortcut() { + let mut shim = create_shim(); + let query_ctx = QueryContext::arc(); + let stmt_key = "test_non_federated".to_string(); + + let result = shim + .do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone()) + .await; + + assert!(result.is_ok()); + let plan = shim.plan(&stmt_key).unwrap(); + assert!(matches!(plan, SqlPlan::Shortcut(_))); + } + + #[tokio::test] + async fn test_execute_set_shortcut() { + let mut shim = create_shim(); + let query_ctx = QueryContext::arc(); + let stmt_key = "test_set_shortcut".to_string(); + + shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone()) + .await + .unwrap(); + + let outputs = shim + .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![])) + .await + .unwrap(); + + assert_eq!(outputs.len(), 1); + let output = outputs.into_iter().next().unwrap().unwrap(); + match output.data { + common_query::OutputData::RecordBatches(batches) => { + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0); + } + other => panic!("Expected RecordBatches, got {:?}", other), + } + } +} diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 3275ebfdb8..b4780cda4e 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -23,9 +23,10 @@ use common_recordbatch::error::Result as RecordBatchResult; use common_telemetry::{debug, info, tracing}; use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement}; use datafusion_common::ParamValues; +use datafusion_expr::LogicalPlan; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; use datatypes::prelude::ConcreteDataType; -use datatypes::schema::SchemaRef; +use datatypes::schema::{Schema, SchemaRef}; use futures::{Sink, SinkExt, Stream, StreamExt, future, stream}; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; @@ -318,14 +319,16 @@ impl QueryParser for DefaultQueryParser { let query_ctx = self.session.new_query_context(); // do not parse if query is empty or matches rules - if sql.is_empty() || fixtures::matches(sql) { + if sql.is_empty() { return Ok(PgSqlPlan { - plan: SqlPlan { - query: sql.to_owned(), - statement: None, - plan: None, - schema: None, - }, + plan: SqlPlan::Empty, + copy_to_stdout_format: None, + }); + } + + if fixtures::matches(sql) { + return Ok(PgSqlPlan { + plan: SqlPlan::Shortcut(sql.to_string()), copy_to_stdout_format: None, }); } @@ -354,31 +357,23 @@ impl QueryParser for DefaultQueryParser { } else { let stmt = stmts.remove(0); - let describe_result = self + if let Some(logical_plan) = self .query_handler .do_describe(stmt.clone(), query_ctx) .await - .map_err(convert_err)?; - - let (plan, schema) = if let Some(DescribeResult { - logical_plan, - schema, - }) = describe_result + .map_err(convert_err)? + .map(|DescribeResult { logical_plan }| logical_plan) { - (Some(logical_plan), Some(schema)) + Ok(PgSqlPlan { + plan: SqlPlan::Plan(logical_plan, sql), + copy_to_stdout_format, + }) } else { - (None, None) - }; - - Ok(PgSqlPlan { - plan: SqlPlan { - query: sql.clone(), - statement: Some(stmt), - plan, - schema, - }, - copy_to_stdout_format, - }) + Ok(PgSqlPlan { + plan: SqlPlan::Statement(stmt, sql), + copy_to_stdout_format, + }) + } } } @@ -432,39 +427,45 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { let pg_sql_plan = &portal.statement.statement; let sql_plan = &pg_sql_plan.plan; - if sql_plan.query.is_empty() { - // early return if query is empty - return Ok(Response::EmptyQuery); - } - - if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) { - send_warning_opt(client, query_ctx).await?; - // if the statement matches our predefined rules, return it early - return Ok(resps.remove(0)); - } - - let output = if let Some(plan) = &sql_plan.plan { - let values = parameters_to_scalar_values(plan, portal)?; - let plan = plan - .clone() - .replace_params_with_values(&ParamValues::List( - values.into_iter().map(Into::into).collect(), - )) - .context(DataFusionSnafu) - .map_err(convert_err)?; - self.query_handler - .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone()) - .await - } else { - // We won't replace params from statement manually any more. - // Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE. - // Only CREATE TABLE and others minor statements cannot generate sql plan, - // in this case, we assume these statements will not carry parameters - // and execute them directly. - self.query_handler - .do_query(&sql_plan.query, query_ctx.clone()) - .await - .remove(0) + let output = match sql_plan { + SqlPlan::Empty => { + // early return if query is empty + return Ok(Response::EmptyQuery); + } + SqlPlan::Shortcut(query) => { + if let Some(mut resps) = fixtures::process(query, query_ctx.clone()) { + send_warning_opt(client, query_ctx).await?; + // if the statement matches our predefined rules, return it early + return Ok(resps.remove(0)); + } else { + // unreachable logic + return Ok(Response::EmptyQuery); + } + } + SqlPlan::Plan(plan, query) => { + let values = parameters_to_scalar_values(plan, portal)?; + let plan = plan + .clone() + .replace_params_with_values(&ParamValues::List( + values.into_iter().map(Into::into).collect(), + )) + .context(DataFusionSnafu) + .map_err(convert_err)?; + self.query_handler + .do_exec_plan(plan, query.clone(), query_ctx.clone()) + .await + } + SqlPlan::Statement(_stmt, query) => { + // We won't replace params from statement manually any more. + // Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE. + // Only CREATE TABLE and others minor statements cannot generate sql plan, + // in this case, we assume these statements will not carry parameters + // and execute them directly. + self.query_handler + .do_query(query, query_ctx.clone()) + .await + .remove(0) + } }; send_warning_opt(client, query_ctx.clone()).await?; @@ -487,7 +488,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { let sql_plan = &stmt.statement.plan; // client provided parameter types, can be empty if client doesn't try to parse statement let provided_param_types = &stmt.parameter_types; - let server_inferenced_types = if let Some(plan) = &sql_plan.plan { + let server_inferenced_types = if let SqlPlan::Plan(plan, _) = &sql_plan { let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan) .context(InferParameterTypesSnafu) .map_err(convert_err)? @@ -525,23 +526,9 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { }) .collect::>(); - if let Some(schema) = &sql_plan.schema { - schema_to_pg(schema, &Format::UnifiedText, None) - .map(|fields| DescribeStatementResponse::new(param_types, fields)) - .map_err(convert_err) - } else { - if let Some(mut resp) = - fixtures::process(&sql_plan.query, self.session.new_query_context()) - && let Response::Query(query_response) = resp.remove(0) - { - return Ok(DescribeStatementResponse::new( - param_types, - (*query_response.row_schema()).clone(), - )); - } + let fields = describe_fields(sql_plan, &Format::UnifiedText, &self.session)?; - Ok(DescribeStatementResponse::new(param_types, vec![])) - } + Ok(DescribeStatementResponse::new(param_types, fields)) } async fn do_describe_portal( @@ -555,68 +542,100 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { let sql_plan = &portal.statement.statement.plan; let format = &portal.result_column_format; - match sql_plan.statement.as_ref() { - Some(Statement::Query(_)) => { - // if the query has a schema, it is managed by datafusion, use the schema - if let Some(schema) = &sql_plan.schema { - schema_to_pg(schema, format, None) - .map(DescribePortalResponse::new) - .map_err(convert_err) - } else { - // fallback to NoData - Ok(DescribePortalResponse::new(vec![])) - } - } - // We can cover only part of show statements - // these show create statements will return 2 columns - Some(Statement::ShowCreateDatabase(_)) - | Some(Statement::ShowCreateTable(_)) - | Some(Statement::ShowCreateFlow(_)) - | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![ - FieldInfo::new( - "name".to_string(), - None, - None, - Type::TEXT, - format.format_for(0), - ), - FieldInfo::new( - "create_statement".to_string(), - None, - None, - Type::TEXT, - format.format_for(1), - ), - ])), - // single column show statements - Some(Statement::ShowTables(_)) - | Some(Statement::ShowFlows(_)) - | Some(Statement::ShowViews(_)) => { - Ok(DescribePortalResponse::new(vec![FieldInfo::new( - "name".to_string(), - None, - None, - Type::TEXT, - format.format_for(0), - )])) - } - // we will not support other show statements for extended query protocol at least for now. - // because the return columns is not predictable at this stage - _ => { - // test if query caught by fixture - if let Some(mut resp) = - fixtures::process(&sql_plan.query, self.session.new_query_context()) - && let Response::Query(query_response) = resp.remove(0) - { - Ok(DescribePortalResponse::new( - (*query_response.row_schema()).clone(), - )) - } else { - // fallback to NoData - Ok(DescribePortalResponse::new(vec![])) - } + let fields = describe_fields(sql_plan, format, &self.session)?; + + Ok(DescribePortalResponse::new(fields)) + } +} + +fn describe_fields( + sql_plan: &SqlPlan, + format: &Format, + session: &Arc, +) -> PgWireResult> { + match sql_plan { + // query + SqlPlan::Plan(plan, _) if !matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_)) => { + let schema: Schema = plan.schema().clone().try_into().map_err(convert_err)?; + schema_to_pg(&schema, format, None).map_err(convert_err) + } + // We can cover only part of show statements + // these show create statements will return 2 columns + SqlPlan::Statement( + Statement::ShowCreateDatabase(_) + | Statement::ShowCreateTable(_) + | Statement::ShowCreateFlow(_) + | Statement::ShowCreateView(_), + _, + ) => Ok(vec![ + FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + ), + FieldInfo::new( + "create_statement".to_string(), + None, + None, + Type::TEXT, + format.format_for(1), + ), + ]), + #[cfg(feature = "enterprise")] + SqlPlan::Statement(Statement::ShowCreateTrigger(_), _) => Ok(vec![ + FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + ), + FieldInfo::new( + "create_statement".to_string(), + None, + None, + Type::TEXT, + format.format_for(1), + ), + ]), + // single column show statements + SqlPlan::Statement( + Statement::ShowTables(_) | Statement::ShowFlows(_) | Statement::ShowViews(_), + _, + ) => Ok(vec![FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + )]), + #[cfg(feature = "enterprise")] + SqlPlan::Statement(Statement::ShowTriggers(_), _) => Ok(vec![FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + )]), + // we will not support other show statements for extended query protocol at least for now. + // because the return columns is not predictable at this stage + SqlPlan::Shortcut(query) => { + // test if query caught by fixture + if let Some(mut resp) = fixtures::process(query, session.new_query_context()) + && let Response::Query(query_response) = resp.remove(0) + { + Ok((*query_response.row_schema()).clone()) + } else { + // fallback to NoData + Ok(vec![]) } } + _ => { + // NoData + Ok(vec![]) + } } } diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 33fc41164b..887300a573 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -1732,12 +1732,7 @@ mod test { let statement = Arc::new(StoredStatement::new( String::new(), PgSqlPlan { - plan: SqlPlan { - query: String::new(), - statement: None, - plan: None, - schema: None, - }, + plan: SqlPlan::Empty, copy_to_stdout_format: None, }, client_param_types, diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index 544f5bceb6..abc04342c3 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -32,8 +32,8 @@ pub trait SqlQueryHandler { async fn do_exec_plan( &self, - stmt: Option, plan: LogicalPlan, + query: String, query_ctx: QueryContextRef, ) -> Result; diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 2939d2a11c..df0422eacd 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -31,7 +31,6 @@ use servers::influxdb::InfluxdbRequest; use servers::query_handler::InfluxdbLineProtocolHandler; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; -use sql::statements::statement::Statement; use tokio::sync::mpsc; struct DummyInstance { @@ -58,8 +57,8 @@ impl SqlQueryHandler for DummyInstance { async fn do_exec_plan( &self, - _stmt: Option, _plan: LogicalPlan, + _query: String, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index a7a044d4db..7d4059d4b4 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -28,7 +28,6 @@ use servers::opentsdb::codec::DataPoint; use servers::query_handler::OpentsdbProtocolHandler; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; -use sql::statements::statement::Statement; use tokio::sync::mpsc; struct DummyInstance { @@ -58,8 +57,8 @@ impl SqlQueryHandler for DummyInstance { async fn do_exec_plan( &self, - _stmt: Option, _plan: LogicalPlan, + _query: String, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/http/prom_store_test.rs b/src/servers/tests/http/prom_store_test.rs index c5d5207486..07635eca29 100644 --- a/src/servers/tests/http/prom_store_test.rs +++ b/src/servers/tests/http/prom_store_test.rs @@ -36,7 +36,6 @@ use servers::prom_store::{Metrics, snappy_compress}; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse}; use session::context::QueryContextRef; -use sql::statements::statement::Statement; use tokio::sync::mpsc; struct DummyInstance { @@ -87,8 +86,8 @@ impl SqlQueryHandler for DummyInstance { async fn do_exec_plan( &self, - _stmt: Option, _plan: LogicalPlan, + _query: String, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 3ecf2acf73..0933fed22c 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -82,8 +82,8 @@ impl SqlQueryHandler for DummyInstance { async fn do_exec_plan( &self, - _stmt: Option, plan: LogicalPlan, + _query: String, query_ctx: QueryContextRef, ) -> Result { Ok(self.query_engine.execute(plan, query_ctx).await.unwrap()) diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index aa0062e4fd..0416316053 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -85,6 +85,7 @@ macro_rules! sql_tests { test_postgres_uint64_parameter, test_postgres_array_types, test_mysql_prepare_stmt_insert_timestamp, + test_mysql_federated_prepare_stmt, test_declare_fetch_close_cursor, test_alter_update_on, ); @@ -1586,6 +1587,45 @@ pub async fn test_mysql_prepare_stmt_insert_timestamp(store_type: StorageType) { guard.remove_all().await; } +pub async fn test_mysql_federated_prepare_stmt(store_type: StorageType) { + common_telemetry::init_default_ut_logging(); + + let (mut guard, fe_mysql_server) = + setup_mysql_server(store_type, "test_mysql_federated_prepare_stmt").await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); + + let pool = MySqlPoolOptions::new() + .max_connections(2) + .connect(&format!("mysql://{addr}/public")) + .await + .unwrap(); + + // sqlx::query uses binary prepared statement protocol (COM_STMT_PREPARE + COM_STMT_EXECUTE) + // "SELECT @@version_comment" is a federated query matched by federated::check + let rows = sqlx::query("SELECT @@version_comment") + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + let val: String = rows[0].get(0); + assert!(val.contains("GreptimeDB")); + + // "SET NAMES utf8" is another federated pattern + sqlx::query("SET NAMES utf8").execute(&pool).await.unwrap(); + + // "SELECT @@tx_isolation" is a federated variable query + let rows = sqlx::query("SELECT @@tx_isolation") + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + let val: String = rows[0].get(0); + assert_eq!(val, "REPEATABLE-READ"); + + let _ = fe_mysql_server.shutdown().await; + guard.remove_all().await; +} + pub async fn test_postgres_array_types(store_type: StorageType) { let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_array_types").await; let addr = fe_pg_server.bind_addr().unwrap().to_string();