diff --git a/Cargo.lock b/Cargo.lock index 32f82ce086..0c91587cc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1874,6 +1874,7 @@ name = "common-time" version = "0.2.0" dependencies = [ "chrono", + "chrono-tz 0.8.2", "common-error", "rand", "serde", @@ -7995,6 +7996,7 @@ dependencies = [ "arc-swap", "common-catalog", "common-telemetry", + "common-time", ] [[package]] diff --git a/src/common/time/Cargo.toml b/src/common/time/Cargo.toml index 49d778a858..c8c93bec1d 100644 --- a/src/common/time/Cargo.toml +++ b/src/common/time/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true [dependencies] chrono.workspace = true +chrono-tz = "0.8" common-error = { path = "../error" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/src/common/time/src/error.rs b/src/common/time/src/error.rs index c39b9dd61d..fd0a2269f9 100644 --- a/src/common/time/src/error.rs +++ b/src/common/time/src/error.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::any::Any; -use std::num::TryFromIntError; +use std::num::{ParseIntError, TryFromIntError}; use chrono::ParseError; use common_error::ext::ErrorExt; @@ -40,14 +40,33 @@ pub enum Error { #[snafu(display("Timestamp arithmetic overflow, msg: {}", msg))] ArithmeticOverflow { msg: String, location: Location }, + + #[snafu(display("Invalid time zone offset: {hours}:{minutes}"))] + InvalidTimeZoneOffset { + hours: i32, + minutes: u32, + location: Location, + }, + + #[snafu(display("Invalid offset string {raw}: {source}"))] + ParseOffsetStr { + raw: String, + source: ParseIntError, + location: Location, + }, + + #[snafu(display("Invalid time zone string {raw}"))] + ParseTimeZoneName { raw: String, location: Location }, } impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { - Error::ParseDateStr { .. } | Error::ParseTimestamp { .. } => { - StatusCode::InvalidArguments - } + Error::ParseDateStr { .. } + | Error::ParseTimestamp { .. } + | Error::InvalidTimeZoneOffset { .. } + | Error::ParseOffsetStr { .. } + | Error::ParseTimeZoneName { .. } => StatusCode::InvalidArguments, Error::TimestampOverflow { .. } => StatusCode::Internal, Error::InvalidDateStr { .. } | Error::ArithmeticOverflow { .. } => { StatusCode::InvalidArguments @@ -64,7 +83,10 @@ impl ErrorExt for Error { Error::ParseTimestamp { location, .. } | Error::TimestampOverflow { location, .. } | Error::ArithmeticOverflow { location, .. } => Some(*location), - Error::ParseDateStr { .. } => None, + Error::ParseDateStr { .. } + | Error::InvalidTimeZoneOffset { .. } + | Error::ParseOffsetStr { .. } + | Error::ParseTimeZoneName { .. } => None, Error::InvalidDateStr { location, .. } => Some(*location), } } diff --git a/src/common/time/src/lib.rs b/src/common/time/src/lib.rs index fdc9033bed..76558e7610 100644 --- a/src/common/time/src/lib.rs +++ b/src/common/time/src/lib.rs @@ -18,6 +18,7 @@ pub mod error; pub mod range; pub mod timestamp; pub mod timestamp_millis; +pub mod timezone; pub mod util; pub use date::Date; @@ -25,3 +26,4 @@ pub use datetime::DateTime; pub use range::RangeMillis; pub use timestamp::Timestamp; pub use timestamp_millis::TimestampMillis; +pub use timezone::TimeZone; diff --git a/src/common/time/src/timestamp.rs b/src/common/time/src/timestamp.rs index 898da08790..d1c2aef802 100644 --- a/src/common/time/src/timestamp.rs +++ b/src/common/time/src/timestamp.rs @@ -20,12 +20,13 @@ use std::str::FromStr; use std::time::Duration; use chrono::offset::Local; -use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone, Utc}; +use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone as ChronoTimeZone, Utc}; use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt}; use crate::error; use crate::error::{ArithmeticOverflowSnafu, Error, ParseTimestampSnafu, TimestampOverflowSnafu}; +use crate::timezone::TimeZone; use crate::util::div_ceil; #[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)] @@ -171,17 +172,33 @@ impl Timestamp { /// Format timestamp to ISO8601 string. If the timestamp exceeds what chrono timestamp can /// represent, this function simply print the timestamp unit and value in plain string. pub fn to_iso8601_string(&self) -> String { - self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z") + self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z", None) } pub fn to_local_string(&self) -> String { - self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f") + self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", None) } - fn as_formatted_string(self, pattern: &str) -> String { + /// Format timestamp for given timezone. + /// When timezone is None, using local time by default. + pub fn to_timezone_aware_string(&self, tz: Option) -> String { + self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", tz) + } + + fn as_formatted_string(self, pattern: &str, timezone: Option) -> String { if let Some(v) = self.to_chrono_datetime() { - let local = Local {}; - format!("{}", local.from_utc_datetime(&v).format(pattern)) + match timezone { + Some(TimeZone::Offset(offset)) => { + format!("{}", offset.from_utc_datetime(&v).format(pattern)) + } + Some(TimeZone::Named(tz)) => { + format!("{}", tz.from_utc_datetime(&v).format(pattern)) + } + None => { + let local = Local {}; + format!("{}", local.from_utc_datetime(&v).format(pattern)) + } + } } else { format!("[Timestamp{}: {}]", self.unit, self.value) } @@ -934,4 +951,54 @@ mod tests { Timestamp::new_millisecond(58).sub(Timestamp::new_millisecond(100)) ); } + + #[test] + fn test_to_timezone_aware_string() { + std::env::set_var("TZ", "Asia/Shanghai"); + + assert_eq!( + "1970-01-01 08:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond).to_timezone_aware_string(None) + ); + assert_eq!( + "1970-01-01 08:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("SYSTEM").unwrap()) + ); + assert_eq!( + "1970-01-01 08:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("+08:00").unwrap()) + ); + assert_eq!( + "1970-01-01 07:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("+07:00").unwrap()) + ); + assert_eq!( + "1969-12-31 23:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("-01:00").unwrap()) + ); + assert_eq!( + "1970-01-01 08:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("Asia/Shanghai").unwrap()) + ); + assert_eq!( + "1970-01-01 00:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("UTC").unwrap()) + ); + assert_eq!( + "1970-01-01 01:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("Europe/Berlin").unwrap()) + ); + assert_eq!( + "1970-01-01 03:00:00.001", + Timestamp::new(1, TimeUnit::Millisecond) + .to_timezone_aware_string(TimeZone::from_tz_string("Europe/Moscow").unwrap()) + ); + } } diff --git a/src/common/time/src/timezone.rs b/src/common/time/src/timezone.rs new file mode 100644 index 0000000000..6feba8b251 --- /dev/null +++ b/src/common/time/src/timezone.rs @@ -0,0 +1,158 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::str::FromStr; + +use chrono::{FixedOffset, Local}; +use chrono_tz::Tz; +use snafu::{OptionExt, ResultExt}; + +use crate::error::{ + InvalidTimeZoneOffsetSnafu, ParseOffsetStrSnafu, ParseTimeZoneNameSnafu, Result, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TimeZone { + Offset(FixedOffset), + Named(Tz), +} + +impl TimeZone { + /// Compute timezone from given offset hours and minutes + /// Return `None` if given offset exceeds scope + pub fn hours_mins_opt(offset_hours: i32, offset_mins: u32) -> Result { + let offset_secs = if offset_hours > 0 { + offset_hours * 3600 + offset_mins as i32 * 60 + } else { + offset_hours * 3600 - offset_mins as i32 * 60 + }; + + FixedOffset::east_opt(offset_secs) + .map(Self::Offset) + .context(InvalidTimeZoneOffsetSnafu { + hours: offset_hours, + minutes: offset_mins, + }) + } + + /// Parse timezone offset string and return None if given offset exceeds + /// scope. + /// + /// String examples are available as described in + /// https://dev.mysql.com/doc/refman/8.0/en/time-zone-support.html + /// + /// - `SYSTEM` + /// - Offset to UTC: `+08:00` , `-11:30` + /// - Named zones: `Asia/Shanghai`, `Europe/Berlin` + pub fn from_tz_string(tz_string: &str) -> Result> { + // Use system timezone + if tz_string.eq_ignore_ascii_case("SYSTEM") { + Ok(None) + } else if let Some((hrs, mins)) = tz_string.split_once(':') { + let hrs = hrs + .parse::() + .context(ParseOffsetStrSnafu { raw: tz_string })?; + let mins = mins + .parse::() + .context(ParseOffsetStrSnafu { raw: tz_string })?; + Self::hours_mins_opt(hrs, mins).map(Some) + } else if let Ok(tz) = Tz::from_str(tz_string) { + Ok(Some(Self::Named(tz))) + } else { + ParseTimeZoneNameSnafu { raw: tz_string }.fail() + } + } +} + +impl Display for TimeZone { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Named(tz) => write!(f, "{}", tz.name()), + Self::Offset(offset) => write!(f, "{}", offset), + } + } +} + +#[inline] +pub fn system_time_zone_name() -> String { + Local::now().offset().to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_tz_string() { + assert_eq!(None, TimeZone::from_tz_string("SYSTEM").unwrap()); + + let utc_plus_8 = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 8).unwrap())); + assert_eq!(utc_plus_8, TimeZone::from_tz_string("+8:00").unwrap()); + assert_eq!(utc_plus_8, TimeZone::from_tz_string("+08:00").unwrap()); + assert_eq!(utc_plus_8, TimeZone::from_tz_string("08:00").unwrap()); + + let utc_minus_8 = Some(TimeZone::Offset(FixedOffset::west_opt(3600 * 8).unwrap())); + assert_eq!(utc_minus_8, TimeZone::from_tz_string("-08:00").unwrap()); + assert_eq!(utc_minus_8, TimeZone::from_tz_string("-8:00").unwrap()); + + let utc_minus_8_5 = Some(TimeZone::Offset( + FixedOffset::west_opt(3600 * 8 + 60 * 30).unwrap(), + )); + assert_eq!(utc_minus_8_5, TimeZone::from_tz_string("-8:30").unwrap()); + + let utc_plus_max = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 14).unwrap())); + assert_eq!(utc_plus_max, TimeZone::from_tz_string("14:00").unwrap()); + + let utc_minus_max = Some(TimeZone::Offset( + FixedOffset::west_opt(3600 * 13 + 60 * 59).unwrap(), + )); + assert_eq!(utc_minus_max, TimeZone::from_tz_string("-13:59").unwrap()); + + assert_eq!( + Some(TimeZone::Named(Tz::Asia__Shanghai)), + TimeZone::from_tz_string("Asia/Shanghai").unwrap() + ); + assert_eq!( + Some(TimeZone::Named(Tz::UTC)), + TimeZone::from_tz_string("UTC").unwrap() + ); + + assert!(TimeZone::from_tz_string("WORLD_PEACE").is_err()); + assert!(TimeZone::from_tz_string("A0:01").is_err()); + assert!(TimeZone::from_tz_string("20:0A").is_err()); + assert!(TimeZone::from_tz_string(":::::").is_err()); + assert!(TimeZone::from_tz_string("Asia/London").is_err()); + assert!(TimeZone::from_tz_string("Unknown").is_err()); + } + + #[test] + fn test_timezone_to_string() { + assert_eq!("UTC", TimeZone::Named(Tz::UTC).to_string()); + assert_eq!( + "+01:00", + TimeZone::from_tz_string("01:00") + .unwrap() + .unwrap() + .to_string() + ); + assert_eq!( + "Asia/Shanghai", + TimeZone::from_tz_string("Asia/Shanghai") + .unwrap() + .unwrap() + .to_string() + ); + } +} diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index 036d4ae98b..b01033c68a 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -20,6 +20,8 @@ use std::sync::Arc; use common_query::Output; use common_recordbatch::RecordBatches; +use common_time::timezone::system_time_zone_name; +use common_time::TimeZone; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::StringVector; @@ -54,6 +56,10 @@ static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy = static SHOW_SQL_MODE_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap()); +// Time zone settings +static SET_TIME_ZONE_PATTERN: Lazy = + Lazy::new(|| Regex::new(r"(?i)^SET TIME_ZONE\s*=\s*'(\S+)'").unwrap()); + static OTHER_NOT_SUPPORTED_STMT: Lazy = Lazy::new(|| { RegexSet::new([ // Txn. @@ -124,8 +130,6 @@ static VAR_VALUES: Lazy> = Lazy::new(|| { ("transaction_isolation", "REPEATABLE-READ"), ("session.transaction_isolation", "REPEATABLE-READ"), ("session.transaction_read_only", "0"), - ("time_zone", "UTC"), - ("system_time_zone", "UTC"), ("max_allowed_packet", "134217728"), ("interactive_timeout", "31536000"), ("wait_timeout", "31536000"), @@ -168,7 +172,7 @@ fn show_variables(name: &str, value: &str) -> RecordBatches { .unwrap() } -fn select_variable(query: &str) -> Option { +fn select_variable(query: &str, query_context: QueryContextRef) -> Option { let mut fields = vec![]; let mut values = vec![]; @@ -191,12 +195,24 @@ fn select_variable(query: &str) -> Option { .unwrap_or("") }) .collect(); + + // get value of variables from known sources or fallback to defaults + let value = match var_as[0] { + "time_zone" => query_context + .time_zone() + .map(|tz| tz.to_string()) + .unwrap_or_else(|| "".to_owned()), + "system_time_zone" => system_time_zone_name(), + _ => VAR_VALUES + .get(var_as[0]) + .map(|v| v.to_string()) + .unwrap_or_else(|| "0".to_owned()), + }; + + values.push(Arc::new(StringVector::from(vec![value])) as _); match var_as.len() { 1 => { // @@aa - let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0"); - values.push(Arc::new(StringVector::from(vec![*value])) as _); - // field is '@@aa' fields.push(ColumnSchema::new( &format!("@@{}", var_as[0]), @@ -207,9 +223,6 @@ fn select_variable(query: &str) -> Option { 2 => { // @@bb as cc: // var is 'bb'. - let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0"); - values.push(Arc::new(StringVector::from(vec![*value])) as _); - // field is 'cc'. fields.push(ColumnSchema::new( var_as[1], @@ -227,12 +240,12 @@ fn select_variable(query: &str) -> Option { Some(Output::RecordBatches(batches)) } -fn check_select_variable(query: &str) -> Option { +fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option { if vec![&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN] .iter() .any(|r| r.is_match(query)) { - select_variable(query) + select_variable(query, query_context) } else { None } @@ -251,6 +264,20 @@ fn check_show_variables(query: &str) -> Option { recordbatches.map(Output::RecordBatches) } +// TODO(sunng87): extract this to use sqlparser for more variables +fn check_set_variables(query: &str, query_ctx: QueryContextRef) -> Option { + if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) { + // get the capture + let tz = captures.get(1).unwrap(); + if let Ok(timezone) = TimeZone::from_tz_string(tz.as_str()) { + query_ctx.set_time_zone(timezone); + return Some(Output::AffectedRows(0)); + } + } + + None +} + // Check for SET or others query, this is the final check of the federated query. fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) { @@ -283,19 +310,12 @@ pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option { } // First to check the query is like "select @@variables". - let output = check_select_variable(query); - if output.is_some() { - return output; - } - - // Then to check "show variables like ...". - let output = check_show_variables(query); - if output.is_some() { - return output; - } - - // Last check. - check_others(query, query_ctx) + check_select_variable(query, query_ctx.clone()) + // Then to check "show variables like ...". + .or_else(|| check_show_variables(query)) + .or_else(|| check_set_variables(query, query_ctx.clone())) + // Last check + .or_else(|| check_others(query, query_ctx)) } #[cfg(test)] @@ -352,13 +372,15 @@ mod test { +-----------------+------------------------+"; test(query, expected); + // set sysstem timezone + std::env::set_var("TZ", "Asia/Shanghai"); // complex variables let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;"; let expected = "\ +--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+ | auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout; | +--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+ -| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | UTC | UTC | REPEATABLE-READ | 31536000 | +| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | +08:00 | | REPEATABLE-READ | 31536000 | +--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+"; test(query, expected); @@ -395,4 +417,31 @@ mod test { +----------------------------------+"; test(query, expected); } + + #[test] + fn test_set_time_zone() { + let query_context = Arc::new(QueryContext::new()); + let output = check("set time_zone = 'UTC'", query_context.clone()); + match output.unwrap() { + Output::AffectedRows(rows) => { + assert_eq!(rows, 0) + } + _ => unreachable!(), + } + assert_eq!("UTC", query_context.time_zone().unwrap().to_string()); + + let output = check("select @@time_zone", query_context); + match output.unwrap() { + Output::RecordBatches(r) => { + let expected = "\ ++-------------+ +| @@time_zone | ++-------------+ +| UTC | ++-------------+"; + assert_eq!(r.pretty_print().unwrap(), expected); + } + _ => unreachable!(), + } + } } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index d49d4fb12f..1c48ae9401 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -231,7 +231,7 @@ impl AsyncMysqlShim for MysqlInstanceShi log::debug!("execute replaced query: {}", query); let outputs = self.do_query(&query).await; - writer::write_output(w, &query, outputs).await?; + writer::write_output(w, &query, self.session.context(), outputs).await?; Ok(()) } @@ -263,7 +263,7 @@ impl AsyncMysqlShim for MysqlInstanceShi ] ); let outputs = self.do_query(query).await; - writer::write_output(writer, query, outputs).await?; + writer::write_output(writer, query, self.session.context(), outputs).await?; Ok(()) } diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index a1109cc7f3..6a060635b6 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -22,6 +22,7 @@ use datatypes::schema::{ColumnSchema, SchemaRef}; use opensrv_mysql::{ Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter, }; +use session::context::QueryContextRef; use snafu::prelude::*; use tokio::io::AsyncWrite; @@ -31,9 +32,10 @@ use crate::error::{self, Error, Result}; pub async fn write_output<'a, W: AsyncWrite + Send + Sync + Unpin>( w: QueryResultWriter<'a, W>, query: &str, + query_context: QueryContextRef, outputs: Vec>, ) -> Result<()> { - let mut writer = Some(MysqlResultWriter::new(w)); + let mut writer = Some(MysqlResultWriter::new(w, query_context.clone())); for output in outputs { let result_writer = writer.take().context(error::InternalSnafu { err_msg: "Sending multiple result set is unsupported", @@ -54,11 +56,18 @@ struct QueryResult { pub struct MysqlResultWriter<'a, W: AsyncWrite + Unpin> { writer: QueryResultWriter<'a, W>, + query_context: QueryContextRef, } impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { - pub fn new(writer: QueryResultWriter<'a, W>) -> MysqlResultWriter<'a, W> { - MysqlResultWriter::<'a, W> { writer } + pub fn new( + writer: QueryResultWriter<'a, W>, + query_context: QueryContextRef, + ) -> MysqlResultWriter<'a, W> { + MysqlResultWriter::<'a, W> { + writer, + query_context, + } } /// Try to write one result set. If there are more than one result set, return `Some`. @@ -80,18 +89,23 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { recordbatches, schema, }; - Self::write_query_result(query, query_result, self.writer).await?; + Self::write_query_result(query, query_result, self.writer, self.query_context) + .await?; } Output::RecordBatches(recordbatches) => { let query_result = QueryResult { schema: recordbatches.schema(), recordbatches: recordbatches.take(), }; - Self::write_query_result(query, query_result, self.writer).await?; + Self::write_query_result(query, query_result, self.writer, self.query_context) + .await?; } Output::AffectedRows(rows) => { let next_writer = Self::write_affected_rows(self.writer, rows).await?; - return Ok(Some(MysqlResultWriter::new(next_writer))); + return Ok(Some(MysqlResultWriter::new( + next_writer, + self.query_context, + ))); } }, Err(error) => Self::write_query_error(query, error, self.writer).await?, @@ -122,6 +136,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { query: &str, query_result: QueryResult, writer: QueryResultWriter<'a, W>, + query_context: QueryContextRef, ) -> Result<()> { match create_mysql_column_def(&query_result.schema) { Ok(column_def) => { @@ -129,7 +144,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { // to return a new QueryResultWriter. let mut row_writer = writer.start(&column_def).await?; for recordbatch in &query_result.recordbatches { - Self::write_recordbatch(&mut row_writer, recordbatch).await?; + Self::write_recordbatch(&mut row_writer, recordbatch, query_context.clone()) + .await?; } row_writer.finish().await?; Ok(()) @@ -141,6 +157,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { async fn write_recordbatch( row_writer: &mut RowWriter<'_, W>, recordbatch: &RecordBatch, + query_context: QueryContextRef, ) -> Result<()> { for row in recordbatch.rows() { for value in row.into_iter() { @@ -161,7 +178,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { Value::Binary(v) => row_writer.write_col(v.deref())?, Value::Date(v) => row_writer.write_col(v.val())?, Value::DateTime(v) => row_writer.write_col(v.val())?, - Value::Timestamp(v) => row_writer.write_col(v.to_local_string())?, + Value::Timestamp(v) => row_writer + .write_col(v.to_timezone_aware_string(query_context.time_zone()))?, Value::List(_) => { return Err(Error::Internal { err_msg: format!( diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index f6dff95e46..06224ac8ef 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -8,3 +8,4 @@ license.workspace = true arc-swap = "1.5" common-catalog = { path = "../common/catalog" } common-telemetry = { path = "../common/telemetry" } +common-time = { path = "../common/time" } diff --git a/src/session/src/context.rs b/src/session/src/context.rs index fbfad79917..5adbe84ae9 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -20,6 +20,7 @@ use arc_swap::ArcSwap; use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_telemetry::debug; +use common_time::TimeZone; pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; @@ -28,6 +29,7 @@ pub type ConnInfoRef = Arc; pub struct QueryContext { current_catalog: ArcSwap, current_schema: ArcSwap, + time_zone: ArcSwap>, } impl Default for QueryContext { @@ -56,6 +58,7 @@ impl QueryContext { Self { current_catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.to_string())), current_schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.to_string())), + time_zone: ArcSwap::new(Arc::new(None)), } } @@ -63,6 +66,7 @@ impl QueryContext { Self { current_catalog: ArcSwap::new(Arc::new(catalog.to_string())), current_schema: ArcSwap::new(Arc::new(schema.to_string())), + time_zone: ArcSwap::new(Arc::new(None)), } } @@ -99,6 +103,16 @@ impl QueryContext { let schema = self.current_schema(); build_db_string(&catalog, &schema) } + + #[inline] + pub fn time_zone(&self) -> Option { + self.time_zone.load().as_ref().clone() + } + + #[inline] + pub fn set_time_zone(&self, tz: Option) { + self.time_zone.swap(Arc::new(tz)); + } } pub const DEFAULT_USERNAME: &str = "greptime";