feat: implement statement/execution timeout session variable (#4792)

* support set and show on statement/execution timeout session variables.

* implement statement timeout for mysql read, and postgres queries

* add mysql test with max execution time
This commit is contained in:
Lanqing Yang
2024-11-14 22:19:39 -08:00
committed by GitHub
parent 42bf7e9965
commit cdba7b442f
13 changed files with 330 additions and 17 deletions

1
Cargo.lock generated
View File

@@ -7714,6 +7714,7 @@ name = "operator"
version = "0.9.5"
dependencies = [
"api",
"async-stream",
"async-trait",
"catalog",
"chrono",

View File

@@ -20,6 +20,7 @@ pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
tokio.workspace = true
[dev-dependencies]
tokio.workspace = true

View File

@@ -161,6 +161,13 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Stream timeout"))]
StreamTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
}
impl ErrorExt for Error {
@@ -190,6 +197,8 @@ impl ErrorExt for Error {
Error::SchemaConversion { source, .. } | Error::CastVector { source, .. } => {
source.status_code()
}
Error::StreamTimeout { .. } => StatusCode::Cancelled,
}
}

View File

@@ -12,6 +12,7 @@ workspace = true
[dependencies]
api.workspace = true
async-stream.workspace = true
async-trait = "0.1"
catalog.workspace = true
chrono.workspace = true

View File

@@ -23,6 +23,7 @@ use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use snafu::{Location, Snafu};
use table::metadata::TableType;
use tokio::time::error::Elapsed;
#[derive(Snafu)]
#[snafu(visibility(pub))]
@@ -777,6 +778,14 @@ pub enum Error {
location: Location,
json: String,
},
#[snafu(display("Canceling statement due to statement timeout"))]
StatementTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: Elapsed,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -924,6 +933,7 @@ impl ErrorExt for Error {
Error::BuildRecordBatch { source, .. } => source.status_code(),
Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal,
Error::StatementTimeout { .. } => StatusCode::Cancelled,
}
}

View File

@@ -24,11 +24,14 @@ mod show;
mod tql;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_stream::stream;
use catalog::kvbackend::KvBackendCatalogManager;
use catalog::CatalogManagerRef;
use client::RecordBatches;
use client::{OutputData, RecordBatches};
use common_error::ext::BoxedError;
use common_meta::cache::TableRouteCacheRef;
use common_meta::cache_invalidator::CacheInvalidatorRef;
@@ -39,15 +42,19 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef};
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
use common_meta::kv_backend::KvBackendRef;
use common_query::Output;
use common_recordbatch::error::StreamTimeoutSnafu;
use common_recordbatch::RecordBatchStreamWrapper;
use common_telemetry::tracing;
use common_time::range::TimestampRange;
use common_time::Timestamp;
use datafusion_expr::LogicalPlan;
use futures::stream::{Stream, StreamExt};
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use set::set_query_timeout;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
use sql::statements::set_variables::SetVariables;
@@ -63,8 +70,8 @@ use table::TableRef;
use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu,
UpgradeCatalogManagerRefSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu,
TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
@@ -338,6 +345,28 @@ impl StatementExecutor {
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,
"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
"MAX_EXECUTION_TIME" => match query_ctx.channel() {
Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?,
Channel::Postgres => {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name))
}
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail()
}
},
"STATEMENT_TIMEOUT" => {
if query_ctx.channel() == Channel::Postgres {
set_query_timeout(set_var.value, query_ctx)?
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
}
_ => {
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
@@ -387,8 +416,19 @@ impl StatementExecutor {
#[tracing::instrument(skip_all)]
async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
let timeout = derive_timeout(&stmt, &query_ctx);
match timeout {
Some(timeout) => {
let start = tokio::time::Instant::now();
let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx))
.await
.context(StatementTimeoutSnafu)?;
// compute remaining timeout
let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
Ok(attach_timeout(output?, remaining_timeout))
}
None => self.plan_exec_inner(stmt, query_ctx).await,
}
}
async fn get_table(&self, table_ref: &TableReference<'_>) -> Result<TableRef> {
@@ -405,6 +445,49 @@ impl StatementExecutor {
table_name: table_ref.to_string(),
})
}
async fn plan_exec_inner(
&self,
stmt: QueryStatement,
query_ctx: QueryContextRef,
) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
}
}
fn attach_timeout(output: Output, mut timeout: Duration) -> Output {
match output.data {
OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
OutputData::Stream(mut stream) => {
let schema = stream.schema();
let s = Box::pin(stream! {
let start = tokio::time::Instant::now();
while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? {
yield item;
timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO);
}
}) as Pin<Box<dyn Stream<Item = _> + Send>>;
let stream = RecordBatchStreamWrapper {
schema,
stream: s,
output_ordering: None,
metrics: Default::default(),
};
Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
}
}
}
/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
/// For MySQL, it applies only to read-only statements.
fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option<Duration> {
let query_timeout = query_ctx.query_timeout()?;
match (query_ctx.channel(), stmt) {
(Channel::Mysql, QueryStatement::Sql(Statement::Query(_)))
| (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout),
(_, _) => None,
}
}
fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use common_time::Timezone;
use lazy_static::lazy_static;
use regex::Regex;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
@@ -21,6 +26,15 @@ use sql::statements::set_variables::SetVariables;
use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
lazy_static! {
// Regex rules:
// The string must start with a number (one or more digits).
// The number must be followed by one of the valid time units (ms, s, min, h, d).
// The string must end immediately after the unit, meaning there can be no extra
// characters or spaces after the valid time specification.
static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
}
pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timezone find in set variable statement",
@@ -177,3 +191,96 @@ pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}
pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
// postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}
// support time units in ms, s, min, h, d for postgres protocol.
// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
match input.parse::<u64>() {
Ok(timeout) => Ok(timeout),
Err(_) => {
if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];
match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}
}
}
#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;
#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());
assert!(parse_pg_query_timeout_input("ms").is_err());
assert!(parse_pg_query_timeout_input("a").is_err());
assert!(parse_pg_query_timeout_input("-1").is_err());
assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}

View File

@@ -48,7 +48,7 @@ use datatypes::vectors::StringVector;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use regex::Regex;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
pub use show_create_table::create_table_stmt;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Ident;
@@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
"STATEMENT_TIMEOUT" => {
// Add time units to postgres query timeout display.
if query_ctx.channel() == Channel::Postgres {
let mut timeout = query_ctx.query_timeout_as_millis().to_string();
timeout.push_str("ms");
timeout
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
_ => return UnsupportedVariableSnafu { name: variable }.fail(),
};
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(

View File

@@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use api::v1::region::RegionRequestHeader;
use arc_swap::ArcSwap;
@@ -282,6 +283,22 @@ impl QueryContext {
pub fn set_warning(&self, msg: String) {
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
}
pub fn query_timeout(&self) -> Option<Duration> {
self.mutable_session_data.read().unwrap().query_timeout
}
pub fn query_timeout_as_millis(&self) -> u128 {
let timeout = self.mutable_session_data.read().unwrap().query_timeout;
if let Some(t) = timeout {
return t.as_millis();
}
0
}
pub fn set_query_timeout(&self, timeout: Duration) {
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
}
}
impl QueryContextBuilder {

View File

@@ -18,6 +18,7 @@ pub mod table_name;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use auth::UserInfoRef;
use common_catalog::build_db_string;
@@ -45,6 +46,7 @@ pub(crate) struct MutableInner {
schema: String,
user_info: UserInfoRef,
timezone: Timezone,
query_timeout: Option<Duration>,
}
impl Default for MutableInner {
@@ -53,6 +55,7 @@ impl Default for MutableInner {
schema: DEFAULT_SCHEMA_NAME.into(),
user_info: auth::userinfo_by_name(None),
timezone: get_timezone(None).clone(),
query_timeout: None,
}
}
}

View File

@@ -58,47 +58,83 @@ mod tests {
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParseOptions;
fn assert_mysql_parse_result(sql: &str) {
fn assert_mysql_parse_result(sql: &str, indent_str: &str, expr: Expr) {
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
Statement::SetVariables(SetVariables {
variable: ObjectName(vec![Ident::new("time_zone")]),
value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))]
variable: ObjectName(vec![Ident::new(indent_str)]),
value: vec![expr]
})
);
}
fn assert_pg_parse_result(sql: &str) {
fn assert_pg_parse_result(sql: &str, indent: &str, expr: Expr) {
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
Statement::SetVariables(SetVariables {
variable: ObjectName(vec![Ident::new("TIMEZONE")]),
value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))],
variable: ObjectName(vec![Ident::new(indent)]),
value: vec![expr],
})
);
}
#[test]
pub fn test_set_timezone() {
let expected_utc_expr = Expr::Value(Value::SingleQuotedString("UTC".to_string()));
// mysql style
let sql = "SET time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
// session or local style
let sql = "SET LOCAL time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
let sql = "SET SESSION time_zone = 'UTC'";
assert_mysql_parse_result(sql);
assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
// postgresql style
let sql = "SET TIMEZONE TO 'UTC'";
assert_pg_parse_result(sql);
assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr.clone());
let sql = "SET TIMEZONE 'UTC'";
assert_pg_parse_result(sql);
assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr);
}
#[test]
pub fn test_set_query_timeout() {
let expected_query_timeout_expr = Expr::Value(Value::Number("5000".to_string(), false));
// mysql style
let sql = "SET MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);
// session or local style
let sql = "SET LOCAL MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);
let sql = "SET SESSION MAX_EXECUTION_TIME = 5000";
assert_mysql_parse_result(
sql,
"MAX_EXECUTION_TIME",
expected_query_timeout_expr.clone(),
);
// postgresql style
let sql = "SET STATEMENT_TIMEOUT = 5000";
assert_pg_parse_result(
sql,
"STATEMENT_TIMEOUT",
expected_query_timeout_expr.clone(),
);
let sql = "SET STATEMENT_TIMEOUT TO 5000";
assert_pg_parse_result(sql, "STATEMENT_TIMEOUT", expected_query_timeout_expr);
}
}

View File

@@ -179,3 +179,22 @@ DROP TABLE foo;
Affected Rows: 0
-- SQLNESS PROTOCOL MYSQL
SET MAX_EXECUTION_TIME = 2000;
affected_rows: 0
-- SQLNESS PROTOCOL MYSQL
SHOW VARIABLES MAX_EXECUTION_TIME;
+---------------+-------+
| Variable_name | Value |
+---------------+-------+
| | |
+---------------+-------+
-- SQLNESS PROTOCOL MYSQL
SET MAX_EXECUTION_TIME = 0;
affected_rows: 0

View File

@@ -72,3 +72,12 @@ DROP TABLE phy;
DROP TABLE system_metrics;
DROP TABLE foo;
-- SQLNESS PROTOCOL MYSQL
SET MAX_EXECUTION_TIME = 2000;
-- SQLNESS PROTOCOL MYSQL
SHOW VARIABLES MAX_EXECUTION_TIME;
-- SQLNESS PROTOCOL MYSQL
SET MAX_EXECUTION_TIME = 0;