From 816d94892c88f3813e9dbd72ec210827ef0493f8 Mon Sep 17 00:00:00 2001 From: WU Jingdi Date: Mon, 15 Jan 2024 14:29:31 +0800 Subject: [PATCH] feat: support HTTP&gRPC&pg set timezone (#3125) * feat: support HTTP&gRPC&pg set timezone * chore: fix code advice * chore: fix code advice --- Cargo.lock | 2 +- Cargo.toml | 2 +- benchmarks/src/bin/nyc-taxi.rs | 2 +- src/catalog/src/error.rs | 2 +- src/catalog/src/kvbackend/manager.rs | 2 +- src/client/examples/logical.rs | 2 +- src/client/src/database.rs | 20 ++++- src/client/src/region.rs | 4 +- src/common/base/src/bytes.rs | 2 +- src/common/grpc-expr/src/alter.rs | 10 +-- src/common/procedure/src/local.rs | 2 +- src/common/query/src/lib.rs | 2 +- src/common/time/src/interval.rs | 8 +- src/common/time/src/timezone.rs | 10 +++ src/frontend/src/instance.rs | 2 + src/meta-srv/src/procedure/region_failover.rs | 4 +- src/meta-srv/src/procedure/utils.rs | 2 +- src/operator/src/expr_factory.rs | 4 +- src/operator/src/statement.rs | 49 +++++++++++- src/operator/src/statement/show.rs | 7 +- src/query/src/error.rs | 4 + src/query/src/sql.rs | 70 ++++++++++++++++- src/servers/src/export_metrics.rs | 6 +- src/servers/src/grpc/greptime_handler.rs | 4 +- src/servers/src/http.rs | 5 +- src/servers/src/http/authorize.rs | 24 +++++- src/servers/src/http/csv_result.rs | 2 +- src/servers/src/http/header.rs | 2 + src/servers/src/http/influxdb_result_v1.rs | 2 +- src/servers/src/mysql/federated.rs | 63 +-------------- src/servers/src/mysql/handler.rs | 4 +- src/servers/src/mysql/writer.rs | 2 +- src/servers/src/postgres/handler.rs | 3 +- src/session/src/context.rs | 40 +++++++--- src/session/src/lib.rs | 2 +- src/sql/src/parser.rs | 2 + src/sql/src/parsers.rs | 1 + src/sql/src/parsers/set_var_parser.rs | 78 +++++++++++++++++++ src/sql/src/parsers/show_parser.rs | 30 ++++++- src/sql/src/statements.rs | 3 +- src/sql/src/statements/create.rs | 4 +- src/sql/src/statements/set_variables.rs | 23 ++++++ src/sql/src/statements/show.rs | 6 ++ src/sql/src/statements/statement.rs | 6 ++ src/store-api/src/region_request.rs | 6 +- tests-integration/tests/grpc.rs | 62 +++++++++++++++ tests-integration/tests/http.rs | 40 +++++++++- tests-integration/tests/sql.rs | 68 +++++++++++++++- 48 files changed, 568 insertions(+), 132 deletions(-) create mode 100644 src/sql/src/parsers/set_var_parser.rs create mode 100644 src/sql/src/statements/set_variables.rs diff --git a/Cargo.lock b/Cargo.lock index 9c1c71356e..350fa50c1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3635,7 +3635,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=a31ea166fc015ea7ff111ac94e26c3a5d64364d2#a31ea166fc015ea7ff111ac94e26c3a5d64364d2" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=2c1f17dce7af748c9a1255e82d6ceb7959f8919b#2c1f17dce7af748c9a1255e82d6ceb7959f8919b" dependencies = [ "prost 0.12.3", "serde", diff --git a/Cargo.toml b/Cargo.toml index 8e1e9b07f8..5454343131 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,7 +91,7 @@ etcd-client = "0.12" fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "a31ea166fc015ea7ff111ac94e26c3a5d64364d2" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "2c1f17dce7af748c9a1255e82d6ceb7959f8919b" } humantime-serde = "1.1" itertools = "0.10" lazy_static = "1.4" diff --git a/benchmarks/src/bin/nyc-taxi.rs b/benchmarks/src/bin/nyc-taxi.rs index 678b5321c0..f357ba5d88 100644 --- a/benchmarks/src/bin/nyc-taxi.rs +++ b/benchmarks/src/bin/nyc-taxi.rs @@ -258,7 +258,7 @@ fn create_table_expr(table_name: &str) -> CreateTableExpr { catalog_name: CATALOG_NAME.to_string(), schema_name: SCHEMA_NAME.to_string(), table_name: table_name.to_string(), - desc: "".to_string(), + desc: String::default(), column_defs: vec![ ColumnDef { name: "VendorID".to_string(), diff --git a/src/catalog/src/error.rs b/src/catalog/src/error.rs index 6602aecbd4..5ec7988640 100644 --- a/src/catalog/src/error.rs +++ b/src/catalog/src/error.rs @@ -333,7 +333,7 @@ mod tests { assert_eq!( StatusCode::StorageUnavailable, Error::SystemCatalog { - msg: "".to_string(), + msg: String::default(), location: Location::generate(), } .status_code() diff --git a/src/catalog/src/kvbackend/manager.rs b/src/catalog/src/kvbackend/manager.rs index 5e7de6fe31..5bed429be1 100644 --- a/src/catalog/src/kvbackend/manager.rs +++ b/src/catalog/src/kvbackend/manager.rs @@ -83,7 +83,7 @@ impl KvBackendCatalogManager { catalog_manager: me.clone(), information_schema_provider: Arc::new(InformationSchemaProvider::new( // The catalog name is not used in system_catalog, so let it empty - "".to_string(), + String::default(), me.clone(), )), }, diff --git a/src/client/examples/logical.rs b/src/client/examples/logical.rs index ec4c11cdd9..13f1165555 100644 --- a/src/client/examples/logical.rs +++ b/src/client/examples/logical.rs @@ -37,7 +37,7 @@ async fn run() { catalog_name: "greptime".to_string(), schema_name: "public".to_string(), table_name: "test_logical_dist_exec".to_string(), - desc: "".to_string(), + desc: String::default(), column_defs: vec![ ColumnDef { name: "timestamp".to_string(), diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 4060cc4797..46fcf4bcec 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -47,6 +47,9 @@ pub struct Database { // The dbname follows naming rule as out mysql, postgres and http // protocol. The server treat dbname in priority of catalog/schema. dbname: String, + // The time zone indicates the time zone where the user is located. + // Some queries need to be aware of the user's time zone to perform some specific actions. + timezone: String, client: Client, ctx: FlightContext, @@ -58,7 +61,8 @@ impl Database { Self { catalog: catalog.into(), schema: schema.into(), - dbname: "".to_string(), + dbname: String::default(), + timezone: String::default(), client, ctx: FlightContext::default(), } @@ -73,8 +77,9 @@ impl Database { /// environment pub fn new_with_dbname(dbname: impl Into, client: Client) -> Self { Self { - catalog: "".to_string(), - schema: "".to_string(), + catalog: String::default(), + schema: String::default(), + timezone: String::default(), dbname: dbname.into(), client, ctx: FlightContext::default(), @@ -105,6 +110,14 @@ impl Database { self.dbname = dbname.into(); } + pub fn timezone(&self) -> &String { + &self.timezone + } + + pub fn set_timezone(&mut self, timezone: impl Into) { + self.timezone = timezone.into(); + } + pub fn set_auth(&mut self, auth: AuthScheme) { self.ctx.auth_header = Some(AuthHeader { auth_scheme: Some(auth), @@ -161,6 +174,7 @@ impl Database { schema: self.schema.clone(), authorization: self.ctx.auth_header.clone(), dbname: self.dbname.clone(), + timezone: self.timezone.clone(), // TODO(Taylor-lagrange): add client grpc tracing tracing_context: W3cTrace::new(), }), diff --git a/src/client/src/region.rs b/src/client/src/region.rs index 3967c23ed0..574af84228 100644 --- a/src/client/src/region.rs +++ b/src/client/src/region.rs @@ -230,7 +230,7 @@ mod test { let result = check_response_header(Some(ResponseHeader { status: Some(PbStatus { status_code: StatusCode::Success as u32, - err_msg: "".to_string(), + err_msg: String::default(), }), })); assert!(result.is_ok()); @@ -238,7 +238,7 @@ mod test { let result = check_response_header(Some(ResponseHeader { status: Some(PbStatus { status_code: u32::MAX, - err_msg: "".to_string(), + err_msg: String::default(), }), })); assert!(matches!( diff --git a/src/common/base/src/bytes.rs b/src/common/base/src/bytes.rs index 7dff0a54b4..7f757917b8 100644 --- a/src/common/base/src/bytes.rs +++ b/src/common/base/src/bytes.rs @@ -216,7 +216,7 @@ mod tests { let bytes = StringBytes::from(hello.clone()); assert_eq!(bytes.len(), hello.len()); - let zero = "".to_string(); + let zero = String::default(); let bytes = StringBytes::from(zero); assert!(bytes.is_empty()); } diff --git a/src/common/grpc-expr/src/alter.rs b/src/common/grpc-expr/src/alter.rs index 532f9cf15c..6eebf00f4b 100644 --- a/src/common/grpc-expr/src/alter.rs +++ b/src/common/grpc-expr/src/alter.rs @@ -145,8 +145,8 @@ mod tests { #[test] fn test_alter_expr_to_request() { let expr = AlterExpr { - catalog_name: "".to_string(), - schema_name: "".to_string(), + catalog_name: String::default(), + schema_name: String::default(), table_name: "monitor".to_string(), kind: Some(Kind::AddColumns(AddColumns { @@ -186,8 +186,8 @@ mod tests { #[test] fn test_alter_expr_with_location_to_request() { let expr = AlterExpr { - catalog_name: "".to_string(), - schema_name: "".to_string(), + catalog_name: String::default(), + schema_name: String::default(), table_name: "monitor".to_string(), kind: Some(Kind::AddColumns(AddColumns { @@ -204,7 +204,7 @@ mod tests { }), location: Some(Location { location_type: LocationType::First.into(), - after_column_name: "".to_string(), + after_column_name: String::default(), }), }, AddColumn { diff --git a/src/common/procedure/src/local.rs b/src/common/procedure/src/local.rs index 30c0403f68..624e98d181 100644 --- a/src/common/procedure/src/local.rs +++ b/src/common/procedure/src/local.rs @@ -374,7 +374,7 @@ pub struct ManagerConfig { impl Default for ManagerConfig { fn default() -> Self { Self { - parent_path: "".to_string(), + parent_path: String::default(), max_retry_times: 3, retry_delay: Duration::from_millis(500), remove_outdated_meta_task_interval: Duration::from_secs(60 * 10), diff --git a/src/common/query/src/lib.rs b/src/common/query/src/lib.rs index 95d1aed2fe..679fd53a98 100644 --- a/src/common/query/src/lib.rs +++ b/src/common/query/src/lib.rs @@ -60,7 +60,7 @@ impl From<&AddColumnLocation> for Location { match value { AddColumnLocation::First => Location { location_type: LocationType::First.into(), - after_column_name: "".to_string(), + after_column_name: String::default(), }, AddColumnLocation::After { column_name } => Location { location_type: LocationType::After.into(), diff --git a/src/common/time/src/interval.rs b/src/common/time/src/interval.rs index 892e4807c4..95e8982f07 100644 --- a/src/common/time/src/interval.rs +++ b/src/common/time/src/interval.rs @@ -395,7 +395,7 @@ impl IntervalFormat { return "PT0S".to_string(); } let fract_str = match self.microseconds { - 0 => "".to_string(), + 0 => String::default(), _ => format!(".{:06}", self.microseconds) .trim_end_matches('0') .to_string(), @@ -446,7 +446,7 @@ impl IntervalFormat { if self.is_zero() { return "00:00:00".to_string(); } - let mut result = "".to_string(); + let mut result = String::default(); if self.has_year_month() { if self.years != 0 { result.push_str(&format!("{} year ", self.years)); @@ -464,7 +464,7 @@ impl IntervalFormat { /// get postgres time part(include hours, minutes, seconds, microseconds) fn get_postgres_time_part(&self) -> String { - let mut time_part = "".to_string(); + let mut time_part = String::default(); if self.has_time_part() { let sign = if !self.has_time_part_positive() { "-" @@ -516,7 +516,7 @@ fn get_time_part( is_time_part_positive: bool, is_only_time: bool, ) -> String { - let mut interval = "".to_string(); + let mut interval = String::default(); if is_time_part_positive && is_only_time { interval.push_str(&format!("{}:{:02}:{:02}", hours, mins, secs)); } else { diff --git a/src/common/time/src/timezone.rs b/src/common/time/src/timezone.rs index 700e0db073..0e59708494 100644 --- a/src/common/time/src/timezone.rs +++ b/src/common/time/src/timezone.rs @@ -52,6 +52,16 @@ pub fn get_timezone(tz: Option) -> Timezone { }) } +#[inline(always)] +/// If the `tz = Some("") || None || Some(Invalid timezone)`, return system timezone, +/// or return parsed `tz` as timezone. +pub fn parse_timezone(tz: Option<&str>) -> Timezone { + match tz { + None | Some("") => Timezone::Named(Tz::UTC), + Some(tz) => Timezone::from_tz_string(tz).unwrap_or(Timezone::Named(Tz::UTC)), + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Timezone { Offset(FixedOffset), diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 4470476772..c6ecabf342 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -457,6 +457,8 @@ pub fn check_permission( // show create table and alter are not supported yet Statement::ShowCreateTable(_) | Statement::CreateExternalTable(_) | Statement::Alter(_) => { } + // set/show variable now only alter/show variable in session + Statement::SetVariables(_) | Statement::ShowVariables(_) => {} Statement::Insert(insert) => { validate_param(insert.table_name(), query_ctx)?; diff --git a/src/meta-srv/src/procedure/region_failover.rs b/src/meta-srv/src/procedure/region_failover.rs index 7af5b86a3e..fbc8f781a0 100644 --- a/src/meta-srv/src/procedure/region_failover.rs +++ b/src/meta-srv/src/procedure/region_failover.rs @@ -532,7 +532,7 @@ mod tests { let nodes = (1..=region_distribution.len()) .map(|id| Peer { id: id as u64, - addr: "".to_string(), + addr: String::default(), }) .collect(); Arc::new(RandomNodeSelector { nodes }) @@ -751,7 +751,7 @@ mod tests { peers: Arc::new(Mutex::new(vec![ Some(Peer { id: 42, - addr: "".to_string(), + addr: String::default(), }), None, ])), diff --git a/src/meta-srv/src/procedure/utils.rs b/src/meta-srv/src/procedure/utils.rs index 094ab4d14c..b8565c8dc7 100644 --- a/src/meta-srv/src/procedure/utils.rs +++ b/src/meta-srv/src/procedure/utils.rs @@ -89,7 +89,7 @@ pub mod mock { header: Some(ResponseHeader { status: Some(PbStatus { status_code: 0, - err_msg: "".to_string(), + err_msg: String::default(), }), }), affected_rows: 0, diff --git a/src/operator/src/expr_factory.rs b/src/operator/src/expr_factory.rs index aec2b51566..0f1d5faeb7 100644 --- a/src/operator/src/expr_factory.rs +++ b/src/operator/src/expr_factory.rs @@ -167,7 +167,7 @@ pub(crate) async fn create_external_expr( catalog_name, schema_name, table_name, - desc: "".to_string(), + desc: String::default(), column_defs, time_index, primary_keys, @@ -198,7 +198,7 @@ pub fn create_to_expr(create: &CreateTable, query_ctx: QueryContextRef) -> Resul catalog_name, schema_name, table_name, - desc: "".to_string(), + desc: String::default(), column_defs: columns_to_expr(&create.columns, &time_index, &primary_keys)?, time_index, primary_keys, diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index b03ff12888..216942521a 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -34,7 +34,7 @@ use common_meta::table_name::TableName; use common_query::Output; use common_telemetry::tracing; use common_time::range::TimestampRange; -use common_time::Timestamp; +use common_time::{Timestamp, Timezone}; use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef}; use query::parser::QueryStatement; use query::plan::LogicalPlan; @@ -45,14 +45,14 @@ use sql::statements::copy::{CopyDatabaseArgument, CopyTable, CopyTableArgument}; use sql::statements::statement::Statement; use sql::statements::OptionMap; use sql::util::format_raw_object_name; -use sqlparser::ast::ObjectName; +use sqlparser::ast::{Expr, ObjectName, Value}; use table::engine::TableReference; use table::requests::{CopyDatabaseRequest, CopyDirection, CopyTableRequest}; use table::TableRef; use crate::error::{ - self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, PlanStatementSnafu, - Result, TableNotFoundSnafu, + self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu, + PlanStatementSnafu, Result, TableNotFoundSnafu, }; use crate::insert::InserterRef; use crate::statement::backup::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY}; @@ -188,6 +188,20 @@ impl StatementExecutor { self.show_create_table(table_name, table_ref, query_ctx) .await } + Statement::SetVariables(set_var) => { + let var_name = set_var.variable.to_string().to_uppercase(); + match var_name.as_str() { + "TIMEZONE" | "TIME_ZONE" => set_timezone(set_var.value, query_ctx)?, + _ => { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail() + } + } + Ok(Output::AffectedRows(0)) + } + Statement::ShowVariables(show_variable) => self.show_variable(show_variable, query_ctx), } } @@ -228,6 +242,33 @@ impl StatementExecutor { } } +fn set_timezone(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let tz_expr = exprs.first().context(NotSupportedSnafu { + feat: "No timezone find in set variable statement", + })?; + match tz_expr { + Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => { + match Timezone::from_tz_string(tz.as_str()) { + Ok(timezone) => ctx.set_timezone(timezone), + Err(_) => { + return NotSupportedSnafu { + feat: format!("Invalid timezone expr {} in set variable statement", tz), + } + .fail() + } + } + Ok(()) + } + expr => NotSupportedSnafu { + feat: format!( + "Unsupported timezone expr {} in set variable statement", + expr + ), + } + .fail(), + } +} + fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result { let direction = match stmt { CopyTable::To(_) => CopyDirection::Export, diff --git a/src/operator/src/statement/show.rs b/src/operator/src/statement/show.rs index 3841d721c1..ee3c2e1071 100644 --- a/src/operator/src/statement/show.rs +++ b/src/operator/src/statement/show.rs @@ -21,7 +21,7 @@ use session::context::QueryContextRef; use snafu::ResultExt; use sql::ast::{Ident, Value as SqlValue}; use sql::statements::create::{PartitionEntry, Partitions}; -use sql::statements::show::{ShowDatabases, ShowTables}; +use sql::statements::show::{ShowDatabases, ShowTables, ShowVariables}; use sql::{statements, MAXVALUE}; use table::TableRef; @@ -71,6 +71,11 @@ impl StatementExecutor { query::sql::show_create_table(table, partitions, query_ctx) .context(error::ExecuteStatementSnafu) } + + #[tracing::instrument(skip_all)] + pub fn show_variable(&self, stmt: ShowVariables, query_ctx: QueryContextRef) -> Result { + query::sql::show_variable(stmt, query_ctx).context(error::ExecuteStatementSnafu) + } } fn create_partitions_stmt(partitions: Vec) -> Result> { diff --git a/src/query/src/error.rs b/src/query/src/error.rs index b6ea37a1e5..c343d25d51 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -30,6 +30,9 @@ pub enum Error { #[snafu(display("Unsupported expr type: {}", name))] UnsupportedExpr { name: String, location: Location }, + #[snafu(display("Unsupported show variable: {}", name))] + UnsupportedVariable { name: String, location: Location }, + #[snafu(display("Operation {} not implemented yet", operation))] Unimplemented { operation: String, @@ -274,6 +277,7 @@ impl ErrorExt for Error { | ConvertSchema { .. } | AddSystemTimeOverflow { .. } | ColumnSchemaIncompatible { .. } + | UnsupportedVariable { .. } | ColumnSchemaNoDefault { .. } => StatusCode::InvalidArguments, BuildBackend { .. } | ListObjects { .. } => StatusCode::StorageUnavailable, diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index 28e803f48e..995c56a90a 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -28,6 +28,7 @@ use common_datasource::util::find_dir_and_filename; use common_query::prelude::GREPTIME_TIMESTAMP; use common_query::Output; use common_recordbatch::{RecordBatch, RecordBatches}; +use common_time::timezone::get_timezone; use common_time::Timestamp; use datatypes::prelude::*; use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, RawSchema, Schema}; @@ -38,12 +39,12 @@ use regex::Regex; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::create::Partitions; -use sql::statements::show::{ShowDatabases, ShowKind, ShowTables}; +use sql::statements::show::{ShowDatabases, ShowKind, ShowTables, ShowVariables}; use table::requests::{FILE_TABLE_LOCATION_KEY, FILE_TABLE_PATTERN_KEY}; use table::TableRef; use crate::datafusion::execute_show_with_filter; -use crate::error::{self, Result}; +use crate::error::{self, Result, UnsupportedVariableSnafu}; const SCHEMAS_COLUMN: &str = "Schemas"; const TABLES_COLUMN: &str = "Tables"; @@ -229,6 +230,26 @@ pub async fn show_tables( } } +pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result { + let variable = stmt.variable.to_string().to_uppercase(); + let value = match variable.as_str() { + "SYSTEM_TIME_ZONE" | "SYSTEM_TIMEZONE" => get_timezone(None).to_string(), + "TIME_ZONE" | "TIMEZONE" => query_ctx.timezone().to_string(), + _ => return UnsupportedVariableSnafu { name: variable }.fail(), + }; + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + variable, + ConcreteDataType::string_datatype(), + false, + )])); + let records = RecordBatches::try_from_columns( + schema, + vec![Arc::new(StringVector::from(vec![value])) as _], + ) + .context(error::CreateRecordBatchSnafu)?; + Ok(Output::RecordBatches(records)) +} + pub fn show_create_table( table: TableRef, partitions: Option, @@ -524,13 +545,18 @@ mod test { use common_query::Output; use common_recordbatch::{RecordBatch, RecordBatches}; use common_time::timestamp::TimeUnit; + use common_time::Timezone; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{StringVector, TimestampMillisecondVector, UInt32Vector, VectorRef}; + use session::context::QueryContextBuilder; use snafu::ResultExt; + use sql::ast::{Ident, ObjectName}; + use sql::statements::show::ShowVariables; use table::test_util::MemTable; use table::TableRef; + use super::show_variable; use crate::error; use crate::error::Result; use crate::sql::{ @@ -603,4 +629,44 @@ mod test { let record_batch = RecordBatch::new(table_schema, data).unwrap(); MemTable::table(table_name, record_batch) } + + #[test] + fn test_show_variable() { + assert_eq!( + exec_show_variable("SYSTEM_TIME_ZONE", "Asia/Shanghai").unwrap(), + "UTC" + ); + assert_eq!( + exec_show_variable("SYSTEM_TIMEZONE", "Asia/Shanghai").unwrap(), + "UTC" + ); + assert_eq!( + exec_show_variable("TIME_ZONE", "Asia/Shanghai").unwrap(), + "Asia/Shanghai" + ); + assert_eq!( + exec_show_variable("TIMEZONE", "Asia/Shanghai").unwrap(), + "Asia/Shanghai" + ); + assert!(exec_show_variable("TIME ZONE", "Asia/Shanghai").is_err()); + assert!(exec_show_variable("SYSTEM TIME ZONE", "Asia/Shanghai").is_err()); + } + + fn exec_show_variable(variable: &str, tz: &str) -> Result { + let stmt = ShowVariables { + variable: ObjectName(vec![Ident::new(variable)]), + }; + let ctx = QueryContextBuilder::default() + .timezone(Timezone::from_tz_string(tz).unwrap()) + .build(); + match show_variable(stmt, ctx) { + Ok(Output::RecordBatches(record)) => { + let record = record.take().first().cloned().unwrap(); + let data = record.column(0); + Ok(data.get(0).to_string()) + } + Ok(_) => unreachable!(), + Err(e) => Err(e), + } + } } diff --git a/src/servers/src/export_metrics.rs b/src/servers/src/export_metrics.rs index 956d5ddd2f..12ca611aec 100644 --- a/src/servers/src/export_metrics.rs +++ b/src/servers/src/export_metrics.rs @@ -306,7 +306,9 @@ mod test { assert!(ExportMetricsTask::try_new( &ExportMetricsOption { enable: true, - self_import: Some(SelfImportOption { db: "".to_string() }), + self_import: Some(SelfImportOption { + db: String::default() + }), remote_write: None, ..Default::default() }, @@ -319,7 +321,7 @@ mod test { enable: true, self_import: None, remote_write: Some(RemoteWriteOption { - url: "".to_string(), + url: String::default(), ..Default::default() }), ..Default::default() diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index 36c5189964..01b1dc59a7 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -28,6 +28,7 @@ use common_error::status_code::StatusCode; use common_query::Output; use common_runtime::Runtime; use common_telemetry::logging; +use common_time::timezone::parse_timezone; use session::context::{QueryContextBuilder, QueryContextRef}; use snafu::{OptionExt, ResultExt}; @@ -161,10 +162,11 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte } }) .unwrap_or((DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)); - + let timezone = parse_timezone(header.map(|h| h.timezone.as_str())); QueryContextBuilder::default() .current_catalog(catalog.to_string()) .current_schema(schema.to_string()) + .timezone(timezone) .build() } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 1507834d81..277e105b60 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -565,10 +565,7 @@ impl HttpServer { let config_router = self .route_config(GreptimeOptionsConfigState { - greptime_config_options: self - .greptime_config_options - .clone() - .unwrap_or("".to_string()), + greptime_config_options: self.greptime_config_options.clone().unwrap_or_default(), }) .finish_api(&mut api); diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index a704b54feb..6b38958957 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -23,12 +23,14 @@ use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_telemetry::warn; +use common_time::timezone::parse_timezone; +use common_time::Timezone; use headers::Header; use secrecy::SecretString; -use session::context::QueryContext; +use session::context::QueryContextBuilder; use snafu::{ensure, OptionExt, ResultExt}; -use super::header::GreptimeDbName; +use super::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME}; use super::{ResponseFormat, PUBLIC_APIS}; use crate::error::{ self, InvalidAuthorizationHeaderSnafu, InvalidParameterSnafu, InvisibleASCIISnafu, @@ -56,7 +58,12 @@ pub async fn inner_auth( ) -> std::result::Result, Response> { // 1. prepare let (catalog, schema) = extract_catalog_and_schema(&req); - let query_ctx = QueryContext::with(catalog, schema); + let timezone = extract_timezone(&req); + let query_ctx = QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .timezone(timezone) + .build(); let need_auth = need_auth(&req); let is_influxdb = req.uri().path().contains("influxdb"); @@ -142,6 +149,17 @@ pub fn extract_catalog_and_schema(request: &Request) -> (&str, &str) { parse_catalog_and_schema_from_db_string(dbname) } +fn extract_timezone(request: &Request) -> Timezone { + // parse timezone from header + let timezone = request + .headers() + .get(&GREPTIME_TIMEZONE_HEADER_NAME) + // eat this invalid ascii error and give user the final IllegalParam error + .and_then(|header| header.to_str().ok()) + .unwrap_or(""); + parse_timezone(Some(timezone)) +} + fn get_influxdb_credentials(request: &Request) -> Result> { // compat with influxdb v2 and v1 if let Some(header) = request.headers().get(http::header::AUTHORIZATION) { diff --git a/src/servers/src/http/csv_result.rs b/src/servers/src/http/csv_result.rs index 7c26d055da..28b4c3b44f 100644 --- a/src/servers/src/http/csv_result.rs +++ b/src/servers/src/http/csv_result.rs @@ -78,7 +78,7 @@ impl IntoResponse for CsvResponse { let execution_time = self.execution_time_ms; let payload = match self.output.pop() { - None => "".to_string(), + None => String::default(), Some(GreptimeQueryOutput::AffectedRows(n)) => { format!("{n}\n") } diff --git a/src/servers/src/http/header.rs b/src/servers/src/http/header.rs index 40bcfabf4b..f8907f1e75 100644 --- a/src/servers/src/http/header.rs +++ b/src/servers/src/http/header.rs @@ -18,6 +18,8 @@ pub const GREPTIME_DB_HEADER_FORMAT: &str = "x-greptime-format"; pub const GREPTIME_DB_HEADER_EXECUTION_TIME: &str = "x-greptime-execution-time"; pub static GREPTIME_DB_HEADER_NAME: HeaderName = HeaderName::from_static("x-greptime-name"); +pub static GREPTIME_TIMEZONE_HEADER_NAME: HeaderName = + HeaderName::from_static("x-greptime-timezone"); pub struct GreptimeDbName(Option); diff --git a/src/servers/src/http/influxdb_result_v1.rs b/src/servers/src/http/influxdb_result_v1.rs index cfcc8f1f94..aac76d253d 100644 --- a/src/servers/src/http/influxdb_result_v1.rs +++ b/src/servers/src/http/influxdb_result_v1.rs @@ -50,7 +50,7 @@ pub struct InfluxdbRecordsOutput { impl InfluxdbRecordsOutput { pub fn new(columns: Vec, values: Vec>) -> Self { Self { - name: "".to_string(), + name: String::default(), columns, values, } diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index 5963d7934b..2efc45128a 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use common_query::Output; use common_recordbatch::RecordBatches; use common_time::timezone::system_timezone_name; -use common_time::Timezone; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::StringVector; @@ -52,10 +51,6 @@ static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy = static SHOW_SQL_MODE_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap()); -// Timezone settings -static SET_TIME_ZONE_PATTERN: Lazy = - Lazy::new(|| Regex::new(r"(?i)^SET TIME_ZONE\s*=\s*'(\S+)'").unwrap()); - static OTHER_NOT_SUPPORTED_STMT: Lazy = Lazy::new(|| { RegexSet::new([ // Txn. @@ -260,20 +255,6 @@ fn check_show_variables(query: &str) -> Option { recordbatches.map(Output::RecordBatches) } -// TODO(sunng87): extract this to use sqlparser for more variables -fn check_set_variables(query: &str, session: SessionRef) -> Option { - if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) { - // get the capture - let tz = captures.get(1).unwrap(); - if let Ok(timezone) = Timezone::from_tz_string(tz.as_str()) { - session.set_timezone(timezone); - return Some(Output::AffectedRows(0)); - } - } - - None -} - // Check for SET or others query, this is the final check of the federated query. fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) { @@ -299,7 +280,7 @@ fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { pub(crate) fn check( query: &str, query_ctx: QueryContextRef, - session: SessionRef, + _session: SessionRef, ) -> Option { // INSERT don't need MySQL federated check. We assume the query doesn't contain // federated or driver setup command if it starts with a 'INSERT' statement. @@ -311,7 +292,6 @@ pub(crate) fn check( check_select_variable(query, query_ctx.clone()) // Then to check "show variables like ...". .or_else(|| check_show_variables(query)) - .or_else(|| check_set_variables(query, session.clone())) // Last check .or_else(|| check_others(query, query_ctx)) } @@ -411,45 +391,4 @@ mod test { +----------------------------------+"; test(query, expected); } - - #[test] - fn test_set_timezone() { - // test default is UTC when no config in greptimedb - { - let session = Arc::new(Session::new(None, Channel::Mysql)); - let query_context = session.new_query_context(); - assert_eq!("UTC", query_context.timezone().to_string()); - } - set_default_timezone(Some("Asia/Shanghai")).unwrap(); - let session = Arc::new(Session::new(None, Channel::Mysql)); - let query_context = session.new_query_context(); - assert_eq!("Asia/Shanghai", query_context.timezone().to_string()); - let output = check( - "set time_zone = 'UTC'", - QueryContext::arc(), - session.clone(), - ); - match output.unwrap() { - Output::AffectedRows(rows) => { - assert_eq!(rows, 0) - } - _ => unreachable!(), - } - let query_context = session.new_query_context(); - assert_eq!("UTC", query_context.timezone().to_string()); - - let output = check("select @@time_zone", query_context.clone(), session.clone()); - match output.unwrap() { - Output::RecordBatches(r) => { - let expected = "\ -+-------------+ -| @@time_zone | -+-------------+ -| UTC | -+-------------+"; - assert_eq!(r.pretty_print().unwrap(), expected); - } - _ => unreachable!(), - } - } } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 7dace0c415..ee1c953015 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -99,7 +99,9 @@ impl MysqlInstanceShim { { vec![Ok(output)] } else { - self.query_handler.do_query(query, query_ctx).await + let output = self.query_handler.do_query(query, query_ctx.clone()).await; + query_ctx.update_session(&self.session); + output } } diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index e311324568..726daf061f 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -285,7 +285,7 @@ pub(crate) fn create_mysql_column( // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server // implementation, will revisit them again in the future. - table: "".to_string(), + table: String::default(), colflags, }) } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 5356057238..fee9e71bad 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -53,7 +53,8 @@ impl SimpleQueryHandler for PostgresServerHandler { let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()]) .start_timer(); - let outputs = self.query_handler.do_query(query, query_ctx).await; + let outputs = self.query_handler.do_query(query, query_ctx.clone()).await; + query_ctx.update_session(&self.session); let mut results = Vec::with_capacity(outputs.len()); diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 256217d785..7446624ae5 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -26,6 +26,8 @@ use common_time::Timezone; use derive_builder::Builder; use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; +use crate::SessionRef; + pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; @@ -36,10 +38,18 @@ pub struct QueryContext { current_catalog: String, current_schema: String, current_user: ArcSwap>, - timezone: Timezone, + #[builder(setter(custom))] + timezone: ArcSwap, sql_dialect: Box, } +impl QueryContextBuilder { + pub fn timezone(mut self, tz: Timezone) -> Self { + self.timezone = Some(ArcSwap::new(Arc::new(tz))); + self + } +} + impl Display for QueryContext { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -58,7 +68,8 @@ impl From<&RegionRequestHeader> for QueryContext { current_catalog: catalog.to_string(), current_schema: schema.to_string(), current_user: Default::default(), - timezone: get_timezone(None), + // for request send to datanode, all timestamp have converted to UTC, so timezone is not important + timezone: ArcSwap::new(Arc::new(get_timezone(None))), sql_dialect: Box::new(GreptimeDbDialect {}), } } @@ -94,17 +105,14 @@ impl QueryContext { .build() } - #[inline] pub fn current_schema(&self) -> &str { &self.current_schema } - #[inline] pub fn current_catalog(&self) -> &str { &self.current_catalog } - #[inline] pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) { &*self.sql_dialect } @@ -115,20 +123,30 @@ impl QueryContext { build_db_string(catalog, schema) } - #[inline] pub fn timezone(&self) -> Timezone { - self.timezone.clone() + self.timezone.load().as_ref().clone() } - #[inline] pub fn current_user(&self) -> Option { self.current_user.load().as_ref().clone() } - #[inline] pub fn set_current_user(&self, user: Option) { let _ = self.current_user.swap(Arc::new(user)); } + + pub fn set_timezone(&self, timezone: Timezone) { + let _ = self.timezone.swap(Arc::new(timezone)); + } + + /// SQL like `set variable` may change timezone or other info in `QueryContext`. + /// We need persist these change in `Session`. + pub fn update_session(&self, session: &SessionRef) { + let tz = self.timezone(); + if session.timezone() != tz { + session.set_timezone(tz) + } + } } impl QueryContextBuilder { @@ -143,7 +161,9 @@ impl QueryContextBuilder { current_user: self .current_user .unwrap_or_else(|| ArcSwap::new(Arc::new(None))), - timezone: self.timezone.unwrap_or(get_timezone(None)), + timezone: self + .timezone + .unwrap_or(ArcSwap::new(Arc::new(get_timezone(None)))), sql_dialect: self .sql_dialect .unwrap_or_else(|| Box::new(GreptimeDbDialect {})), diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 49290826a0..35035cda27 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -59,7 +59,7 @@ impl Session { .current_catalog(self.catalog.load().to_string()) .current_schema(self.schema.load().to_string()) .sql_dialect(self.conn_info.channel.dialect()) - .timezone((**self.timezone.load()).clone()) + .timezone(self.timezone()) .build() } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 2720ee5a95..80a142581b 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -117,6 +117,8 @@ impl<'a> ParserContext<'a> { Keyword::TRUNCATE => self.parse_truncate(), + Keyword::SET => self.parse_set_variables(), + Keyword::NoKeyword if w.value.to_uppercase() == tql_parser::TQL && w.quote_style.is_none() => { diff --git a/src/sql/src/parsers.rs b/src/sql/src/parsers.rs index 1c62242b35..ca249cf640 100644 --- a/src/sql/src/parsers.rs +++ b/src/sql/src/parsers.rs @@ -21,6 +21,7 @@ pub(crate) mod drop_parser; pub(crate) mod explain_parser; pub(crate) mod insert_parser; pub(crate) mod query_parser; +pub(crate) mod set_var_parser; pub(crate) mod show_parser; pub(crate) mod tql_parser; pub(crate) mod truncate_parser; diff --git a/src/sql/src/parsers/set_var_parser.rs b/src/sql/src/parsers/set_var_parser.rs new file mode 100644 index 0000000000..bfa6a3dcaa --- /dev/null +++ b/src/sql/src/parsers/set_var_parser.rs @@ -0,0 +1,78 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use snafu::ResultExt; +use sqlparser::ast::Statement as SpStatement; + +use crate::error::{self, Result}; +use crate::parser::ParserContext; +use crate::statements::set_variables::SetVariables; +use crate::statements::statement::Statement; + +/// SET variables statement parser implementation +impl<'a> ParserContext<'a> { + pub(crate) fn parse_set_variables(&mut self) -> Result { + let _ = self.parser.next_token(); + let spstatement = self.parser.parse_set().context(error::SyntaxSnafu)?; + match spstatement { + SpStatement::SetVariable { + variable, + value, + local, + hivevar, + } if !local && !hivevar => { + Ok(Statement::SetVariables(SetVariables { variable, value })) + } + unexp => error::UnsupportedSnafu { + sql: self.sql.to_string(), + keyword: unexp.to_string(), + } + .fail(), + } + } +} + +#[cfg(test)] +mod tests { + use sqlparser::ast::{Expr, Ident, ObjectName, Value}; + + use super::*; + use crate::dialect::GreptimeDbDialect; + + #[test] + pub fn test_set_timezone() { + // mysql style + let sql = "SET time_zone = 'UTC'"; + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); + let mut stmts = result.unwrap(); + assert_eq!( + stmts.pop().unwrap(), + Statement::SetVariables(SetVariables { + variable: ObjectName(vec![Ident::new("time_zone")]), + value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))] + }) + ); + // postgresql style + let sql = "SET TIMEZONE TO 'UTC'"; + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); + let mut stmts = result.unwrap(); + assert_eq!( + stmts.pop().unwrap(), + Statement::SetVariables(SetVariables { + variable: ObjectName(vec![Ident::new("TIMEZONE")]), + value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))], + }) + ); + } +} diff --git a/src/sql/src/parsers/show_parser.rs b/src/sql/src/parsers/show_parser.rs index 1278ea55bf..2f96889c7d 100644 --- a/src/sql/src/parsers/show_parser.rs +++ b/src/sql/src/parsers/show_parser.rs @@ -18,7 +18,9 @@ use sqlparser::tokenizer::Token; use crate::error::{self, InvalidDatabaseNameSnafu, InvalidTableNameSnafu, Result}; use crate::parser::ParserContext; -use crate::statements::show::{ShowCreateTable, ShowDatabases, ShowKind, ShowTables}; +use crate::statements::show::{ + ShowCreateTable, ShowDatabases, ShowKind, ShowTables, ShowVariables, +}; use crate::statements::statement::Statement; /// SHOW statement parser implementation @@ -43,6 +45,16 @@ impl<'a> ParserContext<'a> { } else { self.unsupported(self.peek_token_as_string()) } + } else if self.consume_token("VARIABLES") { + let variable = + self.parser + .parse_object_name() + .with_context(|_| error::UnexpectedSnafu { + sql: self.sql, + expected: "a variable name", + actual: self.peek_token_as_string(), + })?; + Ok(Statement::ShowVariables(ShowVariables { variable })) } else { self.unsupported(self.peek_token_as_string()) } @@ -178,6 +190,8 @@ impl<'a> ParserContext<'a> { mod tests { use std::assert_matches::assert_matches; + use sqlparser::ast::{Ident, ObjectName}; + use super::*; use crate::dialect::GreptimeDbDialect; use crate::statements::show::ShowDatabases; @@ -387,4 +401,18 @@ mod tests { }) ); } + + #[test] + pub fn test_show_variables() { + let sql = "SHOW VARIABLES system_time_zone"; + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); + let stmts = result.unwrap(); + assert_eq!(1, stmts.len()); + assert_eq!( + stmts[0], + Statement::ShowVariables(ShowVariables { + variable: ObjectName(vec![Ident::new("system_time_zone")]), + }) + ); + } } diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index ebc1c4d9f4..c8da838234 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -22,6 +22,7 @@ pub mod explain; pub mod insert; mod option_map; pub mod query; +pub mod set_variables; pub mod show; pub mod statement; pub mod tql; @@ -505,7 +506,7 @@ pub fn sql_location_to_grpc_add_column_location( match location { Some(AddColumnLocation::First) => Some(Location { location_type: LocationType::First.into(), - after_column_name: "".to_string(), + after_column_name: String::default(), }), Some(AddColumnLocation::After { column_name }) => Some(Location { location_type: LocationType::After.into(), diff --git a/src/sql/src/statements/create.rs b/src/sql/src/statements/create.rs index e7fabd5a2d..bfa12812cc 100644 --- a/src/sql/src/statements/create.rs +++ b/src/sql/src/statements/create.rs @@ -96,7 +96,7 @@ impl CreateTable { if let Some(partitions) = &self.partitions { format!("{}\n", partitions) } else { - "".to_string() + String::default() } } @@ -112,7 +112,7 @@ impl CreateTable { #[inline] fn format_options(&self) -> String { if self.options.is_empty() { - "".to_string() + String::default() } else { let options: Vec<&SqlOption> = self.options.iter().sorted().collect(); let options = format_list_indent!(options); diff --git a/src/sql/src/statements/set_variables.rs b/src/sql/src/statements/set_variables.rs new file mode 100644 index 0000000000..71d6849833 --- /dev/null +++ b/src/sql/src/statements/set_variables.rs @@ -0,0 +1,23 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlparser::ast::{Expr, ObjectName}; +use sqlparser_derive::{Visit, VisitMut}; + +/// SET variables statement. +#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)] +pub struct SetVariables { + pub variable: ObjectName, + pub value: Vec, +} diff --git a/src/sql/src/statements/show.rs b/src/sql/src/statements/show.rs index b0d7121428..f3d066675c 100644 --- a/src/sql/src/statements/show.rs +++ b/src/sql/src/statements/show.rs @@ -63,6 +63,12 @@ pub struct ShowCreateTable { pub table_name: ObjectName, } +/// SQL structure for `SHOW VARIABLES xxx`. +#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)] +pub struct ShowVariables { + pub variable: ObjectName, +} + #[cfg(test)] mod tests { use std::assert_matches::assert_matches; diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 462535be4c..028f437c3d 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -16,6 +16,7 @@ use datafusion_sql::parser::Statement as DfStatement; use sqlparser::ast::Statement as SpStatement; use sqlparser_derive::{Visit, VisitMut}; +use super::show::ShowVariables; use crate::error::{ConvertToDfStatementSnafu, Error}; use crate::statements::alter::AlterTable; use crate::statements::create::{CreateDatabase, CreateExternalTable, CreateTable}; @@ -25,6 +26,7 @@ use crate::statements::drop::DropTable; use crate::statements::explain::Explain; use crate::statements::insert::Insert; use crate::statements::query::Query; +use crate::statements::set_variables::SetVariables; use crate::statements::show::{ShowCreateTable, ShowDatabases, ShowTables}; use crate::statements::tql::Tql; use crate::statements::truncate::TruncateTable; @@ -64,6 +66,10 @@ pub enum Statement { Tql(Tql), // TRUNCATE TABLE TruncateTable(TruncateTable), + // SET VARIABLES + SetVariables(SetVariables), + // SHOW VARIABLES + ShowVariables(ShowVariables), } /// Comment hints from SQL. diff --git a/src/store-api/src/region_request.rs b/src/store-api/src/region_request.rs index 7c0456f447..f3bfc4def6 100644 --- a/src/store-api/src/region_request.rs +++ b/src/store-api/src/region_request.rs @@ -501,14 +501,14 @@ mod tests { fn test_from_proto_location() { let proto_location = v1::AddColumnLocation { location_type: LocationType::First as i32, - after_column_name: "".to_string(), + after_column_name: String::default(), }; let location = AddColumnLocation::try_from(proto_location).unwrap(); assert_eq!(location, AddColumnLocation::First); let proto_location = v1::AddColumnLocation { location_type: 10, - after_column_name: "".to_string(), + after_column_name: String::default(), }; AddColumnLocation::try_from(proto_location).unwrap_err(); @@ -562,7 +562,7 @@ mod tests { }), location: Some(v1::AddColumnLocation { location_type: LocationType::First as i32, - after_column_name: "".to_string(), + after_column_name: String::default(), }), }], })), diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 7f099adb25..12625182ed 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -71,6 +71,7 @@ macro_rules! grpc_tests { test_grpc_auth, test_health_check, test_prom_gateway_query, + test_grpc_timezone, ); )* }; @@ -627,3 +628,64 @@ pub async fn test_prom_gateway_query(store_type: StorageType) { let _ = fe_grpc_server.shutdown().await; guard.remove_all().await; } + +pub async fn test_grpc_timezone(store_type: StorageType) { + let config = GrpcServerConfig { + max_recv_message_size: 1024, + max_send_message_size: 1024, + }; + let (addr, mut guard, fe_grpc_server) = + setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + + let grpc_client = Client::with_urls(vec![addr]); + let mut db = Database::new_with_dbname( + format!("{}-{}", DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME), + grpc_client, + ); + db.set_timezone("Asia/Shanghai"); + let sys1 = to_batch(db.sql("show variables system_time_zone;").await.unwrap()).await; + let user1 = to_batch(db.sql("show variables time_zone;").await.unwrap()).await; + db.set_timezone(""); + let sys2 = to_batch(db.sql("show variables system_time_zone;").await.unwrap()).await; + let user2 = to_batch(db.sql("show variables time_zone;").await.unwrap()).await; + assert_eq!(sys1, sys2); + assert_eq!( + sys2, + "\ ++------------------+ +| SYSTEM_TIME_ZONE | ++------------------+ +| UTC | ++------------------+" + ); + assert_eq!( + user1, + "\ ++---------------+ +| TIME_ZONE | ++---------------+ +| Asia/Shanghai | ++---------------+" + ); + assert_eq!( + user2, + "\ ++-----------+ +| TIME_ZONE | ++-----------+ +| UTC | ++-----------+" + ); + let _ = fe_grpc_server.shutdown().await; + guard.remove_all().await; +} + +async fn to_batch(output: Output) -> String { + match output { + Output::RecordBatches(batch) => batch, + Output::Stream(stream) => RecordBatches::try_collect(stream).await.unwrap(), + Output::AffectedRows(_) => unreachable!(), + } + .pretty_print() + .unwrap() +} diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 351fa7cd0b..620d1bf4c5 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -15,13 +15,14 @@ use std::collections::BTreeMap; use auth::user_provider_from_option; -use axum::http::StatusCode; +use axum::http::{HeaderName, StatusCode}; use axum_test_helper::TestClient; use common_error::status_code::StatusCode as ErrorCode; use serde_json::json; use servers::http::error_result::ErrorResponse; use servers::http::greptime_result_v1::GreptimedbV1Response; use servers::http::handler::HealthResponse; +use servers::http::header::GREPTIME_TIMEZONE_HEADER_NAME; use servers::http::influxdb_result_v1::{InfluxdbOutput, InfluxdbV1Response}; use servers::http::prometheus::{PrometheusJsonResponse, PrometheusResponse}; use servers::http::GreptimeQueryOutput; @@ -318,6 +319,43 @@ pub async fn test_sql_api(store_type: StorageType) { let body = serde_json::from_str::(&res.text().await).unwrap(); assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32); + // test timezone header + let res = client + .get("/v1/sql?&sql=show variables system_time_zone") + .header( + TryInto::::try_into(GREPTIME_TIMEZONE_HEADER_NAME.to_string()).unwrap(), + "Asia/Shanghai", + ) + .send() + .await + .text() + .await; + assert!(res.contains("SYSTEM_TIME_ZONE") && res.contains("UTC")); + let res = client + .get("/v1/sql?&sql=show variables time_zone") + .header( + TryInto::::try_into(GREPTIME_TIMEZONE_HEADER_NAME.to_string()).unwrap(), + "Asia/Shanghai", + ) + .send() + .await + .text() + .await; + assert!(res.contains("TIME_ZONE") && res.contains("Asia/Shanghai")); + let res = client + .get("/v1/sql?&sql=show variables system_time_zone") + .send() + .await + .text() + .await; + assert!(res.contains("SYSTEM_TIME_ZONE") && res.contains("UTC")); + let res = client + .get("/v1/sql?&sql=show variables time_zone") + .send() + .await + .text() + .await; + assert!(res.contains("TIME_ZONE") && res.contains("UTC")); guard.remove_all().await; } diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 93566296fa..4d4a49d0d2 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -21,7 +21,7 @@ use tests_integration::test_util::{ setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server, setup_pg_server_with_user_provider, StorageType, }; -use tokio_postgres::NoTls; +use tokio_postgres::{NoTls, SimpleQueryMessage}; #[macro_export] macro_rules! sql_test { @@ -59,6 +59,7 @@ macro_rules! sql_tests { test_mysql_async_timestamp, test_postgres_auth, test_postgres_crud, + test_postgres_timezone, test_postgres_parameter_inference, test_mysql_prepare_stmt_insert_timestamp, ); @@ -218,9 +219,19 @@ pub async fn test_mysql_timezone(store_type: StorageType) { .await .unwrap(); + let _ = conn + .execute("SET time_zone = 'Asia/Shanghai'") + .await + .unwrap(); + let timezone = conn.fetch_all("SELECT @@time_zone").await.unwrap(); + assert_eq!(timezone[0].get::(0), "Asia/Shanghai"); + let timezone = conn.fetch_all("SELECT @@system_time_zone").await.unwrap(); + assert_eq!(timezone[0].get::(0), "UTC"); let _ = conn.execute("SET time_zone = 'UTC'").await.unwrap(); let timezone = conn.fetch_all("SELECT @@time_zone").await.unwrap(); assert_eq!(timezone[0].get::(0), "UTC"); + let timezone = conn.fetch_all("SELECT @@system_time_zone").await.unwrap(); + assert_eq!(timezone[0].get::(0), "UTC"); // test data let _ = conn @@ -388,6 +399,61 @@ pub async fn test_postgres_crud(store_type: StorageType) { guard.remove_all().await; } +pub async fn test_postgres_timezone(store_type: StorageType) { + let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + + tokio::spawn(async move { + connection.await.unwrap(); + }); + + let get_row = |mess: Vec| -> String { + match &mess[0] { + SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(), + _ => unreachable!(), + } + }; + + let _ = client.simple_query("SET time_zone = 'UTC'").await.unwrap(); + let timezone = get_row( + client + .simple_query("SHOW VARIABLES time_zone") + .await + .unwrap(), + ); + assert_eq!(timezone, "UTC"); + let timezone = get_row( + client + .simple_query("SHOW VARIABLES system_time_zone") + .await + .unwrap(), + ); + assert_eq!(timezone, "UTC"); + let _ = client + .simple_query("SET time_zone = 'Asia/Shanghai'") + .await + .unwrap(); + let timezone = get_row( + client + .simple_query("SHOW VARIABLES time_zone") + .await + .unwrap(), + ); + assert_eq!(timezone, "Asia/Shanghai"); + let timezone = get_row( + client + .simple_query("SHOW VARIABLES system_time_zone") + .await + .unwrap(), + ); + assert_eq!(timezone, "UTC"); + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +} + pub async fn test_postgres_parameter_inference(store_type: StorageType) { let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;