refactor: update SqlPlan for more cleaner variants (#7966)

* refactor: update SqlPlan for more cleaner variants

* refactor: change how we check readonly plan

* fix: don't return schema for non-query statement

* chore: reflect review comments

* fix: federated statements
This commit is contained in:
Ning Sun
2026-04-21 19:50:47 +08:00
committed by GitHub
parent 449243a175
commit 80c395ee23
17 changed files with 474 additions and 300 deletions

View File

@@ -58,6 +58,8 @@ pub enum QueryStatement {
Sql(Statement),
// The optional string is the alias of the PromQL query.
Promql(EvalStmt, Option<String>),
/// Logical plan with original query string
Plan(String),
}
impl Display for QueryStatement {
@@ -71,6 +73,7 @@ impl Display for QueryStatement {
write!(f, "{}", eval_stmt)
}
}
QueryStatement::Plan(query) => write!(f, "{}", query),
}
}
}
@@ -369,6 +372,9 @@ impl SlowQueryTimer {
QueryStatement::Sql(stmt) => {
slow_query_event.query = stmt.to_string();
}
QueryStatement::Plan(query) => {
slow_query_event.query = query.clone();
}
}
match self.record_type {

View File

@@ -195,7 +195,7 @@ impl Instance {
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
let query_interceptor = query_interceptor.as_ref();
if should_capture_statement(Some(&stmt)) {
if stmt.is_readonly() {
let slow_query_timer = self
.slow_query_options
.enable
@@ -483,24 +483,16 @@ fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option<Durat
}
}
/// Derives timeout for plan execution. When statement is not available,
/// applies timeout for PostgreSQL only (can't determine readonly status without statement).
fn derive_timeout_for_plan(
stmt: Option<&Statement>,
query_ctx: &QueryContextRef,
) -> Option<Duration> {
match stmt {
Some(s) => derive_timeout(s, query_ctx),
None => {
let query_timeout = query_ctx.query_timeout()?;
if query_timeout.is_zero() {
return None;
}
match query_ctx.channel() {
Channel::Postgres => Some(query_timeout),
_ => None,
}
}
/// Derives timeout for plan execution.
fn derive_timeout_for_plan(plan: &LogicalPlan, query_ctx: &QueryContextRef) -> Option<Duration> {
let query_timeout = query_ctx.query_timeout()?;
if query_timeout.is_zero() {
return None;
}
match query_ctx.channel() {
Channel::Mysql if is_readonly_plan(plan) => Some(query_timeout),
Channel::Postgres => Some(query_timeout),
_ => None,
}
}
@@ -618,11 +610,10 @@ impl Instance {
async fn exec_plan_with_timeout(
&self,
stmt: Option<Statement>,
plan: LogicalPlan,
query_ctx: QueryContextRef,
) -> Result<Output> {
let timeout = derive_timeout_for_plan(stmt.as_ref(), &query_ctx);
let timeout = derive_timeout_for_plan(&plan, &query_ctx);
match timeout {
Some(timeout) => {
let start = tokio::time::Instant::now();
@@ -638,16 +629,13 @@ impl Instance {
async fn do_exec_plan_inner(
&self,
stmt: Option<Statement>,
plan: LogicalPlan,
query: String,
query_ctx: QueryContextRef,
) -> Result<Output> {
ensure!(!self.is_suspended(), error::SuspendedSnafu);
if should_capture_statement(stmt.as_ref()) {
// It's safe to unwrap here because we've already checked the type.
let stmt = stmt.unwrap();
let query = stmt.to_string();
if is_readonly_plan(&plan) {
let slow_query_timer = self
.slow_query_options
.enable
@@ -655,7 +643,7 @@ impl Instance {
.flatten()
.map(|event_recorder| {
SlowQueryTimer::new(
CatalogQueryStatement::Sql(stmt.clone()),
CatalogQueryStatement::Plan(query.clone()),
self.slow_query_options.threshold,
self.slow_query_options.sample_ratio,
self.slow_query_options.record_type,
@@ -672,7 +660,7 @@ impl Instance {
slow_query_timer,
);
let query_fut = self.exec_plan_with_timeout(Some(stmt), plan, query_ctx);
let query_fut = self.exec_plan_with_timeout(plan, query_ctx);
CancellableFuture::new(query_fut, ticket.cancellation_handle.clone())
.await
@@ -689,7 +677,7 @@ impl Instance {
Output { data, meta }
})
} else {
self.exec_plan_with_timeout(stmt, plan, query_ctx).await
self.exec_plan_with_timeout(plan, query_ctx).await
}
}
@@ -769,11 +757,11 @@ impl SqlQueryHandler for Instance {
async fn do_exec_plan(
&self,
stmt: Option<Statement>,
plan: LogicalPlan,
query: String,
query_ctx: QueryContextRef,
) -> server_error::Result<Output> {
self.do_exec_plan_inner(stmt, plan, query_ctx)
self.do_exec_plan_inner(plan, query, query_ctx)
.await
.map_err(BoxedError::new)
.context(server_error::ExecutePlanSnafu)
@@ -1161,13 +1149,8 @@ fn validate_database(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<(
.context(SqlExecInterceptedSnafu)
}
// Create a query ticket and slow query timer if the statement is a query or readonly statement.
fn should_capture_statement(stmt: Option<&Statement>) -> bool {
if let Some(stmt) = stmt {
matches!(stmt, Statement::Query(_)) || stmt.is_readonly()
} else {
false
}
fn is_readonly_plan(plan: &LogicalPlan) -> bool {
!matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_))
}
#[cfg(test)]

View File

@@ -129,8 +129,9 @@ impl GrpcQueryHandler for Instance {
.decode(bytes::Bytes::from(plan), dummy_catalog_list, true)
.await
.context(SubstraitDecodeLogicalPlanSnafu)?;
let query = logical_plan.display_indent().to_string();
let output =
self.do_exec_plan_inner(None, logical_plan, ctx.clone()).await?;
self.do_exec_plan_inner(logical_plan, query, ctx.clone()).await?;
attach_timer(output, timer)
}
@@ -466,8 +467,9 @@ impl Instance {
// Optimize the plan
let optimized_plan = state.optimize(&analyzed_plan).context(DataFusionSnafu)?;
let query = optimized_plan.display_indent().to_string();
let output = self
.do_exec_plan_inner(None, optimized_plan, ctx.clone())
.do_exec_plan_inner(optimized_plan, query, ctx.clone())
.await?;
Ok(attach_timer(output, timer))

View File

@@ -54,7 +54,7 @@ use crate::analyze::DistAnalyzeExec;
pub use crate::datafusion::planner::DfContextProviderAdapter;
use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
use crate::error::{
CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
};
@@ -427,15 +427,7 @@ impl QueryEngine for DatafusionQueryEngine {
plan: LogicalPlan,
_query_ctx: QueryContextRef,
) -> Result<DescribeResult> {
let schema = plan
.schema()
.clone()
.try_into()
.context(ConvertSchemaSnafu)?;
Ok(DescribeResult {
schema,
logical_plan: plan,
})
Ok(DescribeResult { logical_plan: plan })
}
async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
@@ -876,10 +868,10 @@ mod tests {
.await
.unwrap();
let DescribeResult {
schema,
logical_plan,
} = engine.describe(plan, QueryContext::arc()).await.unwrap();
let DescribeResult { logical_plan } =
engine.describe(plan, QueryContext::arc()).await.unwrap();
let schema: Schema = logical_plan.schema().clone().try_into().unwrap();
assert_eq!(
schema.column_schemas()[0],

View File

@@ -31,7 +31,6 @@ use common_query::Output;
use datafusion::catalog::TableFunction;
use datafusion::dataframe::DataFrame;
use datafusion_expr::{AggregateUDF, LogicalPlan, WindowUDF};
use datatypes::schema::Schema;
pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer};
use partition::manager::PartitionRuleManagerRef;
use session::context::QueryContextRef;
@@ -48,8 +47,6 @@ use crate::region_query::RegionQueryHandlerRef;
/// Describe statement result
#[derive(Debug)]
pub struct DescribeResult {
/// The schema of statement
pub schema: Schema,
/// The logical plan for statement
pub logical_plan: LogicalPlan,
}

View File

@@ -7,6 +7,7 @@ license.workspace = true
[features]
default = []
dashboard = ["dep:rust-embed"]
enterprise = ["sql/enterprise"]
mem-prof = ["dep:common-mem-prof"]
pprof = ["dep:common-pprof"]
testing = []

View File

@@ -1329,7 +1329,6 @@ mod test {
use query::parser::PromQuery;
use query::query_engine::DescribeResult;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use tokio::sync::mpsc;
use tokio::time::Instant;
@@ -1354,8 +1353,8 @@ mod test {
async fn do_exec_plan(
&self,
_stmt: Option<Statement>,
_plan: LogicalPlan,
_query: String,
_query_ctx: QueryContextRef,
) -> Result<Output> {
unimplemented!()

View File

@@ -16,7 +16,6 @@
#![feature(exclusive_wrapper)]
use datafusion_expr::LogicalPlan;
use datatypes::schema::Schema;
use sql::statements::statement::Statement;
// Re-export for use in add_service! macro
#[doc(hidden)]
@@ -54,13 +53,19 @@ mod row_writer;
pub mod server;
pub mod tls;
/// Cached SQL and logical plan for database interfaces
/// Cached sql plan or statement for database interfaces
#[derive(Clone, Debug)]
pub struct SqlPlan {
pub(crate) query: String,
pub(crate) statement: Option<Statement>,
pub(crate) plan: Option<LogicalPlan>,
pub(crate) schema: Option<Schema>,
pub enum SqlPlan {
/// Empty Query
Empty,
/// Hardcoded SQL shortcuts
Shortcut(String),
/// Datafusion parsed execution plan with the original query string
Plan(LogicalPlan, String),
/// Parsed statement when execution is not managed by datafusion
/// eg. CREATE TABLE
/// The String is the original query string to avoid AST round-trip issues
Statement(Statement, String),
}
/// Install the ring crypto provider for rustls process-wide. see:

View File

@@ -28,6 +28,7 @@ use common_telemetry::{debug, error, tracing, warn};
use datafusion_common::ParamValues;
use datafusion_expr::LogicalPlan;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::Schema;
use itertools::Itertools;
use opensrv_mysql::{
AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
@@ -138,23 +139,6 @@ impl MysqlInstanceShim {
}
}
/// Execute the logical plan and return the output
async fn do_exec_plan(
&self,
query: &str,
stmt: Option<Statement>,
plan: LogicalPlan,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(output) =
crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
{
Ok(output)
} else {
self.query_handler.do_exec_plan(stmt, plan, query_ctx).await
}
}
/// Describe the statement
async fn do_describe(
&self,
@@ -198,6 +182,16 @@ impl MysqlInstanceShim {
query_ctx: QueryContextRef,
stmt_key: String,
) -> Result<(Vec<Column>, Vec<Column>)> {
if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone())
.is_some()
{
self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key)
.inspect_err(|e| {
error!(e; "Failed to save prepared statement");
})?;
return Ok((vec![], vec![]));
}
let (query, param_num) = replace_placeholders(raw_query);
let statement = validate_query(raw_query).await?;
@@ -209,15 +203,7 @@ impl MysqlInstanceShim {
let describe_result = self
.do_describe(statement.clone(), query_ctx.clone())
.await?;
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
}) = describe_result
{
(Some(logical_plan), Some(schema))
} else {
(None, None)
};
let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
let params = if let Some(plan) = &plan {
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
@@ -230,49 +216,41 @@ impl MysqlInstanceShim {
dummy_params(param_num)?
};
let columns = schema
.as_ref()
.map(|schema| {
schema
.column_schemas()
.iter()
.map(|column_schema| {
create_mysql_column(&column_schema.data_type, &column_schema.name)
})
.collect::<Result<Vec<_>>>()
})
.transpose()?
.unwrap_or_default();
let columns =
plan.as_ref()
.map(|plan| {
let schema: Schema = plan.schema().clone().try_into().map_err(
|e: datatypes::error::Error| {
error::InternalSnafu {
err_msg: e.to_string(),
}
.build()
},
)?;
schema
.column_schemas()
.iter()
.map(|column_schema| {
create_mysql_column(&column_schema.data_type, &column_schema.name)
})
.collect::<Result<Vec<_>>>()
})
.transpose()?
.unwrap_or_default();
// DataFusion may optimize the plan so that some parameters are not used.
if params.len() != param_num - 1 {
self.save_plan(
SqlPlan {
query: query.clone(),
statement: Some(statement),
plan: None,
schema: None,
},
stmt_key,
)
.map_err(|e| {
error!(e; "Failed to save prepared statement");
e
})?;
} else {
self.save_plan(
SqlPlan {
query: query.clone(),
statement: Some(statement),
plan,
schema,
},
stmt_key,
)
.map_err(|e| {
error!(e; "Failed to save prepared statement");
e
})?;
match plan {
Some(plan) if params.len() == param_num - 1 => {
self.save_plan(SqlPlan::Plan(plan, query.clone()), stmt_key)
.inspect_err(|e| {
error!(e; "Failed to save prepared statement");
})?;
}
_ => {
self.save_plan(SqlPlan::Statement(statement, query), stmt_key)
.inspect_err(|e| {
error!(e; "Failed to save prepared statement");
})?;
}
}
Ok((params, columns))
@@ -291,8 +269,8 @@ impl MysqlInstanceShim {
Some(sql_plan) => sql_plan,
};
let outputs = match sql_plan.plan {
Some(plan) => {
let outputs = match sql_plan {
SqlPlan::Plan(plan, query) => {
let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
.context(InferParameterTypesSnafu)?
.into_iter()
@@ -306,7 +284,7 @@ impl MysqlInstanceShim {
.fail();
}
let plan = match params {
let replaced_plan = match params {
Params::ProtocolParams(params) => {
replace_params_with_values(&plan, param_types, &params)
}
@@ -315,18 +293,26 @@ impl MysqlInstanceShim {
}
}?;
debug!("Mysql execute prepared plan: {}", plan.display_indent());
debug!(
"Mysql execute prepared plan: {}",
replaced_plan.display_indent()
);
vec![
self.do_exec_plan(
&sql_plan.query,
sql_plan.statement.clone(),
plan,
query_ctx.clone(),
)
.await,
self.query_handler
.do_exec_plan(replaced_plan, query, query_ctx.clone())
.await,
]
}
None => {
SqlPlan::Shortcut(query) => {
if let Some(output) =
crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone())
{
vec![Ok(output)]
} else {
self.do_query(&query, query_ctx.clone()).await
}
}
SqlPlan::Statement(_stmt, query) => {
let param_strs = match params {
Params::ProtocolParams(params) => {
params.iter().map(convert_param_value_to_string).collect()
@@ -335,12 +321,15 @@ impl MysqlInstanceShim {
};
debug!(
"do_execute Replacing with Params: {:?}, Original Query: {}",
param_strs, sql_plan.query
param_strs, query
);
let query = replace_params(param_strs, sql_plan.query);
let query = replace_params(param_strs, query);
debug!("Mysql execute replaced query: {}", query);
self.do_query(&query, query_ctx.clone()).await
}
_ => {
return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
}
};
Ok(outputs)
@@ -802,3 +791,152 @@ fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> R
Ok(params)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use async_trait::async_trait;
use common_query::Output;
use datafusion_expr::LogicalPlan;
use query::parser::PromQuery;
use query::query_engine::DescribeResult;
use session::context::QueryContext;
use sql::statements::statement::Statement;
use super::*;
use crate::error::Result;
use crate::query_handler::sql::SqlQueryHandler;
struct DummyQueryHandler;
#[async_trait]
impl SqlQueryHandler for DummyQueryHandler {
async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
unimplemented!()
}
async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
unimplemented!()
}
async fn do_exec_plan(
&self,
_: LogicalPlan,
_: String,
_: QueryContextRef,
) -> Result<Output> {
unimplemented!()
}
async fn do_describe(
&self,
_: Statement,
_: QueryContextRef,
) -> Result<Option<DescribeResult>> {
unimplemented!()
}
async fn is_valid_schema(&self, _: &str, _: &str) -> Result<bool> {
Ok(true)
}
}
fn create_shim() -> MysqlInstanceShim {
MysqlInstanceShim::create(
Arc::new(DummyQueryHandler),
None,
"127.0.0.1:3306".parse().unwrap(),
1,
1024,
)
}
#[tokio::test]
async fn test_prepare_federated_query() {
let mut shim = create_shim();
let query_ctx = QueryContext::arc();
let stmt_key = "test_federated".to_string();
let (params, columns) = shim
.do_prepare(
"SELECT @@version_comment",
query_ctx.clone(),
stmt_key.clone(),
)
.await
.unwrap();
assert!(params.is_empty());
assert!(columns.is_empty());
let plan = shim.plan(&stmt_key).unwrap();
assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment"));
}
#[tokio::test]
async fn test_execute_federated_shortcut() {
let mut shim = create_shim();
let query_ctx = QueryContext::arc();
let stmt_key = "test_federated_exec".to_string();
shim.do_prepare(
"SELECT @@version_comment",
query_ctx.clone(),
stmt_key.clone(),
)
.await
.unwrap();
let outputs = shim
.do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
.await
.unwrap();
assert_eq!(outputs.len(), 1);
let output = outputs.into_iter().next().unwrap().unwrap();
let pretty = output.data.pretty_print().await;
assert!(pretty.contains("GreptimeDB"));
}
#[tokio::test]
async fn test_prepare_non_federated_query_not_shortcut() {
let mut shim = create_shim();
let query_ctx = QueryContext::arc();
let stmt_key = "test_non_federated".to_string();
let result = shim
.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
.await;
assert!(result.is_ok());
let plan = shim.plan(&stmt_key).unwrap();
assert!(matches!(plan, SqlPlan::Shortcut(_)));
}
#[tokio::test]
async fn test_execute_set_shortcut() {
let mut shim = create_shim();
let query_ctx = QueryContext::arc();
let stmt_key = "test_set_shortcut".to_string();
shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
.await
.unwrap();
let outputs = shim
.do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
.await
.unwrap();
assert_eq!(outputs.len(), 1);
let output = outputs.into_iter().next().unwrap().unwrap();
match output.data {
common_query::OutputData::RecordBatches(batches) => {
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 0);
}
other => panic!("Expected RecordBatches, got {:?}", other),
}
}
}

View File

@@ -23,9 +23,10 @@ use common_recordbatch::error::Result as RecordBatchResult;
use common_telemetry::{debug, info, tracing};
use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
use datafusion_common::ParamValues;
use datafusion_expr::LogicalPlan;
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::SchemaRef;
use datatypes::schema::{Schema, SchemaRef};
use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
@@ -318,14 +319,16 @@ impl QueryParser for DefaultQueryParser {
let query_ctx = self.session.new_query_context();
// do not parse if query is empty or matches rules
if sql.is_empty() || fixtures::matches(sql) {
if sql.is_empty() {
return Ok(PgSqlPlan {
plan: SqlPlan {
query: sql.to_owned(),
statement: None,
plan: None,
schema: None,
},
plan: SqlPlan::Empty,
copy_to_stdout_format: None,
});
}
if fixtures::matches(sql) {
return Ok(PgSqlPlan {
plan: SqlPlan::Shortcut(sql.to_string()),
copy_to_stdout_format: None,
});
}
@@ -354,31 +357,23 @@ impl QueryParser for DefaultQueryParser {
} else {
let stmt = stmts.remove(0);
let describe_result = self
if let Some(logical_plan) = self
.query_handler
.do_describe(stmt.clone(), query_ctx)
.await
.map_err(convert_err)?;
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
}) = describe_result
.map_err(convert_err)?
.map(|DescribeResult { logical_plan }| logical_plan)
{
(Some(logical_plan), Some(schema))
Ok(PgSqlPlan {
plan: SqlPlan::Plan(logical_plan, sql),
copy_to_stdout_format,
})
} else {
(None, None)
};
Ok(PgSqlPlan {
plan: SqlPlan {
query: sql.clone(),
statement: Some(stmt),
plan,
schema,
},
copy_to_stdout_format,
})
Ok(PgSqlPlan {
plan: SqlPlan::Statement(stmt, sql),
copy_to_stdout_format,
})
}
}
}
@@ -432,39 +427,45 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
let pg_sql_plan = &portal.statement.statement;
let sql_plan = &pg_sql_plan.plan;
if sql_plan.query.is_empty() {
// early return if query is empty
return Ok(Response::EmptyQuery);
}
if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
}
let output = if let Some(plan) = &sql_plan.plan {
let values = parameters_to_scalar_values(plan, portal)?;
let plan = plan
.clone()
.replace_params_with_values(&ParamValues::List(
values.into_iter().map(Into::into).collect(),
))
.context(DataFusionSnafu)
.map_err(convert_err)?;
self.query_handler
.do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
.await
} else {
// We won't replace params from statement manually any more.
// Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE.
// Only CREATE TABLE and others minor statements cannot generate sql plan,
// in this case, we assume these statements will not carry parameters
// and execute them directly.
self.query_handler
.do_query(&sql_plan.query, query_ctx.clone())
.await
.remove(0)
let output = match sql_plan {
SqlPlan::Empty => {
// early return if query is empty
return Ok(Response::EmptyQuery);
}
SqlPlan::Shortcut(query) => {
if let Some(mut resps) = fixtures::process(query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
} else {
// unreachable logic
return Ok(Response::EmptyQuery);
}
}
SqlPlan::Plan(plan, query) => {
let values = parameters_to_scalar_values(plan, portal)?;
let plan = plan
.clone()
.replace_params_with_values(&ParamValues::List(
values.into_iter().map(Into::into).collect(),
))
.context(DataFusionSnafu)
.map_err(convert_err)?;
self.query_handler
.do_exec_plan(plan, query.clone(), query_ctx.clone())
.await
}
SqlPlan::Statement(_stmt, query) => {
// We won't replace params from statement manually any more.
// Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE.
// Only CREATE TABLE and others minor statements cannot generate sql plan,
// in this case, we assume these statements will not carry parameters
// and execute them directly.
self.query_handler
.do_query(query, query_ctx.clone())
.await
.remove(0)
}
};
send_warning_opt(client, query_ctx.clone()).await?;
@@ -487,7 +488,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
let sql_plan = &stmt.statement.plan;
// client provided parameter types, can be empty if client doesn't try to parse statement
let provided_param_types = &stmt.parameter_types;
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
let server_inferenced_types = if let SqlPlan::Plan(plan, _) = &sql_plan {
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
.context(InferParameterTypesSnafu)
.map_err(convert_err)?
@@ -525,23 +526,9 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
})
.collect::<Vec<_>>();
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, &Format::UnifiedText, None)
.map(|fields| DescribeStatementResponse::new(param_types, fields))
.map_err(convert_err)
} else {
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
&& let Response::Query(query_response) = resp.remove(0)
{
return Ok(DescribeStatementResponse::new(
param_types,
(*query_response.row_schema()).clone(),
));
}
let fields = describe_fields(sql_plan, &Format::UnifiedText, &self.session)?;
Ok(DescribeStatementResponse::new(param_types, vec![]))
}
Ok(DescribeStatementResponse::new(param_types, fields))
}
async fn do_describe_portal<C>(
@@ -555,68 +542,100 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
let sql_plan = &portal.statement.statement.plan;
let format = &portal.result_column_format;
match sql_plan.statement.as_ref() {
Some(Statement::Query(_)) => {
// if the query has a schema, it is managed by datafusion, use the schema
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, format, None)
.map(DescribePortalResponse::new)
.map_err(convert_err)
} else {
// fallback to NoData
Ok(DescribePortalResponse::new(vec![]))
}
}
// We can cover only part of show statements
// these show create statements will return 2 columns
Some(Statement::ShowCreateDatabase(_))
| Some(Statement::ShowCreateTable(_))
| Some(Statement::ShowCreateFlow(_))
| Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
),
FieldInfo::new(
"create_statement".to_string(),
None,
None,
Type::TEXT,
format.format_for(1),
),
])),
// single column show statements
Some(Statement::ShowTables(_))
| Some(Statement::ShowFlows(_))
| Some(Statement::ShowViews(_)) => {
Ok(DescribePortalResponse::new(vec![FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
)]))
}
// we will not support other show statements for extended query protocol at least for now.
// because the return columns is not predictable at this stage
_ => {
// test if query caught by fixture
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
&& let Response::Query(query_response) = resp.remove(0)
{
Ok(DescribePortalResponse::new(
(*query_response.row_schema()).clone(),
))
} else {
// fallback to NoData
Ok(DescribePortalResponse::new(vec![]))
}
let fields = describe_fields(sql_plan, format, &self.session)?;
Ok(DescribePortalResponse::new(fields))
}
}
fn describe_fields(
sql_plan: &SqlPlan,
format: &Format,
session: &Arc<Session>,
) -> PgWireResult<Vec<FieldInfo>> {
match sql_plan {
// query
SqlPlan::Plan(plan, _) if !matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_)) => {
let schema: Schema = plan.schema().clone().try_into().map_err(convert_err)?;
schema_to_pg(&schema, format, None).map_err(convert_err)
}
// We can cover only part of show statements
// these show create statements will return 2 columns
SqlPlan::Statement(
Statement::ShowCreateDatabase(_)
| Statement::ShowCreateTable(_)
| Statement::ShowCreateFlow(_)
| Statement::ShowCreateView(_),
_,
) => Ok(vec![
FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
),
FieldInfo::new(
"create_statement".to_string(),
None,
None,
Type::TEXT,
format.format_for(1),
),
]),
#[cfg(feature = "enterprise")]
SqlPlan::Statement(Statement::ShowCreateTrigger(_), _) => Ok(vec![
FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
),
FieldInfo::new(
"create_statement".to_string(),
None,
None,
Type::TEXT,
format.format_for(1),
),
]),
// single column show statements
SqlPlan::Statement(
Statement::ShowTables(_) | Statement::ShowFlows(_) | Statement::ShowViews(_),
_,
) => Ok(vec![FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
)]),
#[cfg(feature = "enterprise")]
SqlPlan::Statement(Statement::ShowTriggers(_), _) => Ok(vec![FieldInfo::new(
"name".to_string(),
None,
None,
Type::TEXT,
format.format_for(0),
)]),
// we will not support other show statements for extended query protocol at least for now.
// because the return columns is not predictable at this stage
SqlPlan::Shortcut(query) => {
// test if query caught by fixture
if let Some(mut resp) = fixtures::process(query, session.new_query_context())
&& let Response::Query(query_response) = resp.remove(0)
{
Ok((*query_response.row_schema()).clone())
} else {
// fallback to NoData
Ok(vec![])
}
}
_ => {
// NoData
Ok(vec![])
}
}
}

View File

@@ -1732,12 +1732,7 @@ mod test {
let statement = Arc::new(StoredStatement::new(
String::new(),
PgSqlPlan {
plan: SqlPlan {
query: String::new(),
statement: None,
plan: None,
schema: None,
},
plan: SqlPlan::Empty,
copy_to_stdout_format: None,
},
client_param_types,

View File

@@ -32,8 +32,8 @@ pub trait SqlQueryHandler {
async fn do_exec_plan(
&self,
stmt: Option<Statement>,
plan: LogicalPlan,
query: String,
query_ctx: QueryContextRef,
) -> Result<Output>;

View File

@@ -31,7 +31,6 @@ use servers::influxdb::InfluxdbRequest;
use servers::query_handler::InfluxdbLineProtocolHandler;
use servers::query_handler::sql::SqlQueryHandler;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -58,8 +57,8 @@ impl SqlQueryHandler for DummyInstance {
async fn do_exec_plan(
&self,
_stmt: Option<Statement>,
_plan: LogicalPlan,
_query: String,
_query_ctx: QueryContextRef,
) -> Result<Output> {
unimplemented!()

View File

@@ -28,7 +28,6 @@ use servers::opentsdb::codec::DataPoint;
use servers::query_handler::OpentsdbProtocolHandler;
use servers::query_handler::sql::SqlQueryHandler;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -58,8 +57,8 @@ impl SqlQueryHandler for DummyInstance {
async fn do_exec_plan(
&self,
_stmt: Option<Statement>,
_plan: LogicalPlan,
_query: String,
_query_ctx: QueryContextRef,
) -> Result<Output> {
unimplemented!()

View File

@@ -36,7 +36,6 @@ use servers::prom_store::{Metrics, snappy_compress};
use servers::query_handler::sql::SqlQueryHandler;
use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse};
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -87,8 +86,8 @@ impl SqlQueryHandler for DummyInstance {
async fn do_exec_plan(
&self,
_stmt: Option<Statement>,
_plan: LogicalPlan,
_query: String,
_query_ctx: QueryContextRef,
) -> Result<Output> {
unimplemented!()

View File

@@ -82,8 +82,8 @@ impl SqlQueryHandler for DummyInstance {
async fn do_exec_plan(
&self,
_stmt: Option<Statement>,
plan: LogicalPlan,
_query: String,
query_ctx: QueryContextRef,
) -> Result<Output> {
Ok(self.query_engine.execute(plan, query_ctx).await.unwrap())

View File

@@ -85,6 +85,7 @@ macro_rules! sql_tests {
test_postgres_uint64_parameter,
test_postgres_array_types,
test_mysql_prepare_stmt_insert_timestamp,
test_mysql_federated_prepare_stmt,
test_declare_fetch_close_cursor,
test_alter_update_on,
);
@@ -1586,6 +1587,45 @@ pub async fn test_mysql_prepare_stmt_insert_timestamp(store_type: StorageType) {
guard.remove_all().await;
}
pub async fn test_mysql_federated_prepare_stmt(store_type: StorageType) {
common_telemetry::init_default_ut_logging();
let (mut guard, fe_mysql_server) =
setup_mysql_server(store_type, "test_mysql_federated_prepare_stmt").await;
let addr = fe_mysql_server.bind_addr().unwrap().to_string();
let pool = MySqlPoolOptions::new()
.max_connections(2)
.connect(&format!("mysql://{addr}/public"))
.await
.unwrap();
// sqlx::query uses binary prepared statement protocol (COM_STMT_PREPARE + COM_STMT_EXECUTE)
// "SELECT @@version_comment" is a federated query matched by federated::check
let rows = sqlx::query("SELECT @@version_comment")
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(rows.len(), 1);
let val: String = rows[0].get(0);
assert!(val.contains("GreptimeDB"));
// "SET NAMES utf8" is another federated pattern
sqlx::query("SET NAMES utf8").execute(&pool).await.unwrap();
// "SELECT @@tx_isolation" is a federated variable query
let rows = sqlx::query("SELECT @@tx_isolation")
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(rows.len(), 1);
let val: String = rows[0].get(0);
assert_eq!(val, "REPEATABLE-READ");
let _ = fe_mysql_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_postgres_array_types(store_type: StorageType) {
let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_array_types").await;
let addr = fe_pg_server.bind_addr().unwrap().to_string();