mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-26 08:00:01 +00:00
feat: time_zone variable for mysql connections (#1607)
* feat: add timezone info to query context * feat: parse mysql compatible time zone string * feat: add method to timestamp for rendering timezone aware string * feat: use timezone from session for time string rendering * refactor: use querycontectref * feat: implement session/timezone variable read/write * style: resolve toml format * test: update tests * Apply suggestions from code review Co-authored-by: dennis zhuang <killme2008@gmail.com> * Update src/session/src/context.rs Co-authored-by: dennis zhuang <killme2008@gmail.com> * refactor: address review issues --------- Co-authored-by: dennis zhuang <killme2008@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1874,6 +1874,7 @@ name = "common-time"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"chrono-tz 0.8.2",
|
||||
"common-error",
|
||||
"rand",
|
||||
"serde",
|
||||
@@ -7995,6 +7996,7 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"common-catalog",
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -6,6 +6,7 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
chrono.workspace = true
|
||||
chrono-tz = "0.8"
|
||||
common-error = { path = "../error" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::num::TryFromIntError;
|
||||
use std::num::{ParseIntError, TryFromIntError};
|
||||
|
||||
use chrono::ParseError;
|
||||
use common_error::ext::ErrorExt;
|
||||
@@ -40,14 +40,33 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Timestamp arithmetic overflow, msg: {}", msg))]
|
||||
ArithmeticOverflow { msg: String, location: Location },
|
||||
|
||||
#[snafu(display("Invalid time zone offset: {hours}:{minutes}"))]
|
||||
InvalidTimeZoneOffset {
|
||||
hours: i32,
|
||||
minutes: u32,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Invalid offset string {raw}: {source}"))]
|
||||
ParseOffsetStr {
|
||||
raw: String,
|
||||
source: ParseIntError,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Invalid time zone string {raw}"))]
|
||||
ParseTimeZoneName { raw: String, location: Location },
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::ParseDateStr { .. } | Error::ParseTimestamp { .. } => {
|
||||
StatusCode::InvalidArguments
|
||||
}
|
||||
Error::ParseDateStr { .. }
|
||||
| Error::ParseTimestamp { .. }
|
||||
| Error::InvalidTimeZoneOffset { .. }
|
||||
| Error::ParseOffsetStr { .. }
|
||||
| Error::ParseTimeZoneName { .. } => StatusCode::InvalidArguments,
|
||||
Error::TimestampOverflow { .. } => StatusCode::Internal,
|
||||
Error::InvalidDateStr { .. } | Error::ArithmeticOverflow { .. } => {
|
||||
StatusCode::InvalidArguments
|
||||
@@ -64,7 +83,10 @@ impl ErrorExt for Error {
|
||||
Error::ParseTimestamp { location, .. }
|
||||
| Error::TimestampOverflow { location, .. }
|
||||
| Error::ArithmeticOverflow { location, .. } => Some(*location),
|
||||
Error::ParseDateStr { .. } => None,
|
||||
Error::ParseDateStr { .. }
|
||||
| Error::InvalidTimeZoneOffset { .. }
|
||||
| Error::ParseOffsetStr { .. }
|
||||
| Error::ParseTimeZoneName { .. } => None,
|
||||
Error::InvalidDateStr { location, .. } => Some(*location),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ pub mod error;
|
||||
pub mod range;
|
||||
pub mod timestamp;
|
||||
pub mod timestamp_millis;
|
||||
pub mod timezone;
|
||||
pub mod util;
|
||||
|
||||
pub use date::Date;
|
||||
@@ -25,3 +26,4 @@ pub use datetime::DateTime;
|
||||
pub use range::RangeMillis;
|
||||
pub use timestamp::Timestamp;
|
||||
pub use timestamp_millis::TimestampMillis;
|
||||
pub use timezone::TimeZone;
|
||||
|
||||
@@ -20,12 +20,13 @@ use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::offset::Local;
|
||||
use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone, Utc};
|
||||
use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone as ChronoTimeZone, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::error;
|
||||
use crate::error::{ArithmeticOverflowSnafu, Error, ParseTimestampSnafu, TimestampOverflowSnafu};
|
||||
use crate::timezone::TimeZone;
|
||||
use crate::util::div_ceil;
|
||||
|
||||
#[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)]
|
||||
@@ -171,17 +172,33 @@ impl Timestamp {
|
||||
/// Format timestamp to ISO8601 string. If the timestamp exceeds what chrono timestamp can
|
||||
/// represent, this function simply print the timestamp unit and value in plain string.
|
||||
pub fn to_iso8601_string(&self) -> String {
|
||||
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z")
|
||||
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z", None)
|
||||
}
|
||||
|
||||
pub fn to_local_string(&self) -> String {
|
||||
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f")
|
||||
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", None)
|
||||
}
|
||||
|
||||
fn as_formatted_string(self, pattern: &str) -> String {
|
||||
/// Format timestamp for given timezone.
|
||||
/// When timezone is None, using local time by default.
|
||||
pub fn to_timezone_aware_string(&self, tz: Option<TimeZone>) -> String {
|
||||
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", tz)
|
||||
}
|
||||
|
||||
fn as_formatted_string(self, pattern: &str, timezone: Option<TimeZone>) -> String {
|
||||
if let Some(v) = self.to_chrono_datetime() {
|
||||
let local = Local {};
|
||||
format!("{}", local.from_utc_datetime(&v).format(pattern))
|
||||
match timezone {
|
||||
Some(TimeZone::Offset(offset)) => {
|
||||
format!("{}", offset.from_utc_datetime(&v).format(pattern))
|
||||
}
|
||||
Some(TimeZone::Named(tz)) => {
|
||||
format!("{}", tz.from_utc_datetime(&v).format(pattern))
|
||||
}
|
||||
None => {
|
||||
let local = Local {};
|
||||
format!("{}", local.from_utc_datetime(&v).format(pattern))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
format!("[Timestamp{}: {}]", self.unit, self.value)
|
||||
}
|
||||
@@ -934,4 +951,54 @@ mod tests {
|
||||
Timestamp::new_millisecond(58).sub(Timestamp::new_millisecond(100))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_timezone_aware_string() {
|
||||
std::env::set_var("TZ", "Asia/Shanghai");
|
||||
|
||||
assert_eq!(
|
||||
"1970-01-01 08:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond).to_timezone_aware_string(None)
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 08:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("SYSTEM").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 08:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("+08:00").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 07:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("+07:00").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1969-12-31 23:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("-01:00").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 08:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("Asia/Shanghai").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 00:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("UTC").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 01:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("Europe/Berlin").unwrap())
|
||||
);
|
||||
assert_eq!(
|
||||
"1970-01-01 03:00:00.001",
|
||||
Timestamp::new(1, TimeUnit::Millisecond)
|
||||
.to_timezone_aware_string(TimeZone::from_tz_string("Europe/Moscow").unwrap())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
158
src/common/time/src/timezone.rs
Normal file
158
src/common/time/src/timezone.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
|
||||
use chrono::{FixedOffset, Local};
|
||||
use chrono_tz::Tz;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::error::{
|
||||
InvalidTimeZoneOffsetSnafu, ParseOffsetStrSnafu, ParseTimeZoneNameSnafu, Result,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TimeZone {
|
||||
Offset(FixedOffset),
|
||||
Named(Tz),
|
||||
}
|
||||
|
||||
impl TimeZone {
|
||||
/// Compute timezone from given offset hours and minutes
|
||||
/// Return `None` if given offset exceeds scope
|
||||
pub fn hours_mins_opt(offset_hours: i32, offset_mins: u32) -> Result<Self> {
|
||||
let offset_secs = if offset_hours > 0 {
|
||||
offset_hours * 3600 + offset_mins as i32 * 60
|
||||
} else {
|
||||
offset_hours * 3600 - offset_mins as i32 * 60
|
||||
};
|
||||
|
||||
FixedOffset::east_opt(offset_secs)
|
||||
.map(Self::Offset)
|
||||
.context(InvalidTimeZoneOffsetSnafu {
|
||||
hours: offset_hours,
|
||||
minutes: offset_mins,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse timezone offset string and return None if given offset exceeds
|
||||
/// scope.
|
||||
///
|
||||
/// String examples are available as described in
|
||||
/// https://dev.mysql.com/doc/refman/8.0/en/time-zone-support.html
|
||||
///
|
||||
/// - `SYSTEM`
|
||||
/// - Offset to UTC: `+08:00` , `-11:30`
|
||||
/// - Named zones: `Asia/Shanghai`, `Europe/Berlin`
|
||||
pub fn from_tz_string(tz_string: &str) -> Result<Option<Self>> {
|
||||
// Use system timezone
|
||||
if tz_string.eq_ignore_ascii_case("SYSTEM") {
|
||||
Ok(None)
|
||||
} else if let Some((hrs, mins)) = tz_string.split_once(':') {
|
||||
let hrs = hrs
|
||||
.parse::<i32>()
|
||||
.context(ParseOffsetStrSnafu { raw: tz_string })?;
|
||||
let mins = mins
|
||||
.parse::<u32>()
|
||||
.context(ParseOffsetStrSnafu { raw: tz_string })?;
|
||||
Self::hours_mins_opt(hrs, mins).map(Some)
|
||||
} else if let Ok(tz) = Tz::from_str(tz_string) {
|
||||
Ok(Some(Self::Named(tz)))
|
||||
} else {
|
||||
ParseTimeZoneNameSnafu { raw: tz_string }.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TimeZone {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Named(tz) => write!(f, "{}", tz.name()),
|
||||
Self::Offset(offset) => write!(f, "{}", offset),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn system_time_zone_name() -> String {
|
||||
Local::now().offset().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_from_tz_string() {
|
||||
assert_eq!(None, TimeZone::from_tz_string("SYSTEM").unwrap());
|
||||
|
||||
let utc_plus_8 = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 8).unwrap()));
|
||||
assert_eq!(utc_plus_8, TimeZone::from_tz_string("+8:00").unwrap());
|
||||
assert_eq!(utc_plus_8, TimeZone::from_tz_string("+08:00").unwrap());
|
||||
assert_eq!(utc_plus_8, TimeZone::from_tz_string("08:00").unwrap());
|
||||
|
||||
let utc_minus_8 = Some(TimeZone::Offset(FixedOffset::west_opt(3600 * 8).unwrap()));
|
||||
assert_eq!(utc_minus_8, TimeZone::from_tz_string("-08:00").unwrap());
|
||||
assert_eq!(utc_minus_8, TimeZone::from_tz_string("-8:00").unwrap());
|
||||
|
||||
let utc_minus_8_5 = Some(TimeZone::Offset(
|
||||
FixedOffset::west_opt(3600 * 8 + 60 * 30).unwrap(),
|
||||
));
|
||||
assert_eq!(utc_minus_8_5, TimeZone::from_tz_string("-8:30").unwrap());
|
||||
|
||||
let utc_plus_max = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 14).unwrap()));
|
||||
assert_eq!(utc_plus_max, TimeZone::from_tz_string("14:00").unwrap());
|
||||
|
||||
let utc_minus_max = Some(TimeZone::Offset(
|
||||
FixedOffset::west_opt(3600 * 13 + 60 * 59).unwrap(),
|
||||
));
|
||||
assert_eq!(utc_minus_max, TimeZone::from_tz_string("-13:59").unwrap());
|
||||
|
||||
assert_eq!(
|
||||
Some(TimeZone::Named(Tz::Asia__Shanghai)),
|
||||
TimeZone::from_tz_string("Asia/Shanghai").unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Some(TimeZone::Named(Tz::UTC)),
|
||||
TimeZone::from_tz_string("UTC").unwrap()
|
||||
);
|
||||
|
||||
assert!(TimeZone::from_tz_string("WORLD_PEACE").is_err());
|
||||
assert!(TimeZone::from_tz_string("A0:01").is_err());
|
||||
assert!(TimeZone::from_tz_string("20:0A").is_err());
|
||||
assert!(TimeZone::from_tz_string(":::::").is_err());
|
||||
assert!(TimeZone::from_tz_string("Asia/London").is_err());
|
||||
assert!(TimeZone::from_tz_string("Unknown").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timezone_to_string() {
|
||||
assert_eq!("UTC", TimeZone::Named(Tz::UTC).to_string());
|
||||
assert_eq!(
|
||||
"+01:00",
|
||||
TimeZone::from_tz_string("01:00")
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
"Asia/Shanghai",
|
||||
TimeZone::from_tz_string("Asia/Shanghai")
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,8 @@ use std::sync::Arc;
|
||||
|
||||
use common_query::Output;
|
||||
use common_recordbatch::RecordBatches;
|
||||
use common_time::timezone::system_time_zone_name;
|
||||
use common_time::TimeZone;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::StringVector;
|
||||
@@ -54,6 +56,10 @@ static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
|
||||
static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
|
||||
|
||||
// Time zone settings
|
||||
static SET_TIME_ZONE_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)^SET TIME_ZONE\s*=\s*'(\S+)'").unwrap());
|
||||
|
||||
static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
|
||||
RegexSet::new([
|
||||
// Txn.
|
||||
@@ -124,8 +130,6 @@ static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
|
||||
("transaction_isolation", "REPEATABLE-READ"),
|
||||
("session.transaction_isolation", "REPEATABLE-READ"),
|
||||
("session.transaction_read_only", "0"),
|
||||
("time_zone", "UTC"),
|
||||
("system_time_zone", "UTC"),
|
||||
("max_allowed_packet", "134217728"),
|
||||
("interactive_timeout", "31536000"),
|
||||
("wait_timeout", "31536000"),
|
||||
@@ -168,7 +172,7 @@ fn show_variables(name: &str, value: &str) -> RecordBatches {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn select_variable(query: &str) -> Option<Output> {
|
||||
fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
|
||||
let mut fields = vec![];
|
||||
let mut values = vec![];
|
||||
|
||||
@@ -191,12 +195,24 @@ fn select_variable(query: &str) -> Option<Output> {
|
||||
.unwrap_or("")
|
||||
})
|
||||
.collect();
|
||||
|
||||
// get value of variables from known sources or fallback to defaults
|
||||
let value = match var_as[0] {
|
||||
"time_zone" => query_context
|
||||
.time_zone()
|
||||
.map(|tz| tz.to_string())
|
||||
.unwrap_or_else(|| "".to_owned()),
|
||||
"system_time_zone" => system_time_zone_name(),
|
||||
_ => VAR_VALUES
|
||||
.get(var_as[0])
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "0".to_owned()),
|
||||
};
|
||||
|
||||
values.push(Arc::new(StringVector::from(vec![value])) as _);
|
||||
match var_as.len() {
|
||||
1 => {
|
||||
// @@aa
|
||||
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
|
||||
values.push(Arc::new(StringVector::from(vec![*value])) as _);
|
||||
|
||||
// field is '@@aa'
|
||||
fields.push(ColumnSchema::new(
|
||||
&format!("@@{}", var_as[0]),
|
||||
@@ -207,9 +223,6 @@ fn select_variable(query: &str) -> Option<Output> {
|
||||
2 => {
|
||||
// @@bb as cc:
|
||||
// var is 'bb'.
|
||||
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
|
||||
values.push(Arc::new(StringVector::from(vec![*value])) as _);
|
||||
|
||||
// field is 'cc'.
|
||||
fields.push(ColumnSchema::new(
|
||||
var_as[1],
|
||||
@@ -227,12 +240,12 @@ fn select_variable(query: &str) -> Option<Output> {
|
||||
Some(Output::RecordBatches(batches))
|
||||
}
|
||||
|
||||
fn check_select_variable(query: &str) -> Option<Output> {
|
||||
fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
|
||||
if vec![&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
|
||||
.iter()
|
||||
.any(|r| r.is_match(query))
|
||||
{
|
||||
select_variable(query)
|
||||
select_variable(query, query_context)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -251,6 +264,20 @@ fn check_show_variables(query: &str) -> Option<Output> {
|
||||
recordbatches.map(Output::RecordBatches)
|
||||
}
|
||||
|
||||
// TODO(sunng87): extract this to use sqlparser for more variables
|
||||
fn check_set_variables(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
|
||||
if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) {
|
||||
// get the capture
|
||||
let tz = captures.get(1).unwrap();
|
||||
if let Ok(timezone) = TimeZone::from_tz_string(tz.as_str()) {
|
||||
query_ctx.set_time_zone(timezone);
|
||||
return Some(Output::AffectedRows(0));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Check for SET or others query, this is the final check of the federated query.
|
||||
fn check_others(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
|
||||
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
|
||||
@@ -283,19 +310,12 @@ pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
|
||||
}
|
||||
|
||||
// First to check the query is like "select @@variables".
|
||||
let output = check_select_variable(query);
|
||||
if output.is_some() {
|
||||
return output;
|
||||
}
|
||||
|
||||
// Then to check "show variables like ...".
|
||||
let output = check_show_variables(query);
|
||||
if output.is_some() {
|
||||
return output;
|
||||
}
|
||||
|
||||
// Last check.
|
||||
check_others(query, query_ctx)
|
||||
check_select_variable(query, query_ctx.clone())
|
||||
// Then to check "show variables like ...".
|
||||
.or_else(|| check_show_variables(query))
|
||||
.or_else(|| check_set_variables(query, query_ctx.clone()))
|
||||
// Last check
|
||||
.or_else(|| check_others(query, query_ctx))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -352,13 +372,15 @@ mod test {
|
||||
+-----------------+------------------------+";
|
||||
test(query, expected);
|
||||
|
||||
// set sysstem timezone
|
||||
std::env::set_var("TZ", "Asia/Shanghai");
|
||||
// complex variables
|
||||
let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
|
||||
let expected = "\
|
||||
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+
|
||||
| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout; |
|
||||
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+
|
||||
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | UTC | UTC | REPEATABLE-READ | 31536000 |
|
||||
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | +08:00 | | REPEATABLE-READ | 31536000 |
|
||||
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+";
|
||||
test(query, expected);
|
||||
|
||||
@@ -395,4 +417,31 @@ mod test {
|
||||
+----------------------------------+";
|
||||
test(query, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_time_zone() {
|
||||
let query_context = Arc::new(QueryContext::new());
|
||||
let output = check("set time_zone = 'UTC'", query_context.clone());
|
||||
match output.unwrap() {
|
||||
Output::AffectedRows(rows) => {
|
||||
assert_eq!(rows, 0)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
assert_eq!("UTC", query_context.time_zone().unwrap().to_string());
|
||||
|
||||
let output = check("select @@time_zone", query_context);
|
||||
match output.unwrap() {
|
||||
Output::RecordBatches(r) => {
|
||||
let expected = "\
|
||||
+-------------+
|
||||
| @@time_zone |
|
||||
+-------------+
|
||||
| UTC |
|
||||
+-------------+";
|
||||
assert_eq!(r.pretty_print().unwrap(), expected);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
log::debug!("execute replaced query: {}", query);
|
||||
|
||||
let outputs = self.do_query(&query).await;
|
||||
writer::write_output(w, &query, outputs).await?;
|
||||
writer::write_output(w, &query, self.session.context(), outputs).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -263,7 +263,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
]
|
||||
);
|
||||
let outputs = self.do_query(query).await;
|
||||
writer::write_output(writer, query, outputs).await?;
|
||||
writer::write_output(writer, query, self.session.context(), outputs).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ use datatypes::schema::{ColumnSchema, SchemaRef};
|
||||
use opensrv_mysql::{
|
||||
Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter,
|
||||
};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::prelude::*;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
@@ -31,9 +32,10 @@ use crate::error::{self, Error, Result};
|
||||
pub async fn write_output<'a, W: AsyncWrite + Send + Sync + Unpin>(
|
||||
w: QueryResultWriter<'a, W>,
|
||||
query: &str,
|
||||
query_context: QueryContextRef,
|
||||
outputs: Vec<Result<Output>>,
|
||||
) -> Result<()> {
|
||||
let mut writer = Some(MysqlResultWriter::new(w));
|
||||
let mut writer = Some(MysqlResultWriter::new(w, query_context.clone()));
|
||||
for output in outputs {
|
||||
let result_writer = writer.take().context(error::InternalSnafu {
|
||||
err_msg: "Sending multiple result set is unsupported",
|
||||
@@ -54,11 +56,18 @@ struct QueryResult {
|
||||
|
||||
pub struct MysqlResultWriter<'a, W: AsyncWrite + Unpin> {
|
||||
writer: QueryResultWriter<'a, W>,
|
||||
query_context: QueryContextRef,
|
||||
}
|
||||
|
||||
impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
pub fn new(writer: QueryResultWriter<'a, W>) -> MysqlResultWriter<'a, W> {
|
||||
MysqlResultWriter::<'a, W> { writer }
|
||||
pub fn new(
|
||||
writer: QueryResultWriter<'a, W>,
|
||||
query_context: QueryContextRef,
|
||||
) -> MysqlResultWriter<'a, W> {
|
||||
MysqlResultWriter::<'a, W> {
|
||||
writer,
|
||||
query_context,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to write one result set. If there are more than one result set, return `Some`.
|
||||
@@ -80,18 +89,23 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
recordbatches,
|
||||
schema,
|
||||
};
|
||||
Self::write_query_result(query, query_result, self.writer).await?;
|
||||
Self::write_query_result(query, query_result, self.writer, self.query_context)
|
||||
.await?;
|
||||
}
|
||||
Output::RecordBatches(recordbatches) => {
|
||||
let query_result = QueryResult {
|
||||
schema: recordbatches.schema(),
|
||||
recordbatches: recordbatches.take(),
|
||||
};
|
||||
Self::write_query_result(query, query_result, self.writer).await?;
|
||||
Self::write_query_result(query, query_result, self.writer, self.query_context)
|
||||
.await?;
|
||||
}
|
||||
Output::AffectedRows(rows) => {
|
||||
let next_writer = Self::write_affected_rows(self.writer, rows).await?;
|
||||
return Ok(Some(MysqlResultWriter::new(next_writer)));
|
||||
return Ok(Some(MysqlResultWriter::new(
|
||||
next_writer,
|
||||
self.query_context,
|
||||
)));
|
||||
}
|
||||
},
|
||||
Err(error) => Self::write_query_error(query, error, self.writer).await?,
|
||||
@@ -122,6 +136,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
query: &str,
|
||||
query_result: QueryResult,
|
||||
writer: QueryResultWriter<'a, W>,
|
||||
query_context: QueryContextRef,
|
||||
) -> Result<()> {
|
||||
match create_mysql_column_def(&query_result.schema) {
|
||||
Ok(column_def) => {
|
||||
@@ -129,7 +144,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
// to return a new QueryResultWriter.
|
||||
let mut row_writer = writer.start(&column_def).await?;
|
||||
for recordbatch in &query_result.recordbatches {
|
||||
Self::write_recordbatch(&mut row_writer, recordbatch).await?;
|
||||
Self::write_recordbatch(&mut row_writer, recordbatch, query_context.clone())
|
||||
.await?;
|
||||
}
|
||||
row_writer.finish().await?;
|
||||
Ok(())
|
||||
@@ -141,6 +157,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
async fn write_recordbatch(
|
||||
row_writer: &mut RowWriter<'_, W>,
|
||||
recordbatch: &RecordBatch,
|
||||
query_context: QueryContextRef,
|
||||
) -> Result<()> {
|
||||
for row in recordbatch.rows() {
|
||||
for value in row.into_iter() {
|
||||
@@ -161,7 +178,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
|
||||
Value::Binary(v) => row_writer.write_col(v.deref())?,
|
||||
Value::Date(v) => row_writer.write_col(v.val())?,
|
||||
Value::DateTime(v) => row_writer.write_col(v.val())?,
|
||||
Value::Timestamp(v) => row_writer.write_col(v.to_local_string())?,
|
||||
Value::Timestamp(v) => row_writer
|
||||
.write_col(v.to_timezone_aware_string(query_context.time_zone()))?,
|
||||
Value::List(_) => {
|
||||
return Err(Error::Internal {
|
||||
err_msg: format!(
|
||||
|
||||
@@ -8,3 +8,4 @@ license.workspace = true
|
||||
arc-swap = "1.5"
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
|
||||
@@ -20,6 +20,7 @@ use arc_swap::ArcSwap;
|
||||
use common_catalog::build_db_string;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_telemetry::debug;
|
||||
use common_time::TimeZone;
|
||||
|
||||
pub type QueryContextRef = Arc<QueryContext>;
|
||||
pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
@@ -28,6 +29,7 @@ pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
pub struct QueryContext {
|
||||
current_catalog: ArcSwap<String>,
|
||||
current_schema: ArcSwap<String>,
|
||||
time_zone: ArcSwap<Option<TimeZone>>,
|
||||
}
|
||||
|
||||
impl Default for QueryContext {
|
||||
@@ -56,6 +58,7 @@ impl QueryContext {
|
||||
Self {
|
||||
current_catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.to_string())),
|
||||
current_schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.to_string())),
|
||||
time_zone: ArcSwap::new(Arc::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +66,7 @@ impl QueryContext {
|
||||
Self {
|
||||
current_catalog: ArcSwap::new(Arc::new(catalog.to_string())),
|
||||
current_schema: ArcSwap::new(Arc::new(schema.to_string())),
|
||||
time_zone: ArcSwap::new(Arc::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +103,16 @@ impl QueryContext {
|
||||
let schema = self.current_schema();
|
||||
build_db_string(&catalog, &schema)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn time_zone(&self) -> Option<TimeZone> {
|
||||
self.time_zone.load().as_ref().clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set_time_zone(&self, tz: Option<TimeZone>) {
|
||||
self.time_zone.swap(Arc::new(tz));
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
Reference in New Issue
Block a user