diff --git a/Cargo.lock b/Cargo.lock index 6ad48467cb..e1fc1cfdf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4699,6 +4699,7 @@ version = "0.15.0" dependencies = [ "api", "arc-swap", + "async-stream", "async-trait", "auth", "bytes", diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 3e48324821..e07d152d2d 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -178,8 +178,6 @@ pub enum Error { StreamTimeout { #[snafu(implicit)] location: Location, - #[snafu(source)] - error: tokio::time::error::Elapsed, }, #[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))] diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 15fd74cb32..d1e88f224a 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] api.workspace = true arc-swap = "1.0" +async-stream.workspace = true async-trait.workspace = true auth.workspace = true bytes.workspace = true diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 44f6dfdeb0..6dac8e9898 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -363,6 +363,12 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Canceling statement due to statement timeout"))] + StatementTimeout { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -443,6 +449,8 @@ impl ErrorExt for Error { Error::DataFusion { error, .. } => datafusion_status_code::(error, None), Error::Cancelled { .. } => StatusCode::Cancelled, + + Error::StatementTimeout { .. } => StatusCode::Cancelled, } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index d17257892e..c41475c50a 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -25,9 +25,11 @@ mod promql; mod region_query; pub mod standalone; +use std::pin::Pin; use std::sync::Arc; -use std::time::SystemTime; +use std::time::{Duration, SystemTime}; +use async_stream::stream; use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use catalog::process_manager::ProcessManagerRef; @@ -44,8 +46,11 @@ use common_procedure::local::{LocalManager, ManagerConfig}; use common_procedure::options::ProcedureConfig; use common_procedure::ProcedureManagerRef; use common_query::Output; +use common_recordbatch::error::StreamTimeoutSnafu; +use common_recordbatch::RecordBatchStreamWrapper; use common_telemetry::{debug, error, info, tracing}; use datafusion_expr::LogicalPlan; +use futures::{Stream, StreamExt}; use log_store::raft_engine::RaftEngineBackend; use operator::delete::DeleterRef; use operator::insert::InserterRef; @@ -65,20 +70,21 @@ use servers::interceptor::{ }; use servers::prometheus_handler::PrometheusHandler; use servers::query_handler::sql::SqlQueryHandler; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContextRef}; use session::table_name::table_idents_to_full_name; use snafu::prelude::*; use sql::dialect::Dialect; use sql::parser::{ParseOptions, ParserContext}; use sql::statements::copy::{CopyDatabase, CopyTable}; use sql::statements::statement::Statement; +use sql::statements::tql::Tql; use sqlparser::ast::ObjectName; pub use standalone::StandaloneDatanodeManager; use crate::error::{ self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu, InvalidSqlSnafu, ParseSqlSnafu, PermissionSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu, - TableOperationSnafu, + StatementTimeoutSnafu, TableOperationSnafu, }; use crate::limiter::LimiterRef; use crate::slow_query_recorder::SlowQueryRecorder; @@ -188,56 +194,7 @@ impl Instance { Some(query_ctx.process_id()), ); - let query_fut = async { - match stmt { - Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => { - // TODO: remove this when format is supported in datafusion - if let Statement::Explain(explain) = &stmt { - if let Some(format) = explain.format() { - query_ctx.set_explain_format(format.to_string()); - } - } - - let stmt = QueryStatement::Sql(stmt); - let plan = self - .statement_executor - .plan(&stmt, query_ctx.clone()) - .await?; - - let QueryStatement::Sql(stmt) = stmt else { - unreachable!() - }; - query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?; - self.statement_executor - .exec_plan(plan, query_ctx) - .await - .context(TableOperationSnafu) - } - Statement::Tql(tql) => { - let plan = self - .statement_executor - .plan_tql(tql.clone(), &query_ctx) - .await?; - - query_interceptor.pre_execute( - &Statement::Tql(tql), - Some(&plan), - query_ctx.clone(), - )?; - self.statement_executor - .exec_plan(plan, query_ctx) - .await - .context(TableOperationSnafu) - } - _ => { - query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; - self.statement_executor - .execute_sql(stmt, query_ctx) - .await - .context(TableOperationSnafu) - } - } - }; + let query_fut = self.exec_statement_with_timeout(stmt, query_ctx, query_interceptor); CancellableFuture::new(query_fut, ticket.cancellation_handle.clone()) .await @@ -254,6 +211,149 @@ impl Instance { Output { data, meta } }) } + + async fn exec_statement_with_timeout( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + query_interceptor: Option<&SqlQueryInterceptorRef>, + ) -> Result { + let timeout = derive_timeout(&stmt, &query_ctx); + match timeout { + Some(timeout) => { + let start = tokio::time::Instant::now(); + let output = tokio::time::timeout( + timeout, + self.exec_statement(stmt, query_ctx, query_interceptor), + ) + .await + .map_err(|_| StatementTimeoutSnafu.build())??; + // compute remaining timeout + let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default(); + attach_timeout(output, remaining_timeout) + } + None => { + self.exec_statement(stmt, query_ctx, query_interceptor) + .await + } + } + } + + async fn exec_statement( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + query_interceptor: Option<&SqlQueryInterceptorRef>, + ) -> Result { + match stmt { + Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => { + // TODO: remove this when format is supported in datafusion + if let Statement::Explain(explain) = &stmt { + if let Some(format) = explain.format() { + query_ctx.set_explain_format(format.to_string()); + } + } + + self.plan_and_exec_sql(stmt, &query_ctx, query_interceptor) + .await + } + Statement::Tql(tql) => { + self.plan_and_exec_tql(&query_ctx, query_interceptor, tql) + .await + } + _ => { + query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; + self.statement_executor + .execute_sql(stmt, query_ctx) + .await + .context(TableOperationSnafu) + } + } + } + + async fn plan_and_exec_sql( + &self, + stmt: Statement, + query_ctx: &QueryContextRef, + query_interceptor: Option<&SqlQueryInterceptorRef>, + ) -> Result { + let stmt = QueryStatement::Sql(stmt); + let plan = self + .statement_executor + .plan(&stmt, query_ctx.clone()) + .await?; + let QueryStatement::Sql(stmt) = stmt else { + unreachable!() + }; + query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?; + self.statement_executor + .exec_plan(plan, query_ctx.clone()) + .await + .context(TableOperationSnafu) + } + + async fn plan_and_exec_tql( + &self, + query_ctx: &QueryContextRef, + query_interceptor: Option<&SqlQueryInterceptorRef>, + tql: Tql, + ) -> Result { + let plan = self + .statement_executor + .plan_tql(tql.clone(), query_ctx) + .await?; + query_interceptor.pre_execute(&Statement::Tql(tql), Some(&plan), query_ctx.clone())?; + self.statement_executor + .exec_plan(plan, query_ctx.clone()) + .await + .context(TableOperationSnafu) + } +} + +/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements. +/// For MySQL, it applies only to read-only statements. +fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option { + let query_timeout = query_ctx.query_timeout()?; + match (query_ctx.channel(), stmt) { + (Channel::Mysql, Statement::Query(_)) | (Channel::Postgres, _) => Some(query_timeout), + (_, _) => None, + } +} + +fn attach_timeout(output: Output, mut timeout: Duration) -> Result { + if timeout.is_zero() { + return StatementTimeoutSnafu.fail(); + } + + let output = match output.data { + OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output, + OutputData::Stream(mut stream) => { + let schema = stream.schema(); + let s = Box::pin(stream! { + let mut start = tokio::time::Instant::now(); + while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.map_err(|_| StreamTimeoutSnafu.build())? { + yield item; + + let now = tokio::time::Instant::now(); + timeout = timeout.checked_sub(now - start).unwrap_or(Duration::ZERO); + start = now; + // tokio::time::timeout may not return an error immediately when timeout is 0. + if timeout.is_zero() { + StreamTimeoutSnafu.fail()?; + } + } + }) as Pin + Send>>; + let stream = RecordBatchStreamWrapper { + schema, + stream: s, + output_ordering: None, + metrics: Default::default(), + }; + Output::new(OutputData::Stream(Box::pin(stream)), output.meta) + } + }; + + Ok(output) } #[async_trait] diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index c488efe8cc..137c72571c 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -23,7 +23,6 @@ use datafusion::parquet; use datatypes::arrow::error::ArrowError; use snafu::{Location, Snafu}; use table::metadata::TableType; -use tokio::time::error::Elapsed; #[derive(Snafu)] #[snafu(visibility(pub))] @@ -786,14 +785,6 @@ pub enum Error { json: String, }, - #[snafu(display("Canceling statement due to statement timeout"))] - StatementTimeout { - #[snafu(implicit)] - location: Location, - #[snafu(source)] - error: Elapsed, - }, - #[snafu(display("Cursor {name} is not found"))] CursorNotFound { name: String }, @@ -983,7 +974,6 @@ impl ErrorExt for Error { Error::ExecuteAdminFunction { source, .. } => source.status_code(), Error::BuildRecordBatch { source, .. } => source.status_code(), Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal, - Error::StatementTimeout { .. } => StatusCode::Cancelled, Error::ColumnOptions { source, .. } => source.status_code(), Error::DecodeFlightData { source, .. } => source.status_code(), Error::ComputeArrow { .. } => StatusCode::Internal, diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index acc9265347..698396528f 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -27,15 +27,12 @@ mod show; mod tql; use std::collections::HashMap; -use std::pin::Pin; use std::sync::Arc; -use std::time::Duration; -use async_stream::stream; use catalog::kvbackend::KvBackendCatalogManager; use catalog::process_manager::ProcessManagerRef; use catalog::CatalogManagerRef; -use client::{OutputData, RecordBatches}; +use client::RecordBatches; use common_error::ext::BoxedError; use common_meta::cache::TableRouteCacheRef; use common_meta::cache_invalidator::CacheInvalidatorRef; @@ -46,13 +43,10 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef}; use common_meta::key::{TableMetadataManager, TableMetadataManagerRef}; use common_meta::kv_backend::KvBackendRef; use common_query::Output; -use common_recordbatch::error::StreamTimeoutSnafu; -use common_recordbatch::RecordBatchStreamWrapper; use common_telemetry::tracing; use common_time::range::TimestampRange; use common_time::Timestamp; use datafusion_expr::LogicalPlan; -use futures::stream::{Stream, StreamExt}; use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef}; use query::parser::QueryStatement; use query::QueryEngineRef; @@ -79,8 +73,8 @@ use self::set::{ }; use crate::error::{ self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu, - PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu, - TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu, + PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu, + UpgradeCatalogManagerRefSnafu, }; use crate::insert::InserterRef; use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY}; @@ -490,19 +484,8 @@ impl StatementExecutor { #[tracing::instrument(skip_all)] async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result { - let timeout = derive_timeout(&stmt, &query_ctx); - match timeout { - Some(timeout) => { - let start = tokio::time::Instant::now(); - let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx)) - .await - .context(StatementTimeoutSnafu)?; - // compute remaining timeout - let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default(); - Ok(attach_timeout(output?, remaining_timeout)) - } - None => self.plan_exec_inner(stmt, query_ctx).await, - } + let plan = self.plan(&stmt, query_ctx.clone()).await?; + self.exec_plan(plan, query_ctx).await } async fn get_table(&self, table_ref: &TableReference<'_>) -> Result { @@ -519,49 +502,6 @@ impl StatementExecutor { table_name: table_ref.to_string(), }) } - - async fn plan_exec_inner( - &self, - stmt: QueryStatement, - query_ctx: QueryContextRef, - ) -> Result { - let plan = self.plan(&stmt, query_ctx.clone()).await?; - self.exec_plan(plan, query_ctx).await - } -} - -fn attach_timeout(output: Output, mut timeout: Duration) -> Output { - match output.data { - OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output, - OutputData::Stream(mut stream) => { - let schema = stream.schema(); - let s = Box::pin(stream! { - let start = tokio::time::Instant::now(); - while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? { - yield item; - timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO); - } - }) as Pin + Send>>; - let stream = RecordBatchStreamWrapper { - schema, - stream: s, - output_ordering: None, - metrics: Default::default(), - }; - Output::new(OutputData::Stream(Box::pin(stream)), output.meta) - } - } -} - -/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements. -/// For MySQL, it applies only to read-only statements. -fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option { - let query_timeout = query_ctx.query_timeout()?; - match (query_ctx.channel(), stmt) { - (Channel::Mysql, QueryStatement::Sql(Statement::Query(_))) - | (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout), - (_, _) => None, - } } fn to_copy_query_request(stmt: CopyQueryToArgument) -> Result {